import ast
import astor

 
import json 
import os 
from concurrent.futures import ThreadPoolExecutor
import concurrent 

import traceback 

import logging 
logger = logging.getLogger(__file__)
# Setup logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)


def ast_chk(content_list ):
    def _ast(content):
        err = None 
        try :
            ast.parse (source = content )
        except :
            err= traceback.format_exc()
            return False , err 
        return True  , err

    ret= [(x, _ast(content=x) )  for x in content_list ]
    ret = [(code,flg,err) for code,(flg,err) in ret  ]
    return ret 

def select_one_from_list(list_item,strategy="longest"):
    if type(list_item)!=list or len(list_item)<=1:
        return list_item 
    
    def count_assert (code_block):
        code_block_str=  code_block
        if type(code_block)==list :
            code_block_str = "\n".join(code_block)
        if code_block_str.lower().count("assert")>=2:
            return None 
        return code_block
        
    list_item = [x for x in list_item if count_assert(x) is not None ]
    if strategy=="longest":
        
        list_item = sorted(  list_item , key=lambda x:len(x) )
        list_item = list(list_item)[-1]
        assert type(list_item)==str, ( type(list_item), list_item )
        return [list_item]
            



def load_jsonl(p):
    with open(p) as f :
        data= [json.loads(x) for x in f.readlines()]
        return data 
    


if __name__=="__main__":
    import sys 
    cur_dir = os.path.dirname ( os.path.dirname(__file__) )
    
    sys.path.append(cur_dir )
    print (sys.path,cur_dir)
    import numpy as np 
    import itertools 
    from tqdm import tqdm 
    import random 
    import copy 
    
    from mutations  import processor as mt_producsor  


    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-path", type=str,
                        help="display a square of a given number")
    parser.add_argument("-count", type=int,default=-1,
                        help="display a square of a given number")
    # parser.add_argument("-filter_python", action="store_true",
    #                     help="display a square of a given number")
    # parser.add_argument("-human_extract", action="store_true",
    #                     help="display a square of a given number")
    # parser.add_argument("-chatgpt_extract", action="store_true",
    #                     help="display a square of a given number")
    parser.add_argument("-save_dir", type=str,
                        help="display a square of a given number")
    
    args=  parser.parse_args()

    print (args )
    # root_dir = "/data3/icse_dataset/NL-CCD_dirs/raw/retrain_processed"
    # root_dir_replace = "/data3/icse_dataset/NL-CCD_dirs/code_extract/retrain_processed_if_err_left_empty"
    # root_dir = "/data3/icse_dataset/NL-CCD_dirs/NL-CCD-new/python"
    # root_dir_replace = "/data3/icse_dataset/NL-CCD_dirs/NL-CCD-new/extract"
    root_dir_replace = args.save_dir
    # filter_py = args.filter_python
    json_p = args.path 
    # chatgpt_extract = args . chatgpt_extract
    # human_extract = args . human_extract
    

    assert os.path.isfile(json_p), json_p 
    
    os.makedirs(root_dir_replace, exist_ok=True )
    
    mix_count = args. count
    
                
    num_workers = os.cpu_count()-1 

    # for task,filter_py  in file_list :
    if 1==1:
        
        vides=  load_jsonl(json_p )    

        new_save_path = os.path.join(root_dir_replace,  os.path.basename(json_p ) )
        assert  not os.path.isfile(new_save_path), new_save_path
        # vides = vides[:100]
        ## 
          
        def process_file(processor, i):
            try :
                return _process_file(processor=processor, i=i)
            except Exception as ex :
                traceback.print_exc()
                logger.info(ex)
            
        def _process_file(processor, i):
            data = vides [i]
            # data = json.loads(content)
            new_info = copy.deepcopy(data)

            idx = data["id"]
            # ###### process human_answer
            code = data["human_answer"]
            ast_info={
                "raw_pass":0,
                "mt_pass":0,
                }
            #
            # # with open("/tmp/run_code.py","w") as fff :
            # #     fff.write(code)
            # #     fff.write( json.dumps(data) )
            try :
                code_mt = processor(code=code)
            except Exception as ex :
                logger.info( "idx:{} role:human err_msg:{}".format(idx,ex ))
                code_mt = None 
            #
            ast_list_msg = ast_chk(  content_list=[code,code_mt] )
            ast_info["raw_pass"]= ast_list_msg[0][1]
            ast_info["mt_pass"]= ast_list_msg[1][1]
            #
            new_info["human_answer"]=   code_mt
            new_info["human_answer_raw"]= code 
            new_info["human_answer_flags"]= ast_info
            #

            
            
            code = data["chatgpt_answer"]
            # idx = idx.replace("/","__")

        
            ###### process chatgpt_answer
            
            ast_info={
                "raw_pass":0,
                "mt_pass":0,
                }
            
            try :
                code_mt = processor(code=code)
            except Exception as ex:
                logger.info( "idx:{} role:chatgpt err_msg:{}".format(idx,ex ))
                code_mt = None 
                
            ast_list_msg = ast_chk(  content_list=[code,code_mt] )
            ast_info["raw_pass"]= ast_list_msg[0][1]
            ast_info["mt_pass"]= ast_list_msg[1][1]
            
            new_info["chatgpt_answer"]=   code_mt
            new_info["chatgpt_answer_raw"]= code 
            new_info["chatgpt_answer_flags"]= ast_info
                
                
            return new_info 
         
     
     
# AssginAddLine
# AugAssgin2Assign
# Comp2For
# ExprStmt2Assign
# IfStmt2IfStmt
# WhileStmt
# fb_obfuscator        


# AssginAddLine(AssginAddLine)
# AugAssgin2Assign
# Comp2For(For2While)
# ExprStmt2Assign
# IfStmt2IfStmt(IfElse2ElseIf)
# WhileStmt(While2While)
# fb_obfuscator(Obfuscator)

        MT_OPERATE_RENAME_LIST=[
            # "rename_var",
            # "rename_func", 
            # "add_dead_code",
            
            ] 
            
        MT_OPERATE_EXPR_LIST=[
            "AssginAddLine",#
            "IfStmt2IfStmt",#
            "WhileStmt",#
            "Comp2For",#
            
            # "ExprStmt2Assign",#
            # "IfExpr2Stmt",
            # "AugAssgin2Assign",#
            ]
   
                

        print ("start workers")
        
        seed=42
        if 1==1:
            if mix_count==-1 :
                mt_selected_list = MT_OPERATE_RENAME_LIST+MT_OPERATE_EXPR_LIST
            elif mix_count>1 :
                random.seed(seed)
                mt_selected_list=  random.sample( MT_OPERATE_RENAME_LIST+MT_OPERATE_EXPR_LIST, min(mix_count, len( MT_OPERATE_RENAME_LIST+MT_OPERATE_EXPR_LIST )) )
                
                
            x_name = f"mix_{mix_count}_seed{seed}"+"".join( mt_selected_list )
            new_save_path  = os.path.join(root_dir_replace, x_name, os.path.basename(json_p) )
            os.makedirs( os.path.dirname(new_save_path) ,exist_ok=True )
            print ("save...", new_save_path)
        # for x_name in MT_OPERATE_RENAME_LIST+MT_OPERATE_EXPR_LIST:
            x_processor = mt_producsor.MutateModel(rate=0.8, op_names = mt_selected_list )
            
            with tqdm(total=len(vides)) as pbar:
            
                with ThreadPoolExecutor(max_workers=num_workers) as ex:
            
            
                    #init processor 
                    # x_name = "rename_var"
            
                    # predictions = ex.map(process_file, range(len(vides)))
                    # predictions = list(tqdm(ex.map(process_file, range(len(vides))), total=len(vides)))
                    my_iter = range(len(vides))
                    futures = {ex.submit(process_file, x_processor ,  arg): arg for arg in my_iter}
                    predictions = []
                    for future in concurrent.futures.as_completed(futures):
                        predictions.append(  future.result() )
                        pbar.update(1)
            
            
            predictions= list(predictions )
            with open( new_save_path ,"w") as fff:
                fff.write("\n".join(
                        [ json.dumps(x) for x in predictions ]
                        ) )
            #


        # x_name = "rename_var"
        # x_processor = mt_producsor.MutateModel(rate=0.8, op_names = [x_name] )
        # futures = process_file( x_processor ,  10)
        # print ( vides[10])
        # print (futures)
