
import os 
import json 
from contextlib import contextmanager

import pika
from pika.exchange_type import ExchangeType

import logging

logger = logging.getLogger(__name__)
# logging.basicConfig(
#      # filename='log_file_name.log',
#      level=logging.INFO, 
#      format= '[%(asctime)s] {%(name)s:%(lineno)d} %(levelname)s - %(message)s',
#      datefmt='%H:%M:%S'
#  )


@contextmanager
def context_init_connect_with_priority(connect_str, exc_name,queue_name ,routing_key=None  ):
    
    
    parameters = pika.connection.URLParameters(url= connect_str )
    parameters . heartbeat = 600
    print ("parameters", parameters , connect_str)
    connection = pika.BlockingConnection(parameters )
    main_channel = connection.channel()

    main_channel.exchange_declare(exchange=exc_name,durable=True, exchange_type=ExchangeType.direct)

    if type(queue_name)==str:
        queue_name = [queue_name]
    
    for _queue_name in queue_name:
        main_channel.queue_declare(queue=_queue_name, durable=True, auto_delete=False, arguments={'x-max-priority': 10} )
        main_channel.queue_bind(queue=_queue_name,
                                     exchange=exc_name,
                                     routing_key=routing_key,
                                     )
    
    # main_channel.basic_qos(prefetch_count=prefetch_count )


    # setattr(main_channel,"x_exc_name", exc_name)
    # setattr(main_channel,"x_queue_name", queue_name)
    #

    try:
        yield  main_channel
    finally:
        if connection is not None:
            connection.close()
            print('Connection closed.')


from typing import Any, Mapping, Optional, Union, cast

def parse_extra_eval_params(
    param_str: Optional[str],
) -> Mapping[str, Union[str, int, float]]:
    """Parse a string of the form "key1=value1,key2=value2" into a dict."""
    if not param_str:
        return {}

    def to_number(x: str) -> Union[int, float, str]:
        try:
            return int(x)
        except:
            pass
        try:
            return float(x)
        except:
            pass
        return x

    str_dict = dict(kv.split("=") for kv in param_str.split(","))
    return {k: to_number(v) for k, v in str_dict.items()}


        
if __name__=="__main__":
    import msgpack 
    from glob2 import glob 
    from tqdm import tqdm 
    import copy 
    import datetime
    import random 
    import click
    import traceback

    #connect_str=  "amqp://detection:detection123@10.96.187.173:5672/%2Fdetection"
    connect_str=  "amqp://detection:detection123@10.96.183.224:5672/%2Fdetection"


    q_list= [""
#"gptzero_me",
#"originality",
#"sapling_ai",
#"scribbr_com",
# "writer_com",

#"GPT2_Detector_out",
"Hello_SimpleAI_out",
"Hello_SimpleAI_qa_out",
#"ArguGPT_out",
#"ICLR2024_AIGC_text_detector_out",
#"RADAR_nips2023_out",
#"yaful_out",

             ]
    for q_name in q_list:
        filelist = glob(f"/home/wj2_cuda12/wj_code/dl_chatgpt/remain_dirs/all_in_one/*,q={q_name}*jsonl")
        random.shuffle(filelist)
        print ("filelist" , len(filelist) )
        
        # continue /
        # mq_list=list(set( mq_list  ))
    
        routing_key = "global_key"
        routing_key = "global_key__"+str(datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S"))
    
        def send_into_queue(xpath,priority=1, exc_name="my_fix_exc", dryrun=True  ):
            task_info_str = os.path.basename(xpath)
            task_info_str = task_info_str.replace(".jsonl","")
            
            # print ("tart", task_info_str )
            taks_info  = parse_extra_eval_params( task_info_str )
            taks_info = {str(k):str(v) for k,v in taks_info.items() }
            # print ("task_info_str", task_info_str, "taks_info", taks_info )
            
            assert "r" in taks_info, taks_info 
            role = taks_info["r"]
            assert role in ["human_answer","chatgpt_answer"], role 
            
            with open(xpath) as f :
                raw_lines= f.readlines()
                lines = [json.loads(x) for x in raw_lines ]
    
    
            queue_name = taks_info["q"].replace("_out","")
            mq_list = [queue_name]
            
            # queue_name = mq_list[0]
    
            if dryrun :
                return { "queue_name":queue_name ,"total.file_len":len(lines) ,"xpath":os.path.basename(xpath) }
    
            
            with context_init_connect_with_priority(connect_str, exc_name,mq_list,routing_key=routing_key  ) as channel :
            #
                # role="human_answer"
                for item  in tqdm(lines):
                    # print (list(item), "list.item")
                    new_item = item 
                    # if item[role] is None :
                    #     continue 
                    # item["task"]= task_info_str
                    # new_item = {
                    #     "id":item["sql_id"],
                    #     "content":item[role],
                    #     "task":{**taks_info,"role":role},
                    #     }
            
                    msg_pack = msgpack.packb(new_item )
                    _priority = random.randint(4,7)
            
                    ret = channel.basic_publish(
                        exchange=exc_name ,
                        routing_key= routing_key ,
                        body=msg_pack, 
                        properties=pika.BasicProperties(
                               headers =taks_info,
                                priority=max(_priority,int(priority) ) ,
                                delivery_mode=2,
                            )
                        )
        
    
        # meta_list=[]
        # for one_f in tqdm(filelist):
        #     meta_list.append ( send_into_queue(xpath=one_f, dryrun=True ) )
        # print (meta_list)
        
        import random 
        # if click.confirm('Do you want to continue?', default=True):
        if 1==1:
            print('Do something')
            # for one_f in tqdm(filelist):
            #     process_file(xpath=one_f, dryrun=False  ) 
                
            from concurrent.futures import ThreadPoolExecutor
            import os 
            def process_file(i):
                one_f = filelist[i]
                priority =7 if "_test" in os.path.basename(one_f) else 3 
                # print ("send", one_f)
                info = send_into_queue(xpath=one_f, dryrun=False ,priority=priority , exc_name="my_fix_exc_remain"+q_name) 
            num_workers = min( len(filelist), 16 )
            with ThreadPoolExecutor(max_workers=num_workers) as ex:
                predictions = ex.map(process_file, range(len(filelist)))
            
            #process_file(0)
                    

