import json 
from glob2 import glob 
import os 

import numpy as np 

import itertools 

def metric_cal(y_true, y_prob ):

    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    y_prob_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
    y_prob_label = y_prob_label.astype(int)

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_prob_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)

    eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    
    auc = metrics.auc(_fpr, _tpr)
    return {
        "auc":auc,
        "fpr":FPR,
        "fnr":FNR,
        "eer":eer,
        }


def test_metric ():
    y_true= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0] )
    y_prob= np.array( [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] )
    
    result = metric_cal (y_true=y_true , y_prob=y_prob)
    # tpr
    # print (result, 1-result["fnr"])
    assert np.isclose(1-result["fnr"] , 0.2222222222)

    assert np.isclose(result["fpr"] , 0.0625)

    y_true= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0] )
    y_prob= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0] )
    result = metric_cal (y_true=y_true , y_prob=y_prob)
    print (result)

    y_true= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0] )
    y_prob= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0] )
    result = metric_cal (y_true=y_true , y_prob=y_prob)
    print (result)

    #============
    y_true= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0] )
    y_prob= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] )
    result = metric_cal (y_true=y_true , y_prob=y_prob)
    print (result)

    y_true= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0] )
    y_prob= np.array( [1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0] )
    result = metric_cal (y_true=y_true , y_prob=y_prob)
    print (result)


def load_func( item , q_name ):
    def load_GPT2_Detector_out():
        label = item["result"]["result"] 
        if label ==-1 :
            return None 
        prob = item["result"]["probability"]
        # if prob is None :
        #     print (item )
        if label >0 :
            prob = max([prob,1-prob])
        else:
            prob = min([prob,1-prob])
        if label ==0 :
            assert prob <=0.5 , item 
            
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_GPTZero_out():
        label = item["result"]["result"] 
        prob=  label 
        if label ==-1 or label==2 :
            return None 
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_simpleai():
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==0:
            assert prob<=0.5 , item 
        if label ==-1 or label==2 :
            return None 
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_aidetector():
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==-1 or label==2 :
            return None 
        
        if label ==0:
            prob= min([prob,1-prob])
        else:
            prob= max([prob,1-prob])

        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_saplingai():
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==-1 or label==2 :
            return None 
        
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_saplingai():
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==-1 or label==2 :
            return None 
        
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_scribbr_com_out():
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==-1 or label==2 :
            return None 
        
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_gptzero_me_out():
        if item["result"] is None :
            return None 
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==-1 or label==2 :
            return None 
        
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    def load_writer_com_out():
        if item["result"] is None :
            return None 
        label = item["result"]["result"] 
        prob = item["result"]["probability"] 
        
        if label ==-1 or label==2 :
            return None 
        
        return {
            "pred_label":label, 
            "pred_prob":prob,
            "idx":item["id"],
            "is_content_null":item["content"] is None or len(item["content"])<=0 ,
            }
    
        

    if q_name =="GPT2_Detector_out":
        return load_GPT2_Detector_out()
    elif q_name =="GPTZero_out":
        return load_GPTZero_out()
    elif q_name in ["Hello-SimpleAI-qa_out",
                    "Hello-SimpleAI_out"]:
        return load_GPTZero_out()
    # elif q_name =="ai_detector__compilatio__net_out":
    #     return load_aidetector()
    # elif q_name =="sapling_ai_out":
    #     return load_saplingai()
    # elif q_name =="sapling_ai_out":
    #     return load_saplingai()
    elif q_name =="scribbr_com_out":
        return load_scribbr_com_out()
    elif q_name =="gptzero_me_out":
        return load_gptzero_me_out()
    elif q_name =="writer_com_out":
        return load_writer_com_out()
        

# task_list = {
#     "apps":"apps_test",
#     "doc2code_test_valid":["_".join([x,y,z]) for x,y,z in itertools.product(["doc2code"],["go","python","php","java","javascript","ruby"],["test","valid"]) ],
#     "doc2code_test":["_".join([x,y,z]) for x,y,z in itertools.product(["doc2code"],["go","python","php","java","javascript","ruby"],["test"]) ],
#     "code2doc_test_valid":["_".join([x,y,z]) for x,y,z in itertools.product(["code2doc"],["go","python","php","java","javascript","ruby"],["test","valid"]) ],
#     "code2doc_test":["_".join([x,y,z]) for x,y,z in itertools.product(["code2doc"],["go","python","php","java","javascript","ruby"],["test"]) ],
#
#     "concde_dev":"concde_dev",
#     "archive_stackexchange":"archive_stackexchange",
#     }

def parse_dict_v1(xpath):
    dict_info = {}
    dict_info["path"]=xpath 
    xpath = os.path.basename(xpath)
    
    
    
    xpath = xpath.replace(".jsonl","")
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    
    task = dict_info["task"]
    split = task.split("_")[-1]
    dict_info["split"] =split 
    task = task.replace("_"+dict_info["split"] , "")
    dict_info["lang"]=  task.split("_")[-1] if "_" in task else None 
    dict_info["name"]=  task.split("_")[0] 
    if "mt" not in dict_info :
        dict_info["mt"]="baseline"
    if dict_info["name"]=="apps":
        dict_info["lang"]="python"

    if dict_info["task"]=="archive_stackexchange":
        dict_info["split"]="test"
        dict_info["lang"]=None
        dict_info["name"]=dict_info["task"]
        
    return dict_info 


def load_one_path_result(xpath):
    meta_info = parse_dict_v1(xpath) 
    is_chatgpt=meta_info.get("r","chatgpt_answer")
    is_chatgpt = is_chatgpt=="chatgpt_answer"

    q_name =meta_info.get("q",None )
    assert q_name is not None 


    
    with open(xpath)  as f :
        data = f.readlines() 
        datalist= []
        # data = [json.loads(x) for x in data ]
        for one_line in data:
            try :
                one_line = json.loads(one_line)
                datalist.append(one_line)
            except :
                print ("---->", one_line )
        data = datalist 
    result = [load_func(item,q_name=q_name )  for item in data ]
    r1 = len(result)
    result = [x for x in result if x is not None ]
    r2 = len(result)
    result = [{**item, "y_true":int(is_chatgpt) } for item in result ]
    r3 = len(result)

    print ("load, ", q_name ,"len.result.r1",r1, "len.result.r2",r2, "len.result.r3",r3 , is_chatgpt, "is_chatgpt" )
    
    return result 

def calc(x_path_list ):
    assert len(x_path_list)>0, x_path_list
    input_list=  []
    
    for xpath in x_path_list :
        input_list.extend( load_one_path_result(xpath= xpath ) )
    
    
        
    y_true =np.array( [item["y_true"] for item in input_list ] )
    y_pred =np.array( [item["pred_label"] for item in input_list ] )
    
    print ("true,", np.unique(y_true,return_counts=True) )
    print ("pred,", np.unique(y_pred,return_counts=True) )
    
    
    assert len(y_true)>0 ,(y_true.shape,y_pred.shape  )
    assert len(y_pred)>0 ,(y_true.shape,y_pred.shape  )
    
    chatgpt_count =int( np.sum(y_true) )
    assert chatgpt_count<=len(y_true )
    human_count =len(y_true)-chatgpt_count
    
    metric_result  = metric_cal (y_true=y_true , y_prob=y_pred)

    metric_result.update({"human_c":human_count,"chatgpt_c":chatgpt_count } )
    
    return metric_result 

import pprint
from tqdm import tqdm  

mt_list = [
    "AssginAddLine",
    "AugAssgin2Assign",
    "Comp2For",
    "ExprStmt2Assign",
    "IfExpr2Stmt",
    "IfStmt2IfStmt",
    "WhileStmt",
    "fb_obfuscator",
    "rename_func",
    "rename_var",
    ]
import pprint 
if __name__=="__main__":
    import pandas as pd 
    
    # mutate_name= "baseline"
    # for mutate_name in mt_list:
    for mutate_name in ["baseline"]:
    
        # test_metric()
        detector_name_list = [
                            "GPT2_Detector_out",
                            "GPTZero_out",
                            "Hello-SimpleAI-qa_out",
                            "Hello-SimpleAI_out",
                            # "mt_func1_out",
                              
                            # "ai_detector__compilatio__net_out",
                            # "sapling_ai_out",
                            # # "writefull_com_out",
                            #
                            "scribbr_com_out",
                              "gptzero_me_out",
                              "writer_com_out",
                              ]
        dir_path = "/mnt/nvme/data3/icse_dataset/llm_save_data_new2"
        #
        cal_list=  []
        for detector_name in tqdm(detector_name_list) :
            # for m in ["gpt-3.5-turbo","CodeLlama-34b-hf"]:
            for lang in ["go","python","php","java","javascript","ruby"]:
                for m in ["gpt-3.5-turbo",]:
                    fl = glob( os.path.join(dir_path,detector_name, "*.jsonl") )
                    # print (fl)
                    info_list=  [parse_dict_v1(xpath=one_f ) for one_f in fl ]
                    
                    df = pd.DataFrame( info_list )
                    
                    df.to_csv("tmp_fill.csv",index=False)
                    
                    print( "total files",len(df)  )
                    print( df["lang"].value_counts()  )
                    print( df["name"].value_counts()  )
                    print( df["mt"].value_counts()  )
                    print( df["split"].value_counts()  )
                    
                    # df.to_csv("./a1.csv",index=False )
                    assert len( df[ pd.isnull(df["mt"])]  )==0 
                    
                    new_df=  df [( (df["split"].isin(["test","valid","dev"] ) )  & (df["mt"]==mutate_name  ) & (df["m"]==m  )& (df["lang"]==lang  ) )  ]
                    
                    grp_new_df  = new_df.groupby('name')['path'].apply(list)
                    
                    grp_new_df = grp_new_df.reset_index(name='interest_list')
                    pprint.pprint (grp_new_df["interest_list"].tolist())
                    
                    with open(f"/tmp/{lang}.log_c","a") as fff :
                        fff.write( json.dumps( grp_new_df.to_dict(orient="records") , indent=4 ) ) 
                    
                    
                    
                    for meta_info in grp_new_df.to_dict(orient="records") :
                        interest_list = meta_info["interest_list"]
                        task_name  = meta_info["name"]
                        
                        assert len(interest_list)>0 , meta_info 
                        
                        interest_list = sorted(interest_list)
                        # pprint.pprint (interest_list )
                    # print ( new_df .to_dict(orient="records") )
                        ret = {}
                        ret.update({"detector":detector_name.replace("_out",""), "task":task_name, "m":m, "lang":lang  })
                        ret.update(  calc( interest_list ) )
                        cal_list.append(ret )
                        print (ret)
                        
            
        final_df = pd.DataFrame( cal_list )  
        final_df.to_csv(f"./final_{mutate_name}_6language_result.csv",index=False) 

            