from __future__ import absolute_import, division, print_function

import argparse
import logging
logger = logging.getLogger(__name__)
logger.setLevel( logging.INFO )
import os
import random
import torch
import json
#
from sklearn.model_selection import train_test_split 

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

import numpy as np
from datasets import Dataset, concatenate_datasets
import evaluate
import pandas as pd
import torch
from transformers import (
    HfArgumentParser, 
    RobertaForSequenceClassification, RobertaTokenizer, RobertaConfig,
    DataCollatorWithPadding,
    Trainer, TrainingArguments
)

from transformers import pipeline
from glob2 import glob 

from datetime import datetime 



os.environ["WANDB_DISABLED"] ="true"

def list_field(default=None, metadata=None):
    return field(default_factory=lambda: default, metadata=metadata)


def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYHTONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True




def read_train_test(file_path,split="train",args=None ):
    if type(file_path)==str :
        file_path = [file_path]

    examples = []
    for one_file_path in file_path :
        with open(one_file_path) as f:
            examples.extend( [line.strip() for line in f.readlines() ] )


    print ("finish initial", len(examples) )
    logger.info("the final examples size {} ".format(len(examples)))
    examples = [json.loads(x) for x in examples]
    def save_pick_up(item_str):
        if type(item_str)==str :
            return item_str 
        return item_str[0]
    human_examples = [{"answer":save_pick_up(item["human_answer"]), "labels":0, "idx":item["id"] } for item in examples if item["human_answer"] is not None ]
    human_examples = [x for x in human_examples if x["answer"] is not None and len(x["answer"] )>0  ]
    chatgpt_examples = [{"answer":save_pick_up(item["chatgpt_answer"]), "labels":1, "idx":item["id"]} for item in examples if item["chatgpt_answer"] is not None ]
    chatgpt_examples = [x for x in chatgpt_examples  if x["answer"] is not None and len(x["answer"] )>0  ]

    human_dataset = Dataset.from_pandas( pd.DataFrame(human_examples))
    chatgpt_dataset = Dataset.from_pandas( pd.DataFrame(chatgpt_examples))
    return concatenate_datasets([human_dataset, chatgpt_dataset ])

def metric_cal(y_true, y_prob=None ,y_label =None  ):

    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    if y_label is None :
        y_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
        y_label = y_label.astype(int)

    if y_prob is None :
        assert y_label is not None 
        y_prob = y_label
        

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)
    auc = metrics.auc(_fpr, _tpr)
    

    eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    chatgpt_len = np.sum(y_true)
    human_len =len(y_true)- np.sum(y_true)
    
    
    return {
        "auc":float(auc),
        "fpr":float(FPR),
        "fnr":float(FNR),
        "eer":float(eer),
        "chatgpt_size":int(chatgpt_len),
        "human_size":int(human_len),
        }


def convert_sing_prob_softmax_prob(y_pred:List[int],y_prob:List[float]) :
    '''
    y_pred [1,0,1]
    y_prob [0.9,0.9,0.9]
    
    -->
    [ [0.1,0.9],[0.9,0.1],[0.1,0.9]
    ]
    
    '''
    blank = np.zeros((len(y_pred), 2 ) , dtype =float )
    for i, (one_pred ,one_prob) in enumerate( zip(y_pred,y_prob) ) :
        blank[i][one_pred]= one_prob
        blank[i][1-one_pred]= 1-one_prob

    np.max(blank)<=1
    np.min(blank)>=0
    
    return blank

# dim_2_np = convert_sing_prob_softmax_prob(y_pred=y_pred, y_prob=y_prob )
# dim_2_th = torch.from_numpy(dim_2_np)    
# dim_2_th_argmax = torch.softmax(dim_2_th,dim=1)
# dim_2_th_argmax = dim_2_th_argmax.argmax(1)
def parse_dict_v1(xpath):
    xpath = os.path.basename(xpath)
    xpath = xpath.replace(".jsonl","")
    dict_info = {}
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    return dict_info 

def dict_to_str(xpath_info):
    xpath_info = dict(sorted(xpath_info.items(), reverse=True ) )
    return ",".join( [f"{k}={v}" for k,v in xpath_info.items() ] )



if __name__=="__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--root-dir", default=None,  type=str, required=True,
                        help="The input training data file (a text file).")
    parser.add_argument("--test_data_file", default=None, action="append", required=True,
                        help="The input training data file (a text file).")

    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    
    parser.add_argument("--cur_test_file", default=None, type=str, required=False,
                        help="The output directory where the model predictions and checkpoints will be written.")
    
    parser.add_argument("--save_dir", default="save", type=str, required=False,
                        help="The output directory where the model predictions and checkpoints will be written.")
    
    parser.add_argument("--batch_size", default=64, type=int, required=False,
                        help="The output directory where the model predictions and checkpoints will be written.")
    args = parser.parse_args()
    
        
    def parse_eval_list():
        eval_path_list = []
    
        for eval_path in args.test_data_file:
            if os.path.isfile(eval_path) :
                eval_path_list.append(eval_path)
                continue 
            if os.path.isfile( os.path.join(args.root_dir, eval_path ) ) :
                eval_path_list.append(os.path.eval_path(args.root_dir, eval_path )  )
                continue 
        
            if "*" in eval_path:
                search_path = os.path.join(args.root_dir, eval_path)
                search_path_list = glob( search_path )
                logger.info(f"try search from {search_path}, get total {len(search_path_list)} ")
                print (f"try search from {search_path}, get total {len(search_path_list)} " )
                eval_path_list.extend(search_path_list )
            
        args.test_data_file =  eval_path_list
        return eval_path_list
    
    logger.info("parse eval test list")
    parse_eval_list()


    for one_eval_path in args.test_data_file :
        print (f"start , {one_eval_path}")
        args.cur_test_file = one_eval_path
        ##########
        logger.info("pare save path")
        #time_str = datetime.now().strftime('%S%M%H%m%d%Y') 
        path_dict_info = parse_dict_v1( one_eval_path  )
        if "mt" not  in path_dict_info : path_dict_info["mt"]="baseline"
        path_dict_info["mar"]= os.path.basename(  os.path.dirname(args.output_dir) )
        #path_dict_info["time"]=time_str
        path_dict_info.pop("temp",None)
        path_dict_info.pop("topp",None)
        path_dict_info.pop("formated",None)
        path_str = dict_to_str(path_dict_info)

        save_dir = args.save_dir 
        save_path  =os.path.join(save_dir, f"./{path_str}.csv" )
        os.makedirs(save_dir , exist_ok=True )
        logger.info (f"will save {save_path}")
        
        if os.path.isfile(save_path):
            logger.info(f"exist {save_path}")
            print ("========="*8 , "exist", "\n\n\n", save_path )
            continue 
        
        ###########
        logger.info ('Tokenizing and mapping...')
        # model_name_or_path=  "Hello-SimpleAI/chatgpt-qa-detector-roberta"
        assert os.path.isdir(args.output_dir), args 
        tokenizer = RobertaTokenizer.from_pretrained(args.output_dir)
        kwargs = dict(max_length=512, truncation=True)
        
        device=torch.device("cuda:0")
        
        logger.info ('load model from pretrained:...')
        # model = RobertaForSequenceClassification.from_pretrained(args.output_dir, num_labels=2)
        logger.info ('build pipeline...')
        pipe = pipeline("text-classification", model = args.output_dir  ,batch_size=args.batch_size )
        
        


        logger.info ('dataset mapping ...')
        ##############

        test_dataset = read_train_test(file_path=one_eval_path  ,split="valid",args=args )

        
        def func(batch):
            out = detector(batch['answer'], max_length=512, truncation=True)
            # print (list(out[0]) )
            batch['pred'] = [int(o['label'][-1]) for o in out]
            batch['prob'] = [ o['score'] for o in out]
            # batch['prob'] = [o['score'] for o in out]
            return batch

        # path = f"./xxx_test.csv"  # path to the csv data from the google drive
        batch_size=args.batch_size
        
        detector = pipeline('text-classification', model=args.output_dir, device=device, framework='pt')
        dataset = test_dataset.map(func, batched=True, batch_size=batch_size, desc='test')


        y_true =np.array( [item["labels"] for item in dataset ] )
        y_pred =np.array( [item["pred"] for item in dataset ] )
        y_prob =np.array( [item["prob"] for item in dataset ] )
        
        logger.info(f" the uniq y_true {np.unique(y_true,return_counts=True)} ")
        print (f" the uniq y_true {np.unique(y_true,return_counts=True)} ")
        logger.info(f" the uniq y_pred {np.unique(y_pred,return_counts=True)} ")
        print(f" the uniq y_pred {np.unique(y_pred,return_counts=True)} ")

        x_cnt  =convert_sing_prob_softmax_prob (y_pred=y_pred, y_prob=y_prob ) 

        y_pred2 = x_cnt[:,-1]
        
        metric_value = metric_cal(y_true=y_true, y_prob=y_pred2 )
        metric_value["meta"]= vars(args)
        
        with open(save_path ,"w" ) as f :
            json.dump(obj= metric_value,fp=f,  indent=4 )
        
        save_data_path = save_path.replace(".csv","_data.jsonl")
        new_dataset_list = []
        for item in dataset :
            new_item = {k:v for k,v in item.items() if k in ["labels","pred","prob", "idx" ]}
            new_dataset_list.append(new_item )

        with open(save_data_path,"w") as f :
            f.write( "\n".join( [json.dumps(x) for x in new_dataset_list ] ) )





