import os 
import json 
from glob2 import glob 
from tqdm import tqdm 
import numpy as np 

m_list= [
    ("CodeLlama-34b-Instruct-hf_extract_sta",0.2),
    ("WizardCoder-15B-V1.0_extract_sta",0.2),
    ("extract_sta",0.01),
    ]

TASK_GROUP  = {
    "apps_test":"apps_test",
    "archive_stackexchange_test":"archive_stackexchange_test",
    "archive_stackexchange":"archive_stackexchange_test",
    "code2doc_go_test":"code2doc",
    "code2doc_go_valid":"code2doc",
    "code2doc_java_test":"code2doc",
    "code2doc_java_valid":"code2doc",
    "code2doc_javascript_test":"code2doc",
    "code2doc_javascript_valid":"code2doc",
    "code2doc_php_test":"code2doc",
    "code2doc_php_valid":"code2doc",
    "code2doc_python_test":"code2doc",
    "code2doc_python_valid":"code2doc",
    "code2doc_ruby_test":"code2doc",
    "code2doc_ruby_valid":"code2doc",
    "concde_dev":"concde_dev",
    "concde_test":"concde_dev",
    "doc2code_go_test":"doc2code",
    "doc2code_go_valid":"doc2code",
    "doc2code_java_test":"doc2code",
    "doc2code_java_valid":"doc2code",
    "doc2code_javascript_test":"doc2code",
    "doc2code_javascript_valid":"doc2code",
    "doc2code_php_test":"doc2code",
    "doc2code_php_valid":"doc2code",
    "doc2code_python_test":"doc2code",
    "doc2code_python_valid":"doc2code",
    "doc2code_ruby_test":"doc2code",
    "doc2code_ruby_valid":"doc2code",
    "gdoc2code_go_test":"doc2code",
    "gdoc2code_go_valid":"doc2code",
    "gdoc2code_java_test":"doc2code",
    "gdoc2code_java_valid":"doc2code",
    "gdoc2code_javascript_test":"doc2code",
    "gdoc2code_javascript_valid":"doc2code",
    "gdoc2code_php_test":"doc2code",
    "gdoc2code_php_valid":"doc2code",
    "gdoc2code_python_test":"doc2code",
    "gdoc2code_python_valid":"doc2code",
    "gdoc2code_ruby_test":"doc2code",
    "gdoc2code_ruby_valid":"doc2code",
    } 

def parse_dict_v1(xpath):
    dict_info = {}
    dict_info["path"]=xpath 
    xpath = os.path.basename(xpath)
    
    xpath = xpath.replace(".jsonl","").replace(".csv","").replace(".txt","")
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    
    task = dict_info["task"]
    split = task.split("_")[-1]
    dict_info["split"] =split 
    task = task.replace("_"+dict_info["split"] , "")
    dict_info["lang"]=  task.split("_")[-1] if "_" in task else None 
    dict_info["name"]=  task.split("_")[0] 
    if "mt" not in dict_info :
        dict_info["mt"]="baseline"
    if dict_info["name"]=="apps":
        dict_info["lang"]="python"

    if dict_info["task"]=="archive_stackexchange":
        dict_info["split"]="test"
        dict_info["lang"]=None
        dict_info["name"]=dict_info["task"]
        
    if "r" not in dict_info and "role" in dict_info :
        dict_info ["r"] =dict_info ["role"]
    if "r"  in dict_info and "role" not in dict_info :
        dict_info ["role"] =dict_info ["r"]
    
    return dict_info 


temp_list= ["0.8","0.2","0.01"]
if __name__=="__main__":
    import pandas as pd 
    from itertools import product 
    root_dir ="/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/NL-CCD"

    root_tmp_dir_step1 ="/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/NL-CCD/common_set_step1"
    os.makedirs(root_tmp_dir_step1,exist_ok=True)
    root_tmp_dir ="/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/NL-CCD/common_set"
    os.makedirs(root_tmp_dir,exist_ok=True)
    # def process_one  (xt_path ):
    #     xt_path_base = os.path.basename(xt_path )
    #     meta_info = parse_dict_v1( xpath=xt_path_base )
    #     df = pd .read_csv(xt_path )
    #
    #     df_dict = df.to_dict(orient="records")
    #     assert len(df_dict)>0 , xt_path
    #
    #
    #     task_grp  = TASK_GROUP[  meta_info["task"] ]
    #
    #     meta_info.update( df_dict[0] )
    #     meta_info["task_grp"] =  task_grp 
    #
    #     return meta_info 
    #

    def process_one  (xt_path ):
        xt_path_base = os.path.basename(xt_path )
        meta_info = parse_dict_v1( xpath=xt_path_base )
        task_grp  = TASK_GROUP[  meta_info["task"] ]
        
        m= meta_info["m"]
        temp = meta_info["temp"]
        save_x_path = os.path.join(root_tmp_dir_step1,"m={},temp={},task={}.list.txt".format(m,temp, task_grp )  )
        
        
        df = pd .read_csv(xt_path )

        if task_grp in ["doc2code","gdoc2code","apps_test","concde_dev"]:
            df_dict = df[  df["chatgpt_answer_extract_c"] >0 ] 
            if len(df_dict)<=0:
                return None 
            
            assert len(df_dict)>0 , xt_path
            idx_list = df_dict["idx"].to_list()
            idx_list = list(set(idx_list))
            with open(save_x_path, "a") as fw :
                fw.write("\n".join( idx_list ))
        else:
            df_code = df[  df["chatgpt_answer_extract_c"] >0 ] 
            assert  len(df_code)==0 , (df_code.shape, xt_path)
            idx_list = df["idx"].to_list()
            with open(save_x_path, "a") as fw :
                fw.write("\n".join( idx_list ))
            
        
        return None  
    
    
    def collect_each_model_success_idx():
        df_list = []
        for model_name, temp  in m_list :
            final_list = [] 
    
            find_list_p = os.path.join (root_dir , model_name , "*temp={},*.csv".format(temp ) )
            find_list = glob( find_list_p) 
            if len(find_list)<=0:
                print (model_name, temp, "------>" , find_list_p ) 
                continue 
             
            for one_path in tqdm( find_list) :
                
                meta= process_one( xt_path =one_path )
    
    if not os.path.isdir(root_tmp_dir_step1) or len( os.listdir(root_tmp_dir_step1) )<=0:
        collect_each_model_success_idx()
    
    def merge_all_model_common_idx():    
        grp_ids = list(set(  TASK_GROUP.values()  ) )
        
        for one_grp in tqdm(grp_ids):
            idx_list_p = os.path.join( root_tmp_dir_step1 , "*task={}*.txt".format(one_grp) )
            idx_list = glob(idx_list_p )
            assert len(idx_list)==3, (idx_list_p,idx_list)
            comm_idx = None 
            for each_f in idx_list:
                with open(each_f) as f :
                    data= f.readlines()
                    data = set([x.strip() for x in data ])
                    if comm_idx is None :
                        comm_idx = set( data )  
                    else :
                        comm_idx = comm_idx.intersection( data )

            idx_list_p_save = os.path.join( root_tmp_dir , "task={}.common_list".format(one_grp) )
            with open(idx_list_p_save,"w") as fw :
                fw.write( "\n".join( list(comm_idx) ) ) 
    
            get_lang = lambda x:x.split("/")[-3]
            if one_grp   in ["code2doc", "doc2code"] :
                lang_st = [get_lang(x) for x in comm_idx] 
                print (one_grp, np.unique(lang_st,return_counts=True ) )
    merge_all_model_common_idx()
    
