
import astor
import ast

from tokenize import tokenize, untokenize, COMMENT, STRING, NEWLINE, ENCODING, ENDMARKER, NL, INDENT, NUMBER
import random
import copy 

import re 
import os
import sys 

# sys.path.append("./")

from . import expr as mt # import mt 


from functools import partial

from io import StringIO
import traceback 



OPERATOR_LIST=[
    "Scope2Func",  # check
    "ZipAssgin", # check 
    #

    
    # "ForStmt", ## error 
    "IfExpr2Stmt", # error 
    "WhileStmt", #error 
    "IfStmt2IfStmt", #error 
    # #
    # #
    # #
    # #
    "ExprStmt2Assign",#check 
    "AugAssgin2Assign", #check
    "Assert2If", #check 
    "With2Try", # check 
    "Lambda2Func",#check 
    "Comp2For", #check 
    "AssginAddLine",#check 
    ]



def format_content(fstr):
    fstr = fstr.replace('\r\n', '\n').replace('\r', '\n')
    if not fstr.endswith('\n'):
        fstr += '\n'
    return fstr #ast.parse(fstr) , fstr #, filename=fname)


def do_lines_intersect(line1, line2):
    # Check if two lines intersect by comparing their endpoints
    return (
        min(line1[0], line1[1]) <= max(line2[0], line2[1]) and
        min(line2[0], line2[1]) <= max(line1[0], line1[1])
    )

def find_max_non_intersecting_lines(lines):
    lines.sort(key=lambda x: min(x[0], x[1]))  # Sort lines by their starting points
    
    non_intersecting_lines = []
    
    while lines:
        current_line = lines.pop(0)
        is_independent = True
        
        for line in non_intersecting_lines:
            if do_lines_intersect(current_line, line):
                is_independent = False
                break
        
        if is_independent:
            non_intersecting_lines.append(current_line)
    
    return non_intersecting_lines

def find_max_non_intersecting_lines(lines):


    lines.sort(key=lambda line: line[1])  # Sort lines based on their end points
    selected_lines = []
    last_line = None

    for line in lines:
        if last_line is None or not do_lines_intersect(line, last_line):
            selected_lines.append(line)
            last_line = line

    return selected_lines




def remove_any_interset_expr(if_Expr,op_name=None):
    if op_name in [ "Scope2Func", "ZipAssgin"]:
        return if_Expr # skip 
    assert type(if_Expr)==list , (type(if_Expr) , "type.if_Expr")
    if len(if_Expr)<=1 :
        return if_Expr 
    
    raw_exp_list= None 
    # print (if_Expr,"--->if_Expr", len(if_Expr), op_name )

    # try :
    raw_exp_list= [(item.lineno,item.end_lineno, "#%s"%(idx) )for idx, item in enumerate(if_Expr)]
    # except :
    #     traceback.print_exc()
        # print (if_Expr,"if_Expr", len(if_Expr), op_name )
    # print (op_name , "op_name", len(raw_exp_list), "len.raw_exp_list")
    nointersect_exp_list= find_max_non_intersecting_lines( lines = raw_exp_list )
    # print (len(raw_exp_list), "nointersect_exp_list" , len(nointersect_exp_list) )
    selected_ids = [item[-1] for item in nointersect_exp_list ]
    new_if_Expr =  [one_if_expr  for idx , one_if_expr in enumerate(if_Expr) if  "#%s"%(idx)  in selected_ids]
    
    return new_if_Expr 

    

class ExprProcessor:
    def __init__(self, operator_list=[] ):
        # if operator_list is None or len(operator_list)<=0 :
        #     self.operator_list = OPERATOR_LIST 
        # else :
        self.operator_list = operator_list 

        assert  len( set(self.operator_list) -set(OPERATOR_LIST)  )==0 ,self.operator_list
    # @staticmethod
    def run(self, code_content):
        code_list = []
        
        for idx, op_name in enumerate(self.operator_list ):
            # if len(code_list)>0:
            #     code_content = "".join(code_list)
            # print ("idx", idx, len(code_list), "total.ccode_list" , op_name)
            nodes_dict, name_set ,scope_list     = self.parse_tree(code_content = code_content )
            # print (nodes_dict, "nodes_dict")
            OPERATOR_DICT = {
                "IfExpr2Stmt": partial( mt.mutateIfExpr2Stmt ),
                "Lambda2Func": partial( mt.mutateLambda2func, name_set=name_set ),
                "Comp2For": partial( mt.mutateComp2For ),
                "ExprStmt2Assign": partial( mt.mutateExprStmt2Assign,name_set=name_set ),
                "AugAssgin2Assign": partial( mt.mutateAugAssgin2Assgin ),
                "Assert2If": partial( mt.mutateAssert ),
                "IfStmt2IfStmt": partial( mt.mutateIfStmt ),
                "WhileStmt": partial( mt.mutateWhileStmt ),
                "ForStmt": partial( mt.mutateForStmt,name_set=name_set ),
            
                "With2Try": partial( mt.mutateWith2Try ),
                "Scope2Func": partial( mt.mutateZipScope, name_set=name_set ),
                "ZipAssgin": partial( mt.mutateZipAssgin ),
                "AssginAddLine": partial( mt.mutateAssginAddLine ),
                } 
            
            nodes = nodes_dict.get(op_name) 
            if len(nodes)<=0 :
                continue 
            func_call = OPERATOR_DICT.get(op_name)
            
            action_list  = func_call(nodes=nodes)
            # print (action_list, "len.action_list")
            code_list = code_content.split("\n")
            code_list = ["\n"]+code_list
            code_list = self.replace_mt(code_list = code_list ,new_action_list =action_list  )
            code_list = [x for x in code_list if x is not None ]
    
            code_content = "\n".join( code_list )
            code_list_v1=len(code_list)
            code_content  = format_content( fstr=code_content )
            code_list_v2=len( code_content.split("\n")   )


        return code_content 

    @staticmethod
    def parse_tree(code_content ):
        
        tree  =ast.parse(code_content)
        mt.visit(tree, [], 0)
        scope_list, name_set = mt.findIdentifier(tree)

        nodes_dict = mt.analyzeCode(tree) 
        for k in list(nodes_dict):
            v = nodes_dict[k]
            # print (k, "k", k )
            v = remove_any_interset_expr(v,op_name=k)
            nodes_dict [k] = v 
        return nodes_dict, name_set ,scope_list #, code_content# formated code_content
    
    
    @staticmethod
    def replace_mt(code_list , new_action_list ):
        
        ori_len = len(code_list)
        for one_act in new_action_list :
            indent, name, _, start_end , code_src_new = one_act[:5]
            assert len(start_end)>0 , (one_act, "len.start_end" )
            # print ("start_end", start_end )
            
            if len(start_end)==1:
                start, end =start_end[0],start_end[0]
            else:
                start, end =start_end[0],start_end[1]
                
            assert start <=len(code_list), (len(code_list), start_end , "")
            assert end <=len(code_list), (len(code_list), start_end , "")
            code_src_new= code_src_new.split("\n")

            code_src_new = [(" "*indent)+x for x in code_src_new ]
            joint= "\n"
            if code_src_new[0].endswith("\n"):
                joint =""
            code_src_new  = joint.join(code_src_new )
            
            
            # code_src_new = f"\n###@###{name}:{indent}\n{code_src_new}#---end of {name}\n"
            # code_src_new = f"\n###@###{name}:{indent}\n{code_src_new}#---end of {name}\n"
            # code_src_new = "\n#"+name+"\n"+code_src_new+"#---end of "+name +"\n"

            if not code_src_new.endswith("\n"):
                code_src_new +="\n"
            
            if start==end :
                code_list[start] = code_src_new
            else:
                code_list[start:end+1 ]= [None]*(end+1-start)
                code_list[start]= code_src_new
                
        assert ori_len == len(code_list), "should not change total line number, the raw %s not equal %s "%( ori_len , len(code_list)  )
         
         
        # code_list =[x for x in code_list if x is not None ]
        # code_src = "".join(code_list)
        
        return code_list 
    
            





#
class ExprProcessorConflict(ExprProcessor ) :

    def run(self,code_content):
        code_list = []

        nodes_dict, name_set ,scope_list     = self.parse_tree(code_content = code_content )


        zip_Scope= nodes_dict.pop("Scope2Func")
        zip_Assgin= nodes_dict.pop("ZipAssgin")



        
        # if_Expr  = nodes_dict.pop("IfExpr2Stmt")
        # assgin_Expr= nodes_dict.pop("AssginAddLine")
        # lambda_Expr= nodes_dict.pop("Lambda2Func")
        # comp_Expr= nodes_dict.pop("Comp2For")
        # expr_Stmts= nodes_dict.pop("ExprStmt2Assign")
        # assgin_Stmts= nodes_dict.pop("AugAssgin2Assign")
        # assert_Stmts= nodes_dict.pop("Assert2If")
        # if_Stmts= nodes_dict.pop("IfStmt2IfStmt")
        # while_Stmts= nodes_dict.pop("WhileStmt")
        # for_Stmts= nodes_dict.pop("ForStmt")
        # with_Stmts= nodes_dict.pop("With2Try")

        raw_expr_list= []
        # for name,exp in   zip([
        #     "if_Expr","assgin_Expr", "lambda_Expr", "comp_Expr",
        #     "expr_Stmts","assgin_Stmts","assert_Stmts","if_Stmts",
        #     "while_Stmts","for_Stmts","with_Stmts",
        #      ], [
        #          if_Expr,assgin_Expr,lambda_Expr,comp_Expr,
        #          expr_Stmts,assgin_Stmts,assert_Stmts,if_Stmts,
        #          while_Stmts,for_Stmts,with_Stmts
        #          ]):

        raw_exp_list.extend(  [(item.lineno,item.end_lineno, name , item  )for idx, item in enumerate(exp)]  )


        # print ("raw" , len(raw_exp_list) )
        nointersect_exp_list= find_max_non_intersecting_lines( lines = raw_exp_list )
        # print ("nointersect_exp_list" , len(nointersect_exp_list) )

        nodes_dict={op_oracle:[] for op_oracle in OPERATOR_LIST }
             
        for start,end,op_name , item in nointersect_exp_list:
            nodes_dict[op_name] +=[item] 

        zip_Scope= nodes_dict.pop("Scope2Func")
        zip_Assgin= nodes_dict.pop("ZipAssgin")



        if_Expr  = nodes_dict.pop("IfExpr2Stmt")
        assgin_Expr= nodes_dict.pop("AssginAddLine")
        lambda_Expr= nodes_dict.pop("Lambda2Func")
        comp_Expr= nodes_dict.pop("Comp2For")
        expr_Stmts= nodes_dict.pop("ExprStmt2Assign")
        assgin_Stmts= nodes_dict.pop("AugAssgin2Assign")
        assert_Stmts= nodes_dict.pop("Assert2If")
        if_Stmts= nodes_dict.pop("IfStmt2IfStmt")
        while_Stmts= nodes_dict.pop("WhileStmt")
        for_Stmts= nodes_dict.pop("ForStmt")
        with_Stmts= nodes_dict.pop("With2Try")

        action_list = []
        action_list += mutateIfExpr2Stmt(if_Expr)

        action_list += negate_if_expr(if_Expr)
        action_list += mutateAssginAddLine(assgin_Expr)
        action_list += mutateLambda2func(lambda_Expr, name_set)
        action_list += mutateComp2For(comp_Expr)
        action_list += mutateExprStmt2Assign(expr_Stmts, name_set)
        action_list += mutateAugAssgin2Assgin(assgin_Stmts)
        action_list += mutateAssert(assert_Stmts)
        action_list += mutateIfStmt(if_Stmts)
        action_list += mutateWhileStmt(while_Stmts)
        action_list += mutateForStmt(for_Stmts, name_set)
        action_list += mutateWith2Try(with_Stmts)
        action_list += mutateZipScope(zip_Scope, name_set)
        action_list += mutateZipAssgin(zip_Assgin)


        for k,v in nodes_dict.items() :
            k_full = "#{k}#"
            [vv for vv in v ]

        # no_need_filter = 
        for k in list(nodes_dict):
            v = nodes_dict[k]
            # print (k, "k", k )
            v = remove_any_interset_expr(v,op_name=k)
            nodes_dict [k] = v 

        # for idx, op_name in enumerate(OPERATOR_LIST):
        for idx, op_name in enumerate(self.operator_list):
            
            # if len(code_list)>0:
            #     code_content = "".join(code_list)
            # print ("idx", idx, len(code_list), "total.ccode_list" , op_name)

            OPERATOR_DICT = {
                "IfExpr2Stmt": partial( mt.mutateIfExpr2Stmt ),
                "Lambda2Func": partial( mt.mutateLambda2func, name_set=name_set ),
                "Comp2For": partial( mt.mutateComp2For ),
                "ExprStmt2Assign": partial( mt.mutateExprStmt2Assign,name_set=name_set ),
                "AugAssgin2Assign": partial( mt.mutateAugAssgin2Assgin ),
                "Assert2If": partial( mt.mutateAssert ),
                "IfStmt2IfStmt": partial( mt.mutateIfStmt ),
                "WhileStmt": partial( mt.mutateWhileStmt ),
                "ForStmt": partial( mt.mutateForStmt,name_set=name_set ),

                "With2Try": partial( mt.mutateWith2Try ),
                "Scope2Func": partial( mt.mutateZipScope, name_set=name_set ),
                "ZipAssgin": partial( mt.mutateZipAssgin ),
                "AssginAddLine": partial( mt.mutateAssginAddLine ),
                } 

            nodes = nodes_dict.get(op_name) 
            func_call = OPERATOR_DICT.get(op_name)

            action_list  = func_call(nodes=nodes)
            # print (action_list)
            code_list = code_content.split("\n")
            code_list = ["\n"]+code_list
            code_list = self.replace_mt(code_list = code_list ,new_action_list =action_list  )
            code_list = [x for x in code_list if x is not None ]

            code_content = "\n".join( code_list )
            code_list_v1=len(code_list)
            code_content  = format_content( fstr=code_content )
            code_list_v2=len( code_content.split("\n")   )


        return code_content 

    def parse_tree(self,code_content ):

        tree  =ast.parse(code_content)
        mt.visit(tree, [], 0)
        scope_list, name_set = mt.findIdentifier(tree)

        nodes_dict = mt.analyzeCode(tree) 
        # for k in list(nodes_dict):
        #     v = nodes_dict[k]
        #     # print (k, "k", k )
        #     v = remove_any_interset_expr(v,op_name=k)
        #     nodes_dict [k] = v 
        return nodes_dict, name_set ,scope_list #, code_content# formated code_content


