import os
import json 


from itertools import product

def build_sql_id(task="code2doc_go_test", raw_task_id="code2doc_go/test.jsonl:0"):
    dataset = ["code2doc","doc2code"]
    language = ["go","java","javascript","php","python","ruby"]
    split = ["test","train","valid"]
    
    mapping_dict = {}
    for dt,lang,sp in product(dataset,language,split):
        if dt=="doc2code":
            sql_id = f"summary_in_text_xiaofei_code_bb/summary_in_text_train/{lang}/{sp}/"
        else:
            sql_id = f"summary_in_text_xiaofei/summary_in_text_train/{lang}/{sp}/"

        raw_id = f"{dt}_{lang}/{sp}.jsonl:"
        mapping_dict[f"{dt}_{lang}_{sp}"]=(raw_id,sql_id,".input")
        
    mapping_dict.update({ 
        "apps_test":("apps_test_", "text-code/APPS_001/test/", ".input" ),
        "apps_train":("apps_train_", "text-code/APPS_001/train/", ".input"),

        "archive_stackexchange":None,
        #
        "concde_dev":("concode_","text-code/CodeXGLUE_001/dev/",".input"),
        "concde_test":("concode_","text-code/CodeXGLUE_001/test/",".input"), 
        "concde_train":("concode_","text-code/CodeXGLUE_001/train/",".input"),
    })
    # print (list(mapping_dict),)
    
    func_replace =  mapping_dict[task]
    if func_replace is None :
        return raw_task_id
    f,r,a = func_replace
    # print ("task",task, "func_replace", func_replace , "raw_task_id", raw_task_id)
    return_id = raw_task_id.replace(f,r, )+a #  func_replace[0]( raw_task_id)
    return return_id





def format_item(item,additional_dict={}):
    prompt = item["data"]["prompt"]
    chatgpt = item["data"]["sampled"]
    idx = item["sample_id"] 
    task_str = item["task"]
    task_dict = dict( (n,str(v)) for n,v in (a.split('=') for a in task_str.split(",") ) )
    task = task_dict["task"]
    
    sql_id = build_sql_id(task=task, raw_task_id=idx )
    
    if sql_id not in additional_dict:
        return None 
    human_item  = additional_dict.get(sql_id, {})
    human = human_item.get("human_answer",None )

    assert human_item["sql_id"]== sql_id ,(sql_id,"!=",human_item["sql_id"] , "---->", "task",task,idx,   list(human_item) )

    return { "sql_id":sql_id, "human_answer":human, "chatgpt_answer":chatgpt , "prompt":prompt }


def load_cache_map_sqlid(cache_file="/tmp/mapping_sqlid.pkl"):
    if os.path.isfile(cache_file):
        with open(cache_file,"rb" ) as ff:
            additional_dict = pickle.load(file=ff)
    else:
        print ("start load cache ")
        additional_dict ={}
        def load_cache_dict():
            p="/data3/icse_dataset/wj_build_prompt_data"
            fl = glob( os.path.join(p,"**","*.jsonl" ) )
            print ("there are total ", len(fl) )
            for ffll in tqdm(fl):
                with open(ffll) as fff:
                    item_list = [json.loads(x) for x in fff.readlines() ]
                    additional_dict .update(  {x["sql_id"]:x  for x in item_list } )
        load_cache_dict()
        
        with open(cache_file,"wb" ) as ff:
            pickle.dump(obj=ff, file=additional_dict )

    additional_dict_keys = list(additional_dict)
    print ("cache keys", len(additional_dict_keys ) )  
    # print ("cache keys.sample ", random.sample(additional_dict_keys,20) ) 
    additional_dict_summary_keys  =[x for x in list(additional_dict) if "apps_" in x.lower() ]
    print ("done! load cache ")
    return additional_dict
        
        

if __name__=="__main__":
    from concurrent.futures import ThreadPoolExecutor
    import pickle 
    from glob2 import glob 
    from tqdm import tqdm 
    from functools import partial
    import random 
    import numpy as np 


    additional_dict = load_cache_map_sqlid()


    root_dir_save = "/data3/fse2023_dataset/NL-CCD/raw"

    search_root_dir = "/data3/fse2023_dataset/llm_generated_all/gpt-3.5-turbo"
    generated_filelist = glob( os.path.join(search_root_dir, "*.jsonl")) 
    print ("total there are ", len(generated_filelist ) )

    num_workers = os.cpu_count()-1 
    
    for generated_one_file    in generated_filelist:
        print ("proceess.", os.path.basename(generated_one_file) )
        basename = os.path.basename(generated_one_file)
        basename =basename.replace(".jsonl",",formated=true.jsonl")
        new_save = os.path.join(root_dir_save, basename )
        print ("path saved into ", new_save )
        #
        if os.path.isfile(new_save):
            print ("skip the file", new_save )
            continue 


        with open(generated_one_file) as f :
            vides = f.readlines()
            # vides = [json.loads(x.strip() ) for x in vides ]
                
            
        def process_file(i):
            fx = vides[i]
            data = json.loads(fx) 
            
            item_info = format_item(item=data,additional_dict=additional_dict)
            #
            # if i%5000==0:
            #     print (i,":",len(vides), "list.item", list(item_info) if item_info is not None else {} )
            return item_info 
            
            
        # c=[ process_file(i) for i in range(10) ] 
        
        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            predictions_all = ex.map(process_file, range(len(vides)))
        #
        predictions_all = list(predictions_all)
        
        total_raw=  len(predictions_all) 
        
        predictions_all =  [x for x in predictions_all if x is not None ]

        total_raw_not_none =  len(predictions_all) 
        human_none_sta =   [x["human_answer"] is None for x in predictions_all]
        chatpgt_none_sta =   [x["chatgpt_answer"] is None for x in predictions_all]
        #
        print ("total raw", total_raw, 
               "after none filter", total_raw_not_none, 
               "after human", np.unique(human_none_sta,return_counts=True),
               "after chatgpt", np.unique(chatpgt_none_sta,return_counts=True),
                )
        
        

        predictions = [x for x in predictions_all if x["chatgpt_answer"] is not None   and  x["human_answer"]  is not None ]


        with open(new_save,"w" ) as f :
            f.write( "\n".join( [json.dumps(x) for x in predictions ] ) )
        # #

        # total_none = [x is None or x["human_answer"]  is None for x in predictions_all  ]
        # print (folder_name, "--->", np.unique(total_none,return_counts=True) ,"--->total_none"  )
        
    
    
#
#
# test_case=[
#     ("code2doc_go_test", "code2doc_go/test.jsonl:0","summary_in_text_xiaofei/summary_in_text_train/go/test/0.input"),
#     ("code2doc_java_test", "code2doc_java/test.jsonl:3","summary_in_text_xiaofei/summary_in_text_train/java/test/3.input"),
#     ("code2doc_javascript_test", "code2doc_javascript/test.jsonl:7","summary_in_text_xiaofei/summary_in_text_train/javascript/test/7.input"),
#     ("code2doc_python_test", "code2doc_python/test.jsonl:9","summary_in_text_xiaofei/summary_in_text_train/python/test/9.input"),
#
#     ("concde_dev", "concode_2","text-code/CodeXGLUE_001/dev/2.input"),
#
#     # ("doc2code_php_train", "humaneval_001/default/HumanEval/8.input","text-code/humaneval_001/default/HumanEval/8.input"),
#
#     ("doc2code_python_train", "doc2code_python/train.jsonl:12","summary_in_text_xiaofei_code_bb/summary_in_text_train/python/train/12.input"),
#     ("doc2code_ruby_train", "doc2code_ruby/train.jsonl:13","summary_in_text_xiaofei_code_bb/summary_in_text_train/ruby/train/13.input"),
#     ]
#
# for x,y,z in test_case:
#     z_hat= build_sql_id(task=x, raw_task_id=y )
#     print ("run...1")
#     assert z==z_hat , (z,z_hat,x,y)


    