import os
import os
from enum import Enum
from typing import List, Optional, Tuple

import sys 
import json 

from tree_sitter import Node, Parser, Tree
from tree_sitter import Language

import traceback 

so =  os.path.join( os.path.expanduser("~"), "language.so")

GO_LANGUAGE = Language(so, "go")
GO_PARSER = Parser()
GO_PARSER.set_language(GO_LANGUAGE )

JS_LANGUAGE = Language(so, "javascript")
JS_PARSER = Parser()
JS_PARSER.set_language(JS_LANGUAGE )

PY_LANGUAGE = Language(so, "python")
PY_PARSER = Parser()
PY_PARSER.set_language(PY_LANGUAGE )

PHP_LANGUAGE = Language(so, "php")
PHP_PARSER = Parser()
PHP_PARSER.set_language(PHP_LANGUAGE )

JAVA_LANGUAGE = Language(so, "java")
JAVA_PARSER = Parser()
JAVA_PARSER.set_language(JAVA_LANGUAGE )

RUBY_LANGUAGE = Language(so, "ruby")
RUBY_PARSER = Parser()
RUBY_PARSER.set_language(RUBY_LANGUAGE )

CPP_LANGUAGE = Language(so, "cpp")
CPP_PARSER = Parser()
CPP_PARSER.set_language(CPP_LANGUAGE )

C_LANGUAGE = Language(so, "c")
C_PARSER = Parser()
C_PARSER.set_language(C_LANGUAGE )


LANGUAGE_DICT = {
    "go":GO_PARSER, 
    "javascript":JS_PARSER, 
    "python":PY_PARSER, 
    "php":PY_PARSER, 
    "java":JAVA_PARSER, 
    "ruby":RUBY_PARSER, 
    "cpp":CPP_PARSER, 
    "c":C_PARSER, 
    }

def Mccabe_score(lang, source_code  ):
    def _visit_node( node: Node, repeat=0 ):
        if repeat>100:
            return 0 
        count = 0
        if node.type in judge_nodes:
            count += 1
        for item in node.children:
            count += _visit_node(item, repeat=repeat+1 )
        return count

    def _complexity_item(parser , code ):
        tree = parser.parse(code.encode())  # type:Tree
        return _visit_node(tree.root_node)
    
    x_lang_parser  =LANGUAGE_DICT[lang]
    judge_nodes = X_JUDGE_NODES[lang]
    try :
        score = _complexity_item( parser=x_lang_parser, code=source_code )
        return score 
    except Exception as ex :
        return 0 
    # return 0 

import contextlib
import signal

@contextlib.contextmanager
def time_limit(seconds: float):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    signal.setitimer(signal.ITIMER_REAL, seconds)
    signal.signal(signal.SIGALRM, signal_handler)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0)

class TimeoutException(Exception):
    pass



    
c_judge_nodes = [
    "if_statement",
    "case_statement",
    "do_statement",
    "for_range_loop",
    "for_statement",
    "goto_statement",
    "function_declarator",
    "pointer_declarator",
    "struct_specifier",
    "preproc_elif",
    "while_statement",
    "switch_statement",
    "&&",
    "||",
]

cpp_judge_nodes = [
        "if_statement",
        "case_statement",
        "do_statement",
        "for_range_loop",
        "for_statement",
        "goto_statement",
        "function_declarator",
        "pointer_declarator",
        "class_specifier",
        "struct_specifier",
        "preproc_elif",
        "while_statement",
        "switch_statement",
        "&&",
        "||",
    ]

go_judge_nodes = [
        "if_statement",
        "for_statement",
        "function_declaration",
        "expression_case",
        "for_statement",
        "&&",
        "||",
    ]


js_judge_nodes = [
        "if_statement",
        "while_statement",
        "for_statement",
        "catch_clause",
        "with_statement",
        "function_declaration",
        "&&",
        "||",
    ]


py_judge_nodes = [
        "if_statement",
        "elif_clause",
        "while_statement",
        "for_statement",
        "except_clause",
        "boolean_operator",
        "with_statement",
        "assert_statement",
        "list_comprehension",
        "function_definition",
    ]
java_judge_nodes= py_judge_nodes
ruby_judge_nodes= py_judge_nodes
php_judge_nodes = py_judge_nodes

X_JUDGE_NODES= {
    "go":go_judge_nodes, 
    "javascript":js_judge_nodes, 
    "python":py_judge_nodes, 
    "php":php_judge_nodes, 
    "java":java_judge_nodes, 
    "ruby":ruby_judge_nodes, 
    "cpp":cpp_judge_nodes, 
    "c":c_judge_nodes, 
    }

def parse_dict_v1(xpath):
    dict_info = {}
    dict_info["path"]=xpath 
    xpath = os.path.basename(xpath)
    
    xpath = xpath.replace(".jsonl","")
    for dic_str in xpath.split(","):
        k,v = dic_str.split("=")[:2]
        dict_info.update({k:v})
    
    task = dict_info["task"]
    split = task.split("_")[-1]
    dict_info["split"] =split 
    task = task.replace("_"+dict_info["split"] , "")
    dict_info["lang"]=  task.split("_")[-1] if "_" in task else None 
    dict_info["name"]=  task.split("_")[0] 
    if "mt" not in dict_info :
        dict_info["mt"]="baseline"
    if dict_info["name"]=="apps" or  dict_info["name"]=="apps_test" :
        dict_info["lang"]="python"

    if dict_info["task"]=="archive_stackexchange" or  dict_info["task"]=="archive_stackexchange_test":
        dict_info["split"]="test"
        dict_info["lang"]=None
        dict_info["name"]=dict_info["task"]

        
    if "r" not in dict_info and "role" in dict_info :
        dict_info ["r"] =dict_info ["role"]
    if "r"  in dict_info and "role" not in dict_info :
        dict_info ["role"] =dict_info ["r"]
    
    return dict_info 



import tiktoken
ENCODING = tiktoken.get_encoding("cl100k_base")

def num_tokens_from_string(string: str ) -> int:
    """Returns the number of tokens in a text string."""
    num_tokens = len(ENCODING.encode(string))
    return num_tokens



if __name__=="__main__":
    
    
    from concurrent.futures import ThreadPoolExecutor
    num_workers = os.cpu_count()-1 

    p = sys.argv[-1]
    assert os.path.isfile(p ) ,p 
    assert p.endswith(".jsonl") ,p
    
    with open(p) as f :
        videos =  f.readlines()

    raw_len = len(videos )
        
    print ("total readline" , raw_len )
    xpath =  os.path.basename(p)
    xpath_info= {}
    
    def process_file(i):
        if i%500==0:
            print (i)
        filename = videos[i]
        data = json.loads(filename) 
        
        try :
            if "human_answer"  in data :
                human_answer_content  = data.pop( "human_answer")
                chatgpt_answer_content  = data.pop( "chatgpt_answer")
                
                if "=" in xpath:
                    xpath_info  = parse_dict_v1( xpath )
                lang = xpath_info["lang"]  
                
                if xpath_info["lang"] is None :
                    return None 
                
                score_h = Mccabe_score(lang=lang, source_code=human_answer_content )
                score_c = Mccabe_score(lang=lang, source_code=chatgpt_answer_content )
                data["human_complexity"] =  score_h
                data["chatgpt_complexity"] =  score_c
    
                len_h = num_tokens_from_string( string=human_answer_content )
                len_c = num_tokens_from_string( string=chatgpt_answer_content )
    
                data["human_len"] =  len_h
                data["chatgpt_len"] =  len_c
                
            elif "xpath"  in data and "content" in data :
                chatgpt_answer_content  = data.pop( "content")
                
                xpath_y = data["xpath"]
                xpath_info  = parse_dict_v1( xpath_y )
    
                lang = xpath_info["lang"]  
                
                if xpath_info["lang"] is None :
                    return None 
                
                score_c = Mccabe_score(lang=lang, source_code=chatgpt_answer_content )
                data["content_complexity"] =  score_c
                
                len_c = num_tokens_from_string( string=chatgpt_answer_content )
    
                data["content_len"] =  len_c
    
            return data
        except  :
            traceback.print_exc()

    with ThreadPoolExecutor(max_workers=num_workers) as ex:
        predictions = ex.map(process_file, range(len(videos)), timeout=30 )


    
    predictions  = list(predictions)
    

    predictions =  [x for x in predictions if x is not None ]
    notnone_len = len(predictions )
    
    save_dir ="/home/wj2_cuda12/wj_code/dl_chatgpt/tosem_data/fse2023_dataset/complexity_score_v2"
    print ( "raw", raw_len , "not none ", notnone_len)
    with open( os.path.join(save_dir , os.path.basename(p) ) ,"w") as f :
        f.write("\n".join(  [json.dumps(x) for x in predictions ] ) )

    
