import json 
import pickle 
import os 
import pandas as pd 
from glob2 import glob 
import numpy as np 

if __name__=="__main__":
    label_p = "/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/complexity_score_v3/prepare_common_set_with_complexity_level.csv"
    
    df_label = pd.read_csv(label_p)
    df_label_doc2code = df_label[ df_label ["task_grp"]=="doc2code" ]#& (df_label ["lang"]=="python") ]
    
    # selected_idx_list  = df_label_doc2code ["idx"]
    # selected_idx_list = set( selected_idx_list.tolist() )
    LANG_INFO = {}
    for one_lang  in  [1,2,3]:
        idx_list = df_label_doc2code [ df_label_doc2code["content_complexity_level"]== one_lang ]["idx"]
        LANG_INFO [one_lang] = set(idx_list)
        print (one_lang  , len( set( idx_list )) )
    
    # exit()
    #
    # # complexity,
    # a_heights, a_bins = np.histogram( df_label_doc2code["content_complexity"]  , bins=3  )
    #
    # df_label_doc2code ["python_complex"] = np.digitize(df_label_doc2code["content_complexity"]  , bins=[0,5,10,200] )
    #
    # print (df_label_doc2code["python_complex"] .value_counts() )
    # exit()
    
    def auc_process(y_true, y_pred, y_prob , expect_full_size= None   ):

        def metric_cal(y_true, y_prob, y_prob_label= None  ):
            # print ("y_true", len(y_true), "y_prob", len(y_prob),  "y_prob_label" , type(y_prob_label) )
            
            assert len(y_true)==len(y_prob)
            
            from sklearn import metrics
            from scipy.optimize import brentq
            from scipy.interpolate import interp1d
        
            if y_prob_label is None :
                y_prob_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
                y_prob_label = y_prob_label.astype(int)
        
            TN, FP, FN, TP  =\
                metrics.confusion_matrix(y_true = y_true, 
                                 y_pred = y_prob_label, 
                                 ).ravel()
        
            # print (cmt,"cmt....")
            FPR = FP / float(FP+ TN  )
            TPR = TP/ float( TP+FN )
            FNR = FN/ float( FN+TP )
        
            _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)
            auc = metrics.auc(_fpr, _tpr)
            return {
                "auc":auc,
                "fpr":FPR,
                "fnr":FNR,
                # "eer":eer,
                }
        
        y_true_add = []
        y_true_add += [0]* (expect_full_size-   sum(y_true==0) )
        y_true_add += [1]* (expect_full_size-   sum(y_true==1) )
        
        if len(y_true_add) > 0:
            y_pred_add  = 1-np.array(y_true_add) 
            y_prob_add  = 1-np.array(y_true_add) 
            
            def concat(y1,y2=None ):
                if y2 is None :
                    return y1 
                if type(y1)==list:
                    y1 = np.array(y1)
                if type(y2)==list:
                    y2 = np.array(y2)
                return np.concatenate([y1,y2])
            
            add_meta ={
                "y_true":len(y_true),
                "y_true_add":len(y_true_add),
                "y_pred":len(y_pred),
                "y_pred_add":len(y_pred_add),
                "y_prob":len(y_prob),
                "y_prob_add":len(y_prob_add),
                "expect_full_size":expect_full_size,
                }
            y_true = concat(y_true, y_true_add).astype(int)
            y_prob = concat(y_prob, y_prob_add)
            y_pred = concat(y_pred, y_pred_add)
        else:
            add_meta ={
                "y_true":len(y_true),
                "y_true_add":0,
                "y_pred":len(y_pred),
                "y_pred_add":0 ,
                "y_prob":len(y_prob),
                "y_prob_add":0 ,
                "expect_full_size":expect_full_size,
                }
            y_true =np.array(  y_true  ).astype(int) if type(y_true) ==list  else y_true 
        
        # print (add_meta)
        metric_info = metric_cal(y_true=y_true, y_prob=y_prob, y_prob_label= y_pred  )
        metric_info .update(add_meta)
        return metric_info
    
    
    def process_one (fn ):
        with open(fn ,"rb") as fread :
            data = pickle.load(  fread ) 
    
        print (fn )
        len_data = len(data) 
        
        for one_item in data :
            task_grp  = one_item["task"]
            if task_grp != "doc2code" :
                continue
            print  (list(one_item) ,"list.one_item", one_item["task"], )
            # lang   = one_item["lang"]
            # if lang != "python":
            #     continue 
            
            model  = one_item["model_name"]
            temp   = one_item["temp"]
            
            uniq_list = one_item["uniq_list"]
            uniq_list = [x.split("@")[-1] for x in uniq_list ]
            
            # print (list(one_item),task_grp,uniq_list[:10] )
            y_true  = np.array( one_item["y_true"][:len(uniq_list)] )
            y_prob  = np.array( one_item["y_prob"][:len(uniq_list)] )
            y_pred  = np.array( one_item["y_pred"][:len(uniq_list)] )
            
            assert len(y_true)== len(uniq_list)
            assert len(y_prob)== len(uniq_list)
            assert len(y_pred)== len(uniq_list)
            
            df_complex = pd.DataFrame( ) 
            
            for lang, lang_idx_list in LANG_INFO.items():
                idx_list_index = [ x in lang_idx_list   for x in uniq_list  ]
                # lang_idx_list_debug   = [x  for x in uniq_list  if "/{}/".format(lang) in x]
                
                assert len( set(lang_idx_list)-  set(lang_idx_list) ) ==0 
                
                # print ("lang-->", lang, sum(idx_list_index) )
                y_true_lang  = y_true [idx_list_index ]
                y_pred_lang  = y_pred [idx_list_index ]
                y_prob_lang  = y_prob [idx_list_index ]
                expect_full_size = len(lang_idx_list)
                metric_lang = auc_process(y_true=y_true_lang, y_pred=y_pred_lang, y_prob=y_prob_lang , expect_full_size= expect_full_size   )

                metric_lang["task"]= one_item["task"]
                metric_lang["model_name"]= one_item["model_name"]
                metric_lang["temp"]= one_item["temp"]
                metric_lang["lang"]= lang 
                metric_lang["detector"]= os.path.basename(fn) .replace(".pkl","") 
                
                print (metric_lang)
                final_list.append( metric_lang)

    final_list=  []
    
    search_p ="/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/save_dir2/each_q"
    fl = glob( os.path.join(search_p, "*.pkl" ) )
            
    for one_f in fl :
        process_one( fn =one_f ) 


    save_p = "/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/complexity_score_v3/3_complex.csv"
    df = pd.DataFrame( final_list )
    df .  to_csv( save_p , index=False )
    
    
    