import json 
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm 
import os 
import sys 
num_workers = os.cpu_count()-1 
import pandas as pd 
import numpy as np 
import random 
import traceback 

from itertools import product 

def metric_cal(y_true, y_prob, y_prob_label= None  ):
    # print ("y_true", len(y_true), "y_prob", len(y_prob),  "y_prob_label" , type(y_prob_label) )
    
    assert len(y_true)==len(y_prob)
    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    if y_prob_label is None :
        y_prob_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
        y_prob_label = y_prob_label.astype(int)

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_prob_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)

    # eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    
    auc = metrics.auc(_fpr, _tpr)
    return {
        "auc":auc,
        "fpr":FPR,
        "fnr":FNR,
        # "eer":eer,
        }

MODEL_LIST=[
    "CodeLlama-34b-Instruct-hf",
    "WizardCoder-15B-V1.0",
    "WizardCoder-Python-34B-V1.0",
    "gpt-3.5-turbo",
    ]


full_match_list= [
{"concde_dev.jsonl": 1999}, 

{"gdoc2code_java_test.jsonl": 10949},
{"gdoc2code_python_valid.jsonl": 13906},
{"gdoc2code_javascript_valid.jsonl": 3879}, 
{"gdoc2code_javascript_test.jsonl": 3287}, 
{"gdoc2code_php_test.jsonl": 14008},
{"gdoc2code_php_valid.jsonl": 12976}, 
{"gdoc2code_python_test.jsonl": 14912}, 
{"gdoc2code_ruby_valid.jsonl": 1397}, 
{"gdoc2code_java_valid.jsonl": 5178}, 
{"gdoc2code_ruby_test.jsonl": 1258},
{"gdoc2code_go_test.jsonl": 8116},
{"gdoc2code_go_valid.jsonl": 7320},

{"doc2code_java_test.jsonl": 10949},
{"doc2code_python_valid.jsonl": 13906},
{"doc2code_javascript_valid.jsonl": 3879}, 
{"doc2code_javascript_test.jsonl": 3287}, 
{"doc2code_php_test.jsonl": 14008},
{"doc2code_php_valid.jsonl": 12976}, 
{"doc2code_python_test.jsonl": 14912}, 
{"doc2code_ruby_valid.jsonl": 1397}, 
{"doc2code_java_valid.jsonl": 5178}, 
{"doc2code_ruby_test.jsonl": 1258},
{"doc2code_go_test.jsonl": 8116},
{"doc2code_go_valid.jsonl": 7320},


{"apps_test.jsonl": 4999}, 

{"archive_stackexchange_test.jsonl": 2195},

{"concde_test.jsonl": 1999}, 

{"code2doc_go_valid.jsonl": 7324}, 
{"code2doc_python_valid.jsonl": 13913}, 
{"code2doc_php_test.jsonl": 14013}, 
{"code2doc_ruby_valid.jsonl": 1399},
{"code2doc_java_valid.jsonl": 5182}, 
{"code2doc_java_test.jsonl": 10954}, 
{"code2doc_ruby_test.jsonl": 1260},
{"code2doc_python_test.jsonl": 14917}, 
{"code2doc_javascript_test.jsonl": 3290}, 
{"code2doc_go_test.jsonl": 8121}, 
{"code2doc_javascript_valid.jsonl": 3884},
{"code2doc_php_valid.jsonl": 12981},
]

TASK_GROUP  = {
    "apps_test":"apps_test",
    "archive_stackexchange_test":"archive_stackexchange_test",
    "code2doc_go_test":"code2doc",
    "code2doc_go_valid":"code2doc",
    "code2doc_java_test":"code2doc",
    "code2doc_java_valid":"code2doc",
    "code2doc_javascript_test":"code2doc",
    "code2doc_javascript_valid":"code2doc",
    "code2doc_php_test":"code2doc",
    "code2doc_php_valid":"code2doc",
    "code2doc_python_test":"code2doc",
    "code2doc_python_valid":"code2doc",
    "code2doc_ruby_test":"code2doc",
    "code2doc_ruby_valid":"code2doc",
    "concde_dev":"concde_dev",
    "concde_test":"concde_dev",
    "doc2code_go_test":"doc2code",
    "doc2code_go_valid":"doc2code",
    "doc2code_java_test":"doc2code",
    "doc2code_java_valid":"doc2code",
    "doc2code_javascript_test":"doc2code",
    "doc2code_javascript_valid":"doc2code",
    "doc2code_php_test":"doc2code",
    "doc2code_php_valid":"doc2code",
    "doc2code_python_test":"doc2code",
    "doc2code_python_valid":"doc2code",
    "doc2code_ruby_test":"doc2code",
    "doc2code_ruby_valid":"doc2code",
    "gdoc2code_go_test":"gdoc2code",
    "gdoc2code_go_valid":"gdoc2code",
    "gdoc2code_java_test":"gdoc2code",
    "gdoc2code_java_valid":"gdoc2code",
    "gdoc2code_javascript_test":"gdoc2code",
    "gdoc2code_javascript_valid":"gdoc2code",
    "gdoc2code_php_test":"gdoc2code",
    "gdoc2code_php_valid":"gdoc2code",
    "gdoc2code_python_test":"gdoc2code",
    "gdoc2code_python_valid":"gdoc2code",
    "gdoc2code_ruby_test":"gdoc2code",
    "gdoc2code_ruby_valid":"gdoc2code",
    } 

TASK_GROUP_VALUE = set( list(TASK_GROUP.values() ) )


FULL_MATCH_META ={}
[FULL_MATCH_META.update(x) for x in full_match_list ]

FULL_MATCH_META = {k.replace(".jsonl",""):v for k,v in FULL_MATCH_META.items() }


FULL_MATCH_META_group = {}
for k,v in FULL_MATCH_META.items():
    grp = TASK_GROUP[k]
    size = FULL_MATCH_META[k]
    if grp not in FULL_MATCH_META_group:
        FULL_MATCH_META_group[grp] = size
    else :
        FULL_MATCH_META_group[grp] = size+ FULL_MATCH_META_group[grp]  
    



if __name__=="__main__":
    xp = sys.argv[-1]
    save_p = "/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/save_dir2/each_q"

    MODEL_LIST_temp = list( product( MODEL_LIST, ["0.2","0.8","0.01"]) )
    
    

    
    print ("read file")
    with open(xp) as f :
        videos = f.readlines() 
        videos =[x for x in videos if len(x.strip())>0 ]
        # print ("shuf")
        # random.shuffle( videos )
        # print ("shuf done ")
        # videos = videos[:100000]
    print ("total ,", len(videos ) )

    final_csv = []
    
    def process_file(i):
        if i % 50000==0:
            print (i)
        filename = videos[i]
        try :
            xjson = json.loads(filename)
        except :
            print ( filename ,"-->")
            return None 
        
        if xjson["label"] not in [0,1] :
            return None  


        xjson["task2"] =  TASK_GROUP.get(xjson["task"],None)
        role = xjson["r"] if "r" in xjson else xjson["role"]
        if role is None :
            return None 
        xjson["gt"] = int( role=="chatgpt_answer")
        xjson["uniq"] ="@".join( [role ,xjson["q"],xjson["temp"],xjson["task"],xjson["m"],xjson["id"]])
        
        xjson["prob"] = xjson["prob"]  if xjson["prob"] >=0 else 0 
        return xjson 

    with ThreadPoolExecutor(max_workers=num_workers) as ex:
        predictions = ex.map(process_file, range(len(videos)))
    predictions =  list(predictions)

    print ("predictions.raw", len(predictions) )    
    predictions = [x for x in predictions  if x is not None ]
    print ("predictions.not none ", len(predictions) )    
    
    
    
    for MODEL_NAME,X_TEMP  in tqdm(MODEL_LIST_temp ) :
        
        
        # print ("predictions", len(predictions) )    
            # if xjson["m"] == MODEL_NAME :
            #     return xjson 
        predictions_model  = [x for  x in predictions if x["m"]==MODEL_NAME and x["temp"]==X_TEMP  ]
        
        # print ("predictions_model", len(predictions_model), MODEL_NAME )
        
        for one_task  in TASK_GROUP_VALUE :
            
            
            try :
        
                predictions_task  = [x for x in predictions_model  if x["task2"]==one_task ]
                if len(predictions_task)==0:
                    continue 
                
                df_predictions_task = pd.DataFrame( predictions_task )
                
                
                # print (df_predictions_task, "--->df_predictions_task" )
                raw_len = len(df_predictions_task)
                df_predictions_task = df_predictions_task[ ~df_predictions_task . duplicated (subset="uniq") ]
                predictions_task = df_predictions_task.to_dict(orient="records")
                new_len = len(df_predictions_task)
                
                if raw_len > new_len :
                    print ("duplicated..", raw_len-new_len , raw_len )
                
                y_prob_raw= [x["prob"] for x in predictions_task ]
                y_pred = [x["label"] for x in predictions_task ]
                
                # print ("y_prob_raw", y_prob_raw[:20])
                y_prob  = np.array( y_prob_raw )
                # y_pred  = np.array( y_pred )
                y_true =np.array( [x["gt"] for x in predictions_task ] )
                
                # np.unique(y_true,return_counts=True)
                
                unique_values, counts = np.unique(y_true, return_counts=True)
                unique_values = ["gt_"+str(x) for x in unique_values ]
                result_dict = dict(zip(unique_values, counts))
    
                
                df_prob = pd.DataFrame( y_prob, columns=["prob"] )
                df_prob_none = df_prob[ pd.isnull(df_prob["prob"]) ]
                assert len (df_prob_none)==0 , (df_prob_none.shape, "there are nuone",df_prob.shape,  one_task , MODEL_NAME ,xp )
                assert y_prob.max()<=1 and y_prob.min()>=0 , (df_prob.describe () ,  "there are nuone", one_task , MODEL_NAME ,xp )
                # print ( df_prob["prob"].describe() , "-->df_prob" )
                
                assert one_task in FULL_MATCH_META_group , one_task
                expect_full_size =    FULL_MATCH_META_group [one_task] 
                
                
                y_true_add = []
                y_true_add += [0]* (expect_full_size- (result_dict["gt_0"] if "gt_0" in  result_dict else 0) )
                y_true_add += [1]* (expect_full_size- (result_dict["gt_1"] if "gt_1" in  result_dict else 0) )

                y_pred_add  = 1-np.array(y_true_add) 
                y_prob_add  = 1-np.array(y_true_add) 
                
                def concat(y1,y2=None ):
                    if y2 is None :
                        return y1 
                    if type(y1)==list:
                        y1 = np.array(y1)
                    if type(y2)==list:
                        y2 = np.array(y2)
                    return np.concatenate([y1,y2])
                
                add_meta ={
                    "y_true":len(y_true),
                    "y_true_add":len(y_true_add),
                    "y_pred":len(y_pred),
                    "y_pred_add":len(y_pred_add),
                    "y_prob":len(y_prob),
                    "y_prob_add":len(y_prob_add),
                    }
                # print ("----", one_task, add_meta , np.unique(y_true,return_counts=True ) )
                print (one_task, "-->one_task" , add_meta, "add_meta")
                
                y_true = concat(y_true, y_true_add)
                y_prob = concat(y_prob, y_prob_add)
                y_pred = concat(y_pred, y_pred_add)
                
                # print ("----", one_task, add_meta , np.unique(y_true,return_counts=True ) )

                eval_result = metric_cal(y_true=y_true , y_prob=y_prob ,y_prob_label = y_pred )
                
                # update the unique 
                unique_values, counts = np.unique(y_true, return_counts=True)
                unique_values = ["gt_"+str(x) for x in unique_values ]
                result_dict = dict(zip(unique_values, counts))
    
                
                eval_result["task"]= one_task
                eval_result["temp"]= X_TEMP
                eval_result["meta"]= add_meta
                
                eval_result["model_name"]= MODEL_NAME
                eval_result.update(result_dict)
                # print ( eval_result )
                final_csv.append( eval_result )
            
            except :
                print (MODEL_NAME,X_TEMP, one_task)
                traceback.print_exc() 
                pass 
            
    df_final = pd.DataFrame( final_csv )
    df_final.to_csv( os.path.join(save_p, os.path.basename(xp).replace(".jsonl",".csv") ) , index=False )
    
