import os
from tqdm import tqdm 
import json 

from concurrent.futures import ThreadPoolExecutor
import concurrent 
import traceback 
import numpy as np 

import pickle 
import copy 

def load_jsonl(xpath,with_enurate_id=False):
    with open(xpath) as f :
        if not with_enurate_id:
            return [json.loads(x.strip()) for x in f.readlines()]
        else:
            return [(json.loads(x.strip()),idx ) for idx,x in enumerate(f.readlines()) ]


def map_raw_sha_to_sql_id(dict_item, language, split,  enumerate_id ):
    
    # del dict_item["original_string"]
    # if "code" in dict_item:
    #     del dict_item["code"]
    # else:
    #     print ( list(dict_item), "there are ")
    # if "code_tokens" in dict_item:
    #     del dict_item["code_tokens"]
    # else:
    #     print ( list(dict_item), "there are ")
    #
    # if "docstring" in dict_item:
    #     del dict_item["docstring"]
    # else:
    #     print ( list(dict_item), "there are ")
    #
    # if "docstring_tokens" in dict_item:
    #     del dict_item["docstring_tokens"]
    # else:
    #     print ( list(dict_item), "there are ")
    #
    #

    sql_id = f"summary_in_text_train/{language}/{split}/{enumerate_id}.input"
    
    dict_item ["sql_id"]= sql_id 
    return dict_item
    
def load_chatgpt_data(x_path):
    data= load_jsonl(x_path)
    data = {item["id"]:item for item in data }
    return data 
    
if __name__=="__main__":
    
    root_dir = "/mnt/nvme/code_datasets/CodeXGLUE/Code-Text/code-to-text/dataset"
    root_dir_final_save = "/mnt/nvme/code_datasets/CodeXGLUE/Code-Text/code-to-text/dataset_sqlid"
    os.makedirs(root_dir_final_save,exist_ok=True )
    languages = ["python","java","javascript","ruby","go","php"]
    splits = ["train","valid","test"]
    
    chatgpt_save_path= "/data3/icse_dataset/NL-CCD_dirs/raw/retrain_processed/summary_text-to-code_test.jsonl"
    chatgpt_save_path2= "/data3/icse_dataset/NL-CCD_dirs/raw/retrain_processed/summary_text-to-code_train.jsonl"
    
    if not os.path.isfile("/tmp/load_chatgpt_data.pkl"):
        # do cache 
        sql_dict ={}
        print ("build dict.1 ")
        sql_dict.update ( load_chatgpt_data(chatgpt_save_path) )
        print ("build dict.2 ")
        sql_dict.update ( load_chatgpt_data(chatgpt_save_path2) )
        with open("/tmp/load_chatgpt_data.pkl","wb") as ef :
            pickle.dump(obj=sql_dict, file=ef  )
    else:
        with open("/tmp/load_chatgpt_data.pkl","rb") as ef :
            sql_dict = pickle.load( file=ef  )
        
    # sql_dict_samples = list(sql_dict.keys())[:200]
    
    from itertools import product
    
    check_file_list= list(product(languages,splits))
    demand_data_list = []
    
    num_workers =32 
    
    for language,split  in check_file_list :
        demand_file = os.path.join(root_dir,language, split+".jsonl" )
        print ("load queue ")
        demand_data_list = load_jsonl( xpath =demand_file ,with_enurate_id=True )
        print (f"load total {len(demand_data_list)} queue ")
        # demand_data_list = demand_data_list#[:100]
        
        def process_file(i):
            item,enumerate_id = demand_data_list[i]
            item_ret = copy.deepcopy(item)
            dict_ret = map_raw_sha_to_sql_id(dict_item=item_ret ,language=language, split=split,  enumerate_id=enumerate_id)
            
            id_raw=  dict_ret["sql_id"]
            
            # print (i, id_raw, id_raw in sql_dict, sql_dict_samples)
            if   id_raw in sql_dict:
                raw_func_name = dict_ret ["func_name"]
                raw_func_name = raw_func_name.replace("=","").strip()
                if "." in raw_func_name:
                    raw_func_name = raw_func_name.split(".")[-1]
                exist_func_str= sql_dict[id_raw] ["human_answer"] 
            
                exist = any([ exist_func_str.lower() in raw_func_name.lower(), raw_func_name.lower() in exist_func_str.lower()] )
                if  not exist:
                    # print ("raw_func_name-->", raw_func_name, "any-->", [ exist_func_str in raw_func_name, raw_func_name in exist_func_str],"--->", exist_func_str, "<---exist_func_str" , id_raw)
                    return -2 ,dict_ret
                else:
                    return  1,dict_ret
            
            else :
                return -1 ,dict_ret
            
        with ThreadPoolExecutor(max_workers=num_workers) as ex:
            predictions = ex.map(process_file, range(len(demand_data_list)))
            # predictions = list(tqdm(ex.map(process_file, range(len(demand_data_list))), total=len(demand_data_list)))
        predictions = list(predictions)
        total_predictions= len(predictions)
        predictions_flg = [flg for flg,x in predictions  ]
        predictions = [x for _,x in predictions  ]
        
        print ("===>", language,split  , np.unique(predictions_flg,return_counts=True) , "loss.predictions ", "total.predictions", total_predictions, "total.raw", len(demand_data_list) )

        with open( os.path.join(root_dir_final_save,f"summary_in_text_train__mapping_sqlid__{language}__{split}.jsonl"), "w") as fff :
            fff.write( "\n".join( [ json.dumps(x) for x in  predictions ] ) )
            


