import ast
import astor

 
import json 
import os 
from concurrent.futures import ThreadPoolExecutor
import concurrent 

import traceback 

def ast_chk(content_list ):
    def _ast(content):
        err = None 
        try :
            ast.parse (source = content )
        except :
            err= traceback.format_exc()
            return False , err 
        return True  , err

    ret= [(x, _ast(content=x) )  for x in content_list ]
    ret = [(code,flg,err) for code,(flg,err) in ret  ]
    return ret 

def select_one_from_list(list_item,strategy="longest"):
    if type(list_item)!=list or len(list_item)<=1:
        return list_item 
    
    def count_assert (code_block):
        code_block_str=  code_block
        if type(code_block)==list :
            code_block_str = "\n".join(code_block)
        if code_block_str.lower().count("assert")>=2:
            return None 
        return code_block
        
    list_item = [x for x in list_item if count_assert(x) is not None ]
    if strategy=="longest":
        
        list_item = sorted(  list_item , key=lambda x:len(x) )
        list_item = list(list_item)[-1]
        assert type(list_item)==str, ( type(list_item), list_item )
        return [list_item]
            



def load_jsonl(p):
    with open(p) as f :
        data= [json.loads(x) for x in f.readlines()]
        return data 
    

if __name__=="__main__":
    import click

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-path", type=str,
                        help="display a square of a given number")
    parser.add_argument("-filter_python", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-human_is_extract", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-human_is_remove_comment", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-human_is_select_only_one", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-chatgpt_is_extract", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-chatgpt_is_remove_comment", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-chatgpt_is_select_only_one", action="store_true",
                        help="display a square of a given number")
    parser.add_argument("-save_dir", type=str,
                        help="display a square of a given number")
    
    args=  parser.parse_args()

    root_dir_replace = args.save_dir
    root_dir_replace_sta = args.save_dir+"_sta"
    filter_py = args.filter_python
    json_p = args.path 
    

    config= {}
    config["human_answer"]={"is_extract":args.human_is_extract,"is_remove_comment":args.human_is_remove_comment,"is_select_only_one":args.human_is_select_only_one }

    config["chatgpt_answer"]={"is_extract":args.chatgpt_is_extract,"is_remove_comment":args.chatgpt_is_remove_comment,"is_select_only_one":args.chatgpt_is_select_only_one }
    
    print (args, "---config-->",config )
    # if not click.confirm('Do you want to continue?', default=True):
    #     exit() 

    assert os.path.isfile(json_p), json_p 
    
    os.makedirs(root_dir_replace, exist_ok=True )
    os.makedirs(root_dir_replace_sta, exist_ok=True )
    
    
    import clean_utils 
    import numpy as np 
    import itertools 
    from tqdm import tqdm 
    import random 
    import copy
    
    
    num_workers = os.cpu_count()-1 

    if 1==1:
        
        vides=  load_jsonl(json_p )    
        
        # vides = vides[:500]

        new_save_path = os.path.join(root_dir_replace,  os.path.basename(json_p ) )
        if os.path.isfile(new_save_path):
            print ("skip...", new_save_path )
            exit()
        # assert  not os.path.isfile(new_save_path), new_save_path

        
        if filter_py :
            vides = [x for x in vides if "python" in x["id"] ] 
          
        def process_file(i):
            
            flg= False 
            data = vides [i]
            flg2 =True 

            new_info =  copy.deepcopy(data)
            new_info["id" ] = new_info["sql_id"]
            idx = new_info["id" ] 
            if_python_lang = "python" in idx.lower() or "apps" in idx.lower() or "human" in idx.lower() or "mobpp" in idx.lower()  
            
            new_info["flags"] = {}
            
            for role in ["chatgpt_answer", "human_answer"]:
                ###### process chatgpt_answer
                pass_all,pass_main, extract_c,err,is_delete   = 0 , -1 ,-1, None, 0
                
                
                
                try :
                    code = data[role ]
                    
                    if config[role]["is_extract"] :
                        code_src_list = clean_utils.extract_code_from_markdown_v2(  code )
                        extract_c = len(code_src_list)
                    else:
                        code_src_list= [code]
                        extract_c = 0 
                    
                    if config[role]["is_remove_comment"]:
                        new_code_src_list = []
                        for cx in code_src_list:
                            try :
                                cx= clean_utils.remove_comments_and_docstrings(source=cx,lang="python")
                            except :
                                cx=clean_utils.remove_comments_and_docstrings(source=cx,lang="java")
                                
                            new_code_src_list.append(cx ) 
                    else:
                        new_code_src_list = code_src_list
                        
                        
                    code_src_list = [x for x in new_code_src_list if x is not None and len(x.strip())>0 ]
                    if len(code_src_list)<=0:
                        raise Exception( "[chatgpt]the code_src_list is empty, raw.len: %s, idx: %s "%(len(code), idx ) )
    
                    if if_python_lang:
                        ast_list_msg = ast_chk(  content_list=code_src_list )
                        pass_all = all( [y for x,y,z in ast_list_msg ]  )
    
    
                    if config[role]["is_select_only_one"]:
                        code_src_list = select_one_from_list(list_item=code_src_list)
                        ast_list_msg = ast_chk(  content_list=code_src_list )
                        pass_main = all( [y for x,y,z in ast_list_msg ]  )
                        new_info[role]=   code_src_list[0]
                    else:
                        new_info[role]=   "\n\n".join(code_src_list )
                        
                    
                    is_delete = clean_utils.cognise_as_ai_language_model(content=new_info[role] )
                    if is_delete :
                        raise Exception("delete as a language model") 

                    flgs = new_info["flags"]
                    flgs.update(   {
                        role+"_ast_pass_all":int(pass_all),
                        role+"_ast_pass_main":int(pass_main),
                        role+"_extract_c":extract_c,
                        role+"_error":int( err is not None), 
                        role+"_is_delete":int(is_delete),
                        role+"_if_python_lang":int(if_python_lang),
                        "idx":idx,

                         }
                        )
                    new_info["flags"] = flgs

                except Exception as err :
                    del new_info[role]
                    flgs = new_info["flags"]
                    flgs.update(   {
                        role+"_ast_pass_all":int(pass_all),
                        role+"_ast_pass_main":int(pass_main),
                        role+"_extract_c":extract_c,
                        role+"_error":int( err is not None), 
                        role+"_is_delete":int(is_delete),
                        role+"_if_python_lang":int(if_python_lang),
                        "idx":idx,
                        
                         }
                        )
                    new_info["flags"] = flgs


                    
            if "prompt" in new_info :
                del new_info["prompt"]
            return new_info 
         
                
                

        print ("start workers")
        with tqdm(total=len(vides)) as pbar:
        
            with ThreadPoolExecutor(max_workers=num_workers) as ex:
        
                my_iter = range(len(vides))
                futures = {ex.submit(process_file, arg): arg for arg in my_iter}
                predictions = []
                for future in concurrent.futures.as_completed(futures):
                    predictions.append(  future.result() )
                    pbar.update(1)
        
        
        predictions= list(predictions )
        
        import pandas as pd 
        
        df_list = [item["flags"] for item in predictions ]
        df  = pd.DataFrame( df_list )
        print (df.describe() )
        new_save_path_sta = os.path.join(root_dir_replace_sta,  os.path.basename(json_p ) )
        new_save_path_sta =  new_save_path_sta.replace(".jsonl",".csv")
        df.to_csv(new_save_path_sta, index=False )

        predictions_overall= {}
        predictions_overall["human_not_none"] = sum( [int( "human_answer"  in item ) for item in predictions  ])
        predictions_overall["chatgpt_not_none"] =sum( [int(  "chatgpt_answer"  in item) for item in predictions  ])
        predictions_overall["hum_chat_not_none"] =sum( [int( "human_answer"  in item and "chatgpt_answer"  in item) for item in predictions  ] )
        predictions_overall["raw_size"] = len( predictions )
        
        new_save_path_sta_overall =  new_save_path_sta.replace(".csv",".txt")
        df_overall = pd.DataFrame( [predictions_overall] )
        df_overall.to_csv(new_save_path_sta_overall,index=False )
        
        with open(new_save_path,"w") as fff:
            fff.write("\n".join(
                    [ json.dumps(x) for x in  [item for item in predictions if "human_answer"  in item and "chatgpt_answer"  in item ] ]
                    ) )
        