import json
import numpy as np 
import pandas as pd 
from glob2 import glob 
import os 
import logging 
from tqdm import tqdm 


from typing import List 


task_list=[
    "apps_test,",
    "doc2code_python_",
    ]


def metric_cal(y_true, y_prob=None ,y_label =None  ):

    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    if y_label is None :
        y_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
        y_label = y_label.astype(int)

    if y_prob is None :
        assert y_label is not None 
        y_prob = y_label
        

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)
    auc = metrics.auc(_fpr, _tpr)
    

    eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    chatgpt_len = np.sum(y_true)
    human_len =len(y_true)- np.sum(y_true)
    
    
    return {
        "auc":float(auc),
        "fpr":float(FPR),
        "fnr":float(FNR),
        "eer":float(eer),
        "chatgpt_size":int(chatgpt_len),
        "human_size":int(human_len),
        }

def parse_dict_v1(xpath):
    xpath = os.path.basename(xpath)
    xpath = xpath.replace(".jsonl","")
    dict_info = {}
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    return dict_info 

def dict_to_str(xpath_info):
    xpath_info = dict(sorted(xpath_info.items(), reverse=True ) )
    return ",".join( [f"{k}={v}" for k,v in xpath_info.items() ] )


def load_all_files(xpath_list):
    example_list =[]
    if type(xpath_list)!=list :
        xpath_list = [xpath_list]
    for one_xpath in xpath_list :
        data=[ json.loads(x) for x in open(one_xpath).readlines() ]
        example_list.extend( data )
        
    return example_list
def convert_sing_prob_softmax_prob(y_pred:List[int],y_prob:List[float]) :
    '''
    y_pred [1,0,1]
    y_prob [0.9,0.9,0.9]
    
    -->
    [ [0.1,0.9],[0.9,0.1],[0.1,0.9]
    ]
    
    '''
    blank = np.zeros((len(y_pred), 2 ) , dtype =float )
    for i, (one_pred ,one_prob) in enumerate( zip(y_pred,y_prob) ) :
        blank[i][one_pred]= one_prob
        blank[i][1-one_pred]= 1-one_prob

    np.max(blank)<=1
    np.min(blank)>=0
    
    return blank

def calc_metric( example_list ):
    y_true =np.array( [item["labels"] for item in example_list ] ,dtype=int )
    y_label =np.array( [item["pred"] for item in example_list ] ,dtype=int )
    y_prob =np.array( [item["prob"] for item in example_list ] ,dtype=float )
    print ("y_true", np.unique(y_true,return_counts=True) )
    print ("y_label", np.unique(y_label,return_counts=True) )
    print ("y_prob", np.max(y_prob), np.min(y_prob) )
    x_cnt  =convert_sing_prob_softmax_prob (y_pred=y_label, y_prob=y_prob ) 
    y_pred2 = x_cnt[:,-1]
    score = metric_cal(y_true=y_true, y_label=y_label, y_prob=y_pred2 )

    
    # score = metric_cal(y_true=y_true, y_prob=y_prob ,y_label =y_label  )
    return score 


import pprint 
if __name__=="__main__":
    selected_pretrained= "see4231_train_fse_doc2code_concde_apps"
    search_dir = "/home/wj_cuda113/wj_code/dl_chatgpt/nlccd/finetune/save2"
    
    df_list = []
    
    for task in task_list:
        search_pattern = f"task={task}*{selected_pretrained},*.jsonl"
        
        file_list = glob( os.path.join(search_dir,  search_pattern ) )
        file_list = [x for x in file_list if "add_dead_code" not in x ]
        df_list.extend(  [{"g":task,  "path":one_path,  **parse_dict_v1(one_path)}      for one_path in file_list  ] )
        
    df = pd.DataFrame( df_list )
    # df.groupby(by, axis, level, as_index, sort, group_keys, observed, dropna)
    print (df  )
     
                    
    grp_new_df  = df.groupby(['mt',"g"])['path'].apply(list)
                    
    grp_new_df = grp_new_df.reset_index(name='interest_list')
    pprint.pprint (grp_new_df["interest_list"].tolist())

    cal_list= []
    for meta_info in grp_new_df.to_dict(orient="records") :
        interest_list = meta_info["interest_list"]
        assert len(interest_list)>0 , meta_info 
        interest_list = sorted(interest_list)

        
        ret = {}
        ret.update({"mt":meta_info["mt"],"g":meta_info["g"],  })
        
        ex_list= load_all_files( xpath_list=interest_list )
        #
        effect_list= [ ( (x["chatgpt_answer"] is not None and  "###@##" in x["chatgpt_answer"])   or  (x["human_answer"] is not None and  "###@##" in x["human_answer"] )  ) for x in ex_list  ]
        
        ret.update(  calc_metric( example_list=ex_list ) )
        #
        ret.update({
            # "pretrain":pretrain,
            # "dataset":dt_key,
            "effect": sum(effect_list), 
            "file_c": len(file_list),
            "sample_size": len(ex_list),
            })
        #
        cal_list.append(ret )
        print (ret)

    
    df = pd.DataFrame( cal_list )
    df.to_csv("./rq3_full_scale.csv",index=False ) 
    
        