import pandas as pd 
import numpy as np 
import os 
from concurrent.futures import ThreadPoolExecutor
import json 

TASK_GROUP  = {
    "apps_test":"apps_test",
    "archive_stackexchange_test":"archive_stackexchange_test",
    "archive_stackexchange":"archive_stackexchange_test",
    "code2doc_go_test":"code2doc",
    "code2doc_go_valid":"code2doc",
    "code2doc_java_test":"code2doc",
    "code2doc_java_valid":"code2doc",
    "code2doc_javascript_test":"code2doc",
    "code2doc_javascript_valid":"code2doc",
    "code2doc_php_test":"code2doc",
    "code2doc_php_valid":"code2doc",
    "code2doc_python_test":"code2doc",
    "code2doc_python_valid":"code2doc",
    "code2doc_ruby_test":"code2doc",
    "code2doc_ruby_valid":"code2doc",
    "concde_dev":"concde_dev",
    "concde_test":"concde_dev",
    "doc2code_go_test":"doc2code",
    "doc2code_go_valid":"doc2code",
    "doc2code_java_test":"doc2code",
    "doc2code_java_valid":"doc2code",
    "doc2code_javascript_test":"doc2code",
    "doc2code_javascript_valid":"doc2code",
    "doc2code_php_test":"doc2code",
    "doc2code_php_valid":"doc2code",
    "doc2code_python_test":"doc2code",
    "doc2code_python_valid":"doc2code",
    "doc2code_ruby_test":"doc2code",
    "doc2code_ruby_valid":"doc2code",
    "gdoc2code_go_test":"gdoc2code",
    "gdoc2code_go_valid":"gdoc2code",
    "gdoc2code_java_test":"gdoc2code",
    "gdoc2code_java_valid":"gdoc2code",
    "gdoc2code_javascript_test":"gdoc2code",
    "gdoc2code_javascript_valid":"gdoc2code",
    "gdoc2code_php_test":"gdoc2code",
    "gdoc2code_php_valid":"gdoc2code",
    "gdoc2code_python_test":"gdoc2code",
    "gdoc2code_python_valid":"gdoc2code",
    "gdoc2code_ruby_test":"gdoc2code",
    "gdoc2code_ruby_valid":"gdoc2code",
    } 

def parse_dict_v1(xpath):
    dict_info = {}
    dict_info["path"]=xpath 
    xpath = os.path.basename(xpath)
    
    xpath = xpath.replace(".jsonl","").replace(".csv","").replace(".txt","")
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    
    task = dict_info["task"]
    split = task.split("_")[-1]
    dict_info["split"] =split 
    task = task.replace("_"+dict_info["split"] , "")
    dict_info["lang"]=  task.split("_")[-1] if "_" in task else None 
    dict_info["name"]=  task.split("_")[0] 
    if "mt" not in dict_info :
        dict_info["mt"]="baseline"
    if dict_info["name"]=="apps":
        dict_info["lang"]="python"

    if dict_info["task"]=="archive_stackexchange":
        dict_info["split"]="test"
        dict_info["lang"]=None
        dict_info["name"]=dict_info["task"]
        
    if "r" not in dict_info and "role" in dict_info :
        dict_info ["r"] =dict_info ["role"]
    if "r"  in dict_info and "role" not in dict_info :
        dict_info ["role"] =dict_info ["r"]
    
    return dict_info 


if __name__=="__main__":

    num_workers= os.cpu_count()-1
    
    save_dir = "/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/complexity_score_v3"
    os.makedirs(save_dir,exist_ok=True )
    
    save__split_p = os.path.join(save_dir  , "prepare_common_set_with_complexity.csv")
    if  not os.path.isfile(save__split_p):
        # load_common_set 
        COMMON_SET = {}
        
        common_dir = "/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/NL-CCD/common_set"
        for task in ["apps_test","concde_dev","doc2code"]:
            c_p = os.path.join(common_dir, "task={}.common_list".format(task ) )
            with open(c_p) as fr :
                cmm_set = set( [x.strip() for x in fr.readlines() ] )
            COMMON_SET [task ] = cmm_set
        COMMON_SET ["gdoc2code" ] = COMMON_SET["doc2code"]
        
        print ("finish load _common ")
        
    
        searh_dir = "/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/complexity_score_v2"
        p="all_request_extract_rm_languagemodel_baseline_q.jsonl"
        xap = os.path.join( searh_dir, p )
        
        with open(xap ) as f :
            data_list= f.readlines()
            videos = [x.strip() for x in data_list if len(x.strip())>0 ]
            
        print ("total read", len(videos) )
        
        def process_file(i):
            filename = videos[i]
            x_data = json.loads(filename)
            x_path = x_data.pop("xpath")
            # idx = x_data.pop("idx")
            idx = x_data["idx"]
            x_meta = parse_dict_v1(x_path)
            task = x_meta["task"]
            task_grp = TASK_GROUP[task ] 
            x_meta["task_grp"] = task_grp  
            if task_grp not in ["doc2code","apps_test","gdoc2code","concde_dev"]:
                return None 
        
            if idx not in  COMMON_SET[task_grp]:
                return None 
        
            x_meta.update( x_data )
                
            return x_meta
        
        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            predictions = ex.map(process_file, range(len(videos)))
        
        predictions =  list(predictions)
        print ("raw", len(predictions) )
        
        predictions = [x for x in predictions if x is not None  ]
        print ("not none ", len(predictions) )
        
        
        ## compxity 
        df = pd.DataFrame( predictions )
    
        print (df.shape, "df.shape" )
        
        df.to_csv(save__split_p , index=False )



    df = pd.read_csv(save__split_p )
    df = df[ df["role"]=="chatgpt_answer" ]
    print ( df.shape , "df--", df.columns )
        
    code_complexity_bins = [0,2,4,200]
    b_bins = np.digitize(df["content_complexity"] , bins=code_complexity_bins)
    
    df ["content_complexity_level"] = b_bins


    code_complexity_bins = [0,64, 128,512]
    b_bins = np.digitize(df["content_len"] , bins=code_complexity_bins)
    df ["content_len_intuite"] = b_bins


    code_complexity_bins = [0,93,342,512]
    b_bins = np.digitize(df["content_len"] , bins=code_complexity_bins)
    df ["content_len_plot"] = b_bins


    code_complexity_bins = [0,72,197,412]
    b_bins = np.digitize(df["content_len"] , bins=code_complexity_bins)
    df ["content_len_knn"] = b_bins



    df.to_csv(save__split_p.replace(".csv","_level.csv") , index=False )
    