from glob2 import glob 
import re 
import os 
import pickle 
import traceback 

import json 

def load_pickle(xpath):
    with open(xpath,"rb") as f :
        data = pickle.load(file=f )
    return data 


def load_wiki(item,additional_dict={}):
    human = item["data"]
    human = human.replace("please polish the folllowing words:\n\n","")

    chatgpt = item["answer"]

    idx = item["id"] if "id" in item else item["_id"]
    mapping_key= item["mapping_key"]

    return {"mapping_key":mapping_key, "id":idx, "human_answer":human, "chatgpt_answer":chatgpt , "prompt":item["data"] }


# def load_apps(item):
_language_domain = [
"spanish",
"russian",
"portuguese",
"ukrainian",
"latin",
"korean",
"japanese",
"ja",
"rus",
"ru",
"italian",
"german",
"french",
"chinese",
"hinduism",
"politics",
"hermeneutics",
"puzzling",
"linguistics",
"judaism",
"es",
    ]
language_domain = [x+".stackexchange.com" for x in _language_domain]
language_domain += [x+".stackoverflow.com" for x in _language_domain]

def is_in_nonenglish_domain(task_id):
    task_id_raw= task_id 
    if "superuser_" in task_id or "askubuntu_" in task_id or "mathoverflow_" in task_id:
        return False
    if "stackexchange" in task_id :
        task_id = task_id .split("stackexchange_")[0]
        task_id = task_id+"stackexchange.com"
    else:
        task_id = task_id .split("stackoverflow_")[0]
        task_id = task_id+"stackoverflow.com"

    if task_id in language_domain:
        return True
    return False 

def load_stackoverflow(item,additional_dict={}):
    prompt = item["data"]
    if "raw" in item :
        assert  prompt in item["raw"]
        
        Human_raw = item["raw"]

        pattern_split = r'\n*A:\n*'
        human = re.split(pattern_split , Human_raw )
        human = human[-1] if len(human)>1 else Human_raw
        
        # human = human.split("\n\nA:\n\n",1)[-1]
    else:
        human = None 
        
    chatgpt = item["answer"]
    idx = item["id"] if "id" in item else item["_id"]
    mapping_key= item["mapping_key"]

    if is_in_nonenglish_domain(task_id = idx):
        return None 

    return {"mapping_key":mapping_key, "id":idx, "human_answer":human, "chatgpt_answer":chatgpt , "prompt":item["data"] }


def load_summary_code_to_text(item,additional_dict={}):
    chatgpt = item["answer"]
    # human  =  None #item["answer"]
    idx = item["id"] if "id" in item else item["_id"]
    mapping_key= item["mapping_key"]

    mapping_key= item["mapping_key"]
    find_id = f"{mapping_key}/{idx}"
    # print ("find_id", find_id, "additional_dict_summary_keys" , random.sample(additional_dict_summary_keys,20) )
    human_dict=  additional_dict.get(find_id,{})
    
    # if find_id not in additional_dict:
        # print ("err:", find_id )
    
    human  =  human_dict.get("human_answer",None) #item["answer"]


    return {"mapping_key":mapping_key, "id":idx, "human_answer":human, "chatgpt_answer":chatgpt , "prompt":item["data"] }

def load_summary_text_to_code(item,additional_dict={}):
    chatgpt = item["answer"]
    idx = item["id"] if "id" in item else item["_id"]
    
    mapping_key= item["mapping_key"]
    if mapping_key=="summary_in_text_xiaofei_code_b":
        mapping_key="summary_in_text_xiaofei_code_bb"
    find_id = f"{mapping_key}/{idx}"
    # print ("find_id", find_id, "additional_dict_summary_keys" , random.sample(additional_dict_summary_keys,20) )
    # if find_id not in additional_dict:
    #     print ("err:", find_id ,"item.mapping_key:", item["mapping_key"], "idx:", item["id"] )

    human_dict=  additional_dict.get(find_id,{})
    human  =  human_dict.get("human_answer",None) #item["answer"]
    

    return {"mapping_key":mapping_key, "id":idx, "human_answer":human, "chatgpt_answer":chatgpt , "prompt":item["data"] }



def load_apps(item,additional_dict={}):
    chatgpt = item["answer"]
    # human  =  None #item["answer"]
    idx = item["id"] if "id" in item else item["_id"]
    
    
    mapping_key= item["mapping_key"]

    add_001 = False 
    if "CodeXGLUE/" in idx :
        add_001 = True 
    if "APPS/" in idx :
        add_001 = True 
        
    find_id = f"{mapping_key}/{idx}" if not add_001 else f"{mapping_key}_001/{idx}"
    human_dict=  additional_dict.get(find_id,{})
    human  =  human_dict.get("human_answer",None) #item["answer"]



    # if find_id not in additional_dict:
    #     print ("err:", find_id ,"item.mapping_key:", item["mapping_key"], "idx:", item["id"] )
    #     print ("find_id", find_id, "additional_dict_summary_keys" , random.sample(additional_dict_summary_keys,20) )

    return {"mapping_key":mapping_key, "id":idx, "human_answer":human, "chatgpt_answer":chatgpt , "prompt":item["data"] }




    

if __name__=="__main__":
    from concurrent.futures import ThreadPoolExecutor
    import pickle 
    from glob2 import glob 
    from tqdm import tqdm 
    from functools import partial
    import random 
    import numpy as np 
    
    # px = "/mnt/nvme/code_datasets/wj_text_saved_final/save_processed"
    root_dir  = "/mnt/nvme/code_datasets/wj_text_saved_final/save_processed"
    root_dir_save = "/data3/icse_dataset/NL-CCD_dirs/raw2"
    
    role_list=[
        ("archive_stackexchange",load_stackoverflow),
        ("kilt_wiki_random", load_wiki) , 
        ("summary_in_text_xiaofei_code_bb_merged", load_summary_text_to_code ),
        ("summary_in_text_xiaofei_merged",load_summary_code_to_text ),
        ("text-code",load_apps),
        ("text-code-001",load_apps),
        ]

    print ("start load cache ")
    additional_dict ={}
    def load_cache_dict():
        p="/data3/icse_dataset/wj_build_prompt_data"
        fl = glob( os.path.join(p,"**","*.jsonl" ) )
        print ("there are total ", len(fl) )
        for ffll in tqdm(fl):
            with open(ffll) as fff:
                item_list = [json.loads(x) for x in fff.readlines() ]
                additional_dict .update(  {x["sql_id"]:x  for x in item_list } )
        additional_dict_keys = list(additional_dict)
        print ("cache keys", len(additional_dict_keys ) )  
        # print ("cache keys.sample ", random.sample(additional_dict_keys,20) ) 
    load_cache_dict()
    additional_dict_summary_keys  =[x for x in list(additional_dict) if "apps_" in x.lower() ]
          
    print ("done! load cache ")
    
    for folder_name , load_func  in role_list:
        
        search_path1 = os.path.join(root_dir ,  folder_name, "**", "*" )

        search_path2 = os.path.join(root_dir ,  folder_name, "*" )
        
        print ("search_path",  )
        vides = []
        vides += glob( search_path1 )
        vides += glob( search_path2 )
        vides = [x for x in vides if os.path.isfile(x)]
        print ("total", len(vides) )
        
        
        num_workers = os.cpu_count()-1 
        
                
            
        def process_file(i):
            fx = vides[i]
            try :
                with open(fx,"rb") as fff:
                    data = pickle.load(fff)
                    
                item_info = load_func(item=data,additional_dict=additional_dict)
                #
                # if i%500==0:
                #     print ("list.item", list(item_info) )
                #

                return item_info 
            except :
                traceback.print_exc()
                return None 
            
            
        # c=[ process_file(i) for i in range(10) ] 
        
        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            predictions_all = ex.map(process_file, range(len(vides)))
        #
        predictions_all = list(predictions_all)
        total_raw=  len(predictions_all) 
        #
        predictions = [x for x in predictions_all if x is not None ]
        #
        print ("total raw", total_raw, "after ", len(predictions) )
        new_save = os.path.join(root_dir_save, folder_name+".jsonl")
        print ("path saved into ", new_save )
        
        with open(new_save,"w" ) as f :
            f.write( "\n".join( [json.dumps(x) for x in predictions ] ) )
        #

        total_none = [x is None or x["human_answer"]  is None for x in predictions_all  ]
        print (folder_name, "--->", np.unique(total_none,return_counts=True) ,"--->total_none"  )
        