import json
import numpy as np 
import pandas as pd 
from glob2 import glob 
import os 
import logging 
from tqdm import tqdm 


from typing import List 


task_list=[
    "apps_test",
    "doc2code_python_",
    ]


def metric_cal(y_true, y_prob=None ,y_label =None  ):

    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    if y_label is None :
        y_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
        y_label = y_label.astype(int)

    if y_prob is None :
        assert y_label is not None 
        y_prob = y_label
        

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)
    auc = metrics.auc(_fpr, _tpr)
    

    eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    chatgpt_len = np.sum(y_true)
    human_len =len(y_true)- np.sum(y_true)
    
    
    return {
        "auc":float(auc),
        "fpr":float(FPR),
        "fnr":float(FNR),
        "eer":float(eer),
        "chatgpt_size":int(chatgpt_len),
        "human_size":int(human_len),
        }

def parse_dict_v1(xpath):
    xpath = os.path.basename(xpath)
    xpath = xpath.replace(".jsonl","")
    dict_info = {}
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    return dict_info 

def dict_to_str(xpath_info):
    xpath_info = dict(sorted(xpath_info.items(), reverse=True ) )
    return ",".join( [f"{k}={v}" for k,v in xpath_info.items() ] )


def load_all_files(xpath_list):
    example_list =[]
    if type(xpath_list)!=list :
        xpath_list = [xpath_list]
    for one_xpath in xpath_list :
        data=[ json.loads(x) for x in open(one_xpath).readlines() ]
        example_list.extend( data )
        
    return example_list
def convert_sing_prob_softmax_prob(y_pred:List[int],y_prob:List[float]) :
    '''
    y_pred [1,0,1]
    y_prob [0.9,0.9,0.9]
    
    -->
    [ [0.1,0.9],[0.9,0.1],[0.1,0.9]
    ]
    
    '''
    blank = np.zeros((len(y_pred), 2 ) , dtype =float )
    for i, (one_pred ,one_prob) in enumerate( zip(y_pred,y_prob) ) :
        blank[i][one_pred]= one_prob
        blank[i][1-one_pred]= 1-one_prob

    np.max(blank)<=1
    np.min(blank)>=0
    
    return blank

def calc_metric( example_list ):
    y_true =np.array( [item["labels"] for item in example_list ] ,dtype=int )
    y_label =np.array( [item["pred"] for item in example_list ] ,dtype=int )
    y_prob =np.array( [item["prob"] for item in example_list ] ,dtype=float )
    print ("y_true", np.unique(y_true,return_counts=True) )
    print ("y_label", np.unique(y_label,return_counts=True) )
    print ("y_prob", np.max(y_prob), np.min(y_prob) )
    x_cnt  =convert_sing_prob_softmax_prob (y_pred=y_label, y_prob=y_prob ) 
    y_pred2 = x_cnt[:,-1]
    score = metric_cal(y_true=y_true, y_label=y_label, y_prob=y_pred2 )

    
    # score = metric_cal(y_true=y_true, y_prob=y_prob ,y_label =y_label  )
    return score 


def load_success_ids(xpath ): 
    interest_list= [xpath] if type(xpath)!=list else xpath 
    
    example_list= load_all_files( xpath_list=interest_list )

    common_ids  =[item["idx"] for item in example_list  if int(item["labels"])==1  and int(item["pred"])==1  ]
    return common_ids 

def load_success_ids_human(xpath ): 
    interest_list= [xpath] if type(xpath)!=list else xpath 
    
    example_list= load_all_files( xpath_list=interest_list )

    common_ids  =[item["idx"] for item in example_list  if int(item["labels"])==0  and int(item["pred"])==0  ]
    return common_ids 


import pprint 
if __name__=="__main__":
    baseline_search_dir = "/home/wj_cuda113/wj_code/dl_chatgpt/nlccd/finetune/baseline_save"
    
    
    selected_pretrained= "see4231_train_fse_doc2code_concde_apps"
    # search_dir = "/home/wj_cuda113/wj_code/dl_chatgpt/nlccd/finetune/save4"
    #search_dir = "/home/wj_cuda113/wj_code/dl_chatgpt/nlccd/finetune/save_mt"
    search_dir ="/home/wj_cuda113/wj_code/dl_chatgpt/nlccd/finetune/save_ob_mix_final_ablation"
    
    
    #load
    common_ids_set= {}
    
    
    df_list = []
    
    for task in task_list:
        file_list =[]
        if task == "apps_test":
            search_pattern = f"*task=apps_test*{selected_pretrained},*.jsonl"
            file_list = glob( os.path.join(search_dir,  search_pattern ) )
            file_list += glob( os.path.join(search_dir, "**" ,search_pattern ) )
        elif task=="doc2code_python_":
            search_pattern = f"*task=doc2code_python_valid*{selected_pretrained},*.jsonl"
            file_list += glob( os.path.join(search_dir,  search_pattern ) )
            file_list += glob( os.path.join(search_dir, "**", search_pattern ) )
            search_pattern = f"*task=doc2code_python_test*{selected_pretrained},*.jsonl"
            file_list += glob( os.path.join(search_dir,  search_pattern ) )
            file_list += glob( os.path.join(search_dir, "**" ,search_pattern ) )
        
        file_list = list(set(file_list))
        assert len(file_list)>0  , ( task, os.path.join(search_dir,  search_pattern )  )
        # file_list += glob( os.path.join(search_dir, "**", search_pattern ) )
        # file_list = [x for x in file_list if "add_dead_code" not in x ]
        df_list.extend(  [{"g":task,  "path":one_path,  **parse_dict_v1(one_path)}      for one_path in file_list  ] )

        ## load common ids 
        file_list_baseline = glob( os.path.join(baseline_search_dir,  search_pattern ) )
        
        # print (file_list_baseline, "file_list_baseline")
        assert len (file_list_baseline)> 0 ,  os.path.join(baseline_search_dir,  search_pattern )

        comm_id = load_success_ids( xpath= file_list_baseline)
        assert task+"_chatgpt_answer" not in common_ids_set 
        common_ids_set [task+"_chatgpt_answer"] = set(comm_id)
        
        assert task+"_human_answer" not in common_ids_set 
        comm_id2 = load_success_ids_human( xpath= file_list_baseline)
        common_ids_set [task+"_human_answer"] = set(comm_id2)
        
        
        assert len(comm_id)>1000 , ( len(set(comm_id)), "chat" )
        assert len(comm_id2)>1000 ,  ( len(set(comm_id2)), "chat" )
        
        print (task,"fiind apps _human_answer", len( common_ids_set [task+"_human_answer"] )) 
        print (task, "fiind  _chatgpt_answer", len( common_ids_set [task+"_chatgpt_answer"] ) ) 

    
    # print ("df_list", df_list )
    df = pd.DataFrame( df_list )
    # df.groupby(by, axis, level, as_index, sort, group_keys, observed, dropna)
    # print (df  )
     
                    
    grp_new_df  = df.groupby(['mt',"g","r"])['path'].apply(list)
                    
    grp_new_df = grp_new_df.reset_index(name='interest_list')
    print ("-----"*8 , grp_new_df )
    # grp_new_df .to_csv("./grp.csv",index=False)

    cal_list= []
    for meta_info in grp_new_df.to_dict(orient="records") :
        # print (meta_info )
        # continue 
        _task  = meta_info["g"]
        _role  = meta_info["r"]
        assert _role in ["chatgpt_answer","human_answer"], meta_info 
        
        baseline_pathlist = glob( os.path.join(baseline_search_dir,   f"task={_task}*q=gpt-3.5-turbo,mar=see4231_train_fse_doc2code_concde_apps,m=gpt-3.5-turbo_data.jsonl"  )) 
        print ("======"*8 , baseline_pathlist )
        interest_list = meta_info["interest_list"]
        assert len(interest_list)>0 , meta_info 
        interest_list = sorted(interest_list)

        print ("meta_info", meta_info )
        
        
        ret = {}
        ret.update({"mt":meta_info["mt"],"g":meta_info["g"],  })
        
        ex_list = load_all_files( xpath_list=interest_list )
        
        if _role =="chatgpt_answer":
            ex_list = [x for x in ex_list if x["labels"]==1 ]
            pre_list = load_all_files( xpath_list= baseline_pathlist )
            pre_list = [x for x in pre_list if x["labels"]==0 ]
            ex_list += pre_list

        else:
            ex_list = [x for x in ex_list if x["labels"]==0 ]
            pre_list = load_all_files( xpath_list= baseline_pathlist )
            pre_list = [x for x in pre_list if x["labels"]==1 ]
            ex_list += pre_list
            
        # print ( ex_list [:10], "--->common_ids_set-->",  list(common_ids_set[_task+"_"+_role ])[:10])
        # filter 
        raw_len =  len(ex_list)
        ex_list = [x for x in ex_list if x["idx"] in common_ids_set[_task+"_"+_role ] ]
        raw_len_after =  len(ex_list)
        #
        # print ("raw_len", raw_len, "after filter ", raw_len_after )
        print (np.unique( [x["labels"] for x in ex_list ], return_counts=True ),"raw_len ", raw_len, "raw_len_after", raw_len_after)
        
        
        chatgpt_count = [1 for x in ex_list if x["labels"]==1 ]
        chatgpt_count = sum(chatgpt_count)
        
        inconsist_chatgpt_count = [1 for x in ex_list if x["labels"]==1 and x["pred"]==0 ]
        inconsist_chatgpt_count = sum(inconsist_chatgpt_count)
        
        h_count = [1 for x in ex_list if x["labels"]==0 ]
        h_count = sum(h_count)
        
        inconsist_h_count = [1 for x in ex_list if x["labels"]==0 and x["pred"]==1 ]
        inconsist_h_count = sum(inconsist_h_count)


        ret.update(  calc_metric( example_list=ex_list ) )
        #
        ret.update({
            # "pretrain":pretrain,
            "_role":_role,
            "inconsist_c":inconsist_chatgpt_count,
            "inconsist_c_rate":float( inconsist_chatgpt_count/chatgpt_count) if chatgpt_count!=0 else 0  ,
            "inconsist_h":inconsist_h_count,
            "inconsist_h_rate":float( inconsist_h_count/h_count) if h_count!=0 else 0  ,
            
            "c_c":chatgpt_count,
            "h_c":h_count,
            "sz_raw":raw_len,
            "sz_common_id":raw_len_after,
            "file_c": len(file_list),
            "sample_size": len(ex_list),
            })
        #
        cal_list.append(ret )
        print (ret)

    
    df = pd.DataFrame( cal_list )
    df.to_csv("./rq3_inconsist_scale.csv",index=False ) 
    
        
