import os 
import json 
import traceback 
import jmespath 
import numpy as np 

# exit()

def parse_dict_v1(xpath):
    dict_info = {}
    dict_info["path"]=xpath 
    xpath = os.path.basename(xpath)
    
    xpath = xpath.replace(".jsonl","")
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    
    task = dict_info["task"]
    split = task.split("_")[-1]
    dict_info["split"] =split 
    task = task.replace("_"+dict_info["split"] , "")
    dict_info["lang"]=  task.split("_")[-1] if "_" in task else None 
    dict_info["name"]=  task.split("_")[0] 
    if "mt" not in dict_info :
        dict_info["mt"]="baseline"
    if dict_info["name"]=="apps":
        dict_info["lang"]="python"

    if dict_info["task"]=="archive_stackexchange":
        dict_info["split"]="test"
        dict_info["lang"]=None
        dict_info["name"]=dict_info["task"]
        
    if "r" not in dict_info and "role" in dict_info :
        dict_info ["r"] =dict_info ["role"]
    if "r"  in dict_info and "role" not in dict_info :
        dict_info ["role"] =dict_info ["r"]
    
    return dict_info 




def reverse_dict(meta_info  ):
    if "path" in meta_info :
        meta_info.pop("path")
    info_list= [f"{x}={y}"for x,y in meta_info.items() ]
    info_list = list(sorted(info_list))
    xpath = ",".join( info_list )
    return f"{xpath}.jsonl"  # os.path.join(search_dir,  )



def metric_cal(y_true, y_prob ):
    
    from sklearn import metrics
    from scipy.optimize import brentq
    from scipy.interpolate import interp1d

    y_prob_label = y_prob>0.5 if "float" in str(y_prob.dtype)  else y_prob 
    y_prob_label = y_prob_label.astype(int)

    TN, FP, FN, TP  =\
        metrics.confusion_matrix(y_true = y_true, 
                         y_pred = y_prob_label, 
                         ).ravel()

    # print (cmt,"cmt....")
    FPR = FP / float(FP+ TN  )
    TPR = TP/ float( TP+FN )
    FNR = FN/ float( FN+TP )

    _fpr, _tpr, thresholds = metrics.roc_curve(y_true, y_prob, pos_label=1)

    eer = brentq(lambda x : 1. - x - interp1d(_fpr, _tpr)(x), 0., 1.)

    
    auc = metrics.auc(_fpr, _tpr)
    return {
        "auc":auc,
        "fpr":FPR,
        "fnr":FNR,
        "eer":eer,
        }



if __name__=="__main__":
    from tqdm import tqdm 
    from concurrent.futures import ThreadPoolExecutor
    import sys 
    
    
    save_idx_dir = "/data4/data_dir/detection_save_dirs/comparing_ids"
    save_remain_dir = "/home/wj2_cuda12/wj_code/dl_chatgpt/remain_dirs"
    os.makedirs(save_remain_dir,exist_ok=True)
    num_workers = os.cpu_count()-1 
    
    
    one_f_human = sys.argv[-1]
    assert "human_answer" in one_f_human or "chatgpt_answer" in one_f_human, one_f_human
    
    if os.path.isfile( one_f_human ):
    
        
        # one_f_chatgpt = os.path.join(  os.path.dirname(one_f_human) , os.path.basename(one_f_human).replace("human_answer", "chatgpt_answer") )
        # assert os.path.isfile(one_f_chatgpt) , one_f_chatgpt 
        
        # print (meta_info)
        final_list = []
        meta_info = {}
        # print ("\n\n\n")
        # for one_f in [one_f_human,one_f_chatgpt] :
        for one_f in [one_f_human] :
            meta_info = parse_dict_v1(xpath=one_f)

            assert ("r" in meta_info or "role" in meta_info ), (meta_info, one_f_human)

            # if "r" in meta_info:
            #     meta_info.pop("r")
            # if "role" in meta_info:
            #     meta_info.pop("role")
            if "path" in meta_info:
                meta_info.pop("path")
         

            with open(one_f) as f :
                videos = f.readlines() 
                
            # videos = videos[:30000]
            # print (meta_info, "====>", len(videos), "len.data")
                
            # def correct_item(json_one):
            #     try :
            #         prob = json_one["result"]["raw_result"]["data"]["text_score"]
            #         label = prob>0.5 
            #         role = json_one["task"]
            #         return {"id":json_one["id"],"gt":role, "label":label, "prob":prob, "role":role }
            #
            #     except :
            #         return None 
                
                
            def process_file(i):
                json_one = videos[i]
                # if i % 5000 ==0:
                #     print (i)
                    
                idx =None
                xpath= None  
                try :
                    json_one = json.loads(json_one)
                    idx = json_one["idx"] if "idx" in json_one else json_one["id"]
                    xpath  = reverse_dict(json_one["task"])
                except Exception  as ex :
                    # traceback.print_exc()
                    # print (idx, "error2 ", str(ex) )
                    return None 
                
                try :
                    task_q = json_one["task"]["q"]
                    
                    label = jmespath.search("result.result",json_one)

                    if label is None :
                        if "result" in json_one :
                            del json_one["result"]
                        # return {"id":idx,"gt":None, "label":None, "prob":None, "role":None, "xpath":xpath  ,"json_data":json_one, "reason": "label is -1" , "status":-3 }
                        return None 
                        
                    prob = json_one["result"]["probability"] if json_one["result"]["probability"] is not None  else label
                    role = json_one["task"]
                    if "r"  in role or "role"  in role :
                        role =  role["r"] if "r" in role else role ["role"]
                    else:
                        role =  meta_info["r"] if "r" in meta_info else meta_info["role"]
                        
                    if task_q=="scribbr_com_out":
                        v1 = jmespath.search("result.raw_result.data.scores.fake",json_one)
                        v2 = jmespath.search("result.raw_result.data.text_score",json_one)
                        # print ({"v1":v1,"v2":v2 },"--->") 
                        prob = v2 if v1 is None else v1 
                        if prob is None :
                            if "result" in json_one :
                                del json_one["result"]
                            # return {"id":idx ,"gt":None, "label":None, "prob":None, "role":None , "xpath":xpath, "reason":"reject" ,"json_data":json_one , "status":-4 }
                            return None 

                        label = prob>0.5 
                        
                except TypeError as ex :
                    traceback.print_exc()
                    if "result" in json_one :
                        del json_one["result"]
                    # return {"id":idx,"gt":None, "label":None, "prob":None, "role":None, "xpath":xpath  ,"json_data":json_one, "reason": str(ex) , "status":-5 }
                    return None 

                # return {**meta_info, "id":idx,"gt":role, "label":label, "prob":prob, "role":role , "xpath":xpath ,"json_data":None , "status":1 }
                return {**meta_info, "id":idx,"gt":role, "label":label, "prob":prob, "role":role  ,"json_data":None , "status":1 }
            
            with ThreadPoolExecutor(max_workers=num_workers) as ex:
                predictions = ex.map(process_file, range(len(videos)))
            
            predictions = list(predictions)
            raw_len = len(predictions)

            # predictions_none = [x for x in predictions if x is  None ]
            # filter_len_none = len(predictions_none)
            # if filter_len_none>0 :
            #     print ("filter_len_none" , filter_len_none, os.path.basename(one_f)  )
            
            predictions = [x for x in predictions if x is not None ]
            filter_len = len(predictions)
            
            
            # print (predictions[:2])
            
            # predictions2 = [x for x in predictions if x["label"] is not None  ]
            # filter_len2 = len(predictions2)
            #
            # predictions_status = np.array( [x["status"] for x in predictions ] )
            # predictions_status_str = np.unique(predictions_status,return_counts=True )
            # predictions_succ_rate = np.mean(predictions_status==1)#.mean()
            # # filter_len2 = len(predictions2)
            # if predictions_succ_rate<0.95:
            #     print ( {**meta_info, "raw": raw_len , "not_none":filter_len, "can_parse":filter_len2,"status":predictions_status_str, "rate":predictions_succ_rate } )
            #
            #     predictions_fail = [x["json_data"] for x in predictions if x["status"]!=1  ]
            #
            #     # xpath_save_remain_dir = os.path.basename( os.path.dirname(  meta_info["path"] ))
            #     # xpath_save_remain =  os.path.basename(  meta_info["path"] )
            #     # x_save_path = os.path.join(save_remain_dir,  xpath_save_remain_dir , xpath_save_remain )
            #     # os.makedirs( os.path.dirname(x_save_path) , exist_ok=True )
            #     # with open( x_save_path, "w") as f2 :
            #     #     f2.write( "\n".join( [json.dumps(x) for x in predictions_fail]))
            #
            #
            # ret={**meta_info, "raw":raw_len , "filter_c":filter_len, "filter_c_parse":filter_len2 }
            # print (ret)
            # with open("/tmp/final_list5.txt","a") as f :
            #     f.write( json.dumps(ret))
            #     f.write(  "\n" )
        
        
            final_list.extend( predictions )
            
        
            xq = meta_info["q"]
            xpath = "/home/wj2_cuda12/wj_code/dl_chatgpt/save_dir2/final_{}.jsonl".format(xq)
            with open(xpath,"a" ) as f :
                final_list = [json.dumps(x) for x in final_list ]
                f.write( "\n".join(final_list ) )
                f.write( "\n")

        
