import json
import numpy as np 
import pandas as pd 
from glob2 import glob 
import os 
import logging 
from tqdm import tqdm 


from typing import List 

logger = logging.getLogger(__name__)

MAR_LIST = {
    "see4231_qa":["QA","code2doc"],
    "see4231_train_fse_code2doc":["QA","code2doc"],
    "see4231_train_fse_code2doc_qa":["QA","code2doc"],
    
    "see4231_apps":["doc2code","apps","concde"],
    "see4231_train_fse_concode":["doc2code","apps","concde"],
    "see4231_train_fse_doc2code":["doc2code","apps","concde"],
    "see4231_train_fse_doc2code_concde":["doc2code","apps","concde"],
    "see4231_train_fse_doc2code_concde_apps":["doc2code","apps","concde"],
    }

DATASET_LIST={
        "QA":["archive_stackexchange_test"],
        "code2doc":["code2doc*test", "code2doc*valid", ],
        "doc2code":["doc2code*test", "doc2code*valid", ],
        "concde":["concde_dev"],
        "apps":["apps_test"],
    
    }
    

def metric_cal(y_true, y_prob=None ,y_label =None  ):

    from sklearn.metrics import accuracy_score    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    if y_label is None :
        y_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
        y_label = y_label.astype(int)

    if y_prob is None :
        assert y_label is not None 
        y_prob = y_label
        

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)
    auc = metrics.auc(_fpr, _tpr)
    

    eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    chatgpt_len = np.sum(y_true)
    human_len =len(y_true)- np.sum(y_true)
    
    
    return {
        "acc":accuracy_score(y_true, y_label),
        "auc":float(auc),
        "fpr":float(FPR),
        "fnr":float(FNR),
        "eer":float(eer),
        "chatgpt_size":int(chatgpt_len),
        "human_size":int(human_len),
        }

    
def get_file_list(pretrain, search_key ):
    ret= []
    if type(search_key)!=list :
        search_key = [search_key]
    
    for one_search_key in search_key  :
        build_s= f"task={one_search_key}*,mar={pretrain},*.jsonl"
        search_cmd = os.path.join(root_dir, build_s  )
        
        # print (search_cmd,"search_cmd")
        fl = glob (  search_cmd )
        ret. extend( fl )
    
    return ret 
    
def load_all_files(xpath_list):
    example_list =[]
    if type(xpath_list)!=list :
        xpath_list = [xpath_list]
    for one_xpath in xpath_list :
        data=[ json.loads(x) for x in open(one_xpath).readlines() ]
        example_list.extend( data )
        
    return example_list
def convert_sing_prob_softmax_prob(y_pred:List[int],y_prob:List[float]) :
    '''
    y_pred [1,0,1]
    y_prob [0.9,0.9,0.9]
    
    -->
    [ [0.1,0.9],[0.9,0.1],[0.1,0.9]
    ]
    
    '''
    blank = np.zeros((len(y_pred), 2 ) , dtype =float )
    for i, (one_pred ,one_prob) in enumerate( zip(y_pred,y_prob) ) :
        blank[i][one_pred]= one_prob
        blank[i][1-one_pred]= 1-one_prob

    np.max(blank)<=1
    np.min(blank)>=0
    
    return blank

def calc_metric( example_list ):
    y_true =np.array( [item["labels"] for item in example_list ] ,dtype=int )
    y_label =np.array( [item["pred"] for item in example_list ] ,dtype=int )
    y_prob =np.array( [item["prob"] for item in example_list ] ,dtype=float )
    print ("y_true", np.unique(y_true,return_counts=True) )
    print ("y_label", np.unique(y_label,return_counts=True) )
    print ("y_prob", np.max(y_prob), np.min(y_prob) )
    x_cnt  =convert_sing_prob_softmax_prob (y_pred=y_label, y_prob=y_prob ) 
    y_pred2 = x_cnt[:,-1]
    score = metric_cal(y_true=y_true, y_label=y_label, y_prob=y_pred2 )

    
    # score = metric_cal(y_true=y_true, y_prob=y_prob ,y_label =y_label  )
    return score 

    
if __name__=="__main__":
    import pandas as pd 
    
    root_dir = "/home/wj_cuda113/wj_code/dl_chatgpt/nlccd/finetune/save"
    
    final_list = []
    
    for pretrain in tqdm( list(MAR_LIST) ):
        
        dataset_keys = MAR_LIST[ pretrain ]
        
        for dt_key in dataset_keys:
            
            dt_pattern = get_file_list(
                pretrain=pretrain,
                search_key =  DATASET_LIST[dt_key],
                )
            
            logger.info( "srat load") 
            print (pretrain, dt_key, "-->", dt_pattern, "-->",len(dt_pattern) )
            
            ex_list= load_all_files( xpath_list=dt_pattern )
            
            score = calc_metric( example_list=ex_list )
            
            score.update({
                "pretrain":pretrain,
                "dataset":dt_key,
                "file_c": len(dt_pattern),
                "sample_size": len(ex_list),
                })
        
            final_list.append(score)

    df = pd.DataFrame( final_list )
    df.sort_values(['pretrain', 'dataset'], ascending=[True, False])

    df.to_csv("./rq2_result.csv",index=False )
    
    
    
    