
import time
import redis
from redis.exceptions import RedisError

import traceback

# from .config import redis_nodes, rabbitmq_params
import yaml
import os 


## copy from chatgpt 
config_str="""
redis_cluster:
  nodes:
  - host: csl-server7.dynip.ntu.edu.sg
    port: 6379
  - host: csl-server7.dynip.ntu.edu.sg
    port: 6380
  - host: csl-server7.dynip.ntu.edu.sg
    port: 6381
  - host: csl-server7.dynip.ntu.edu.sg
    port: 6382
  - host: csl-server7.dynip.ntu.edu.sg
    port: 6383
  - host: csl-server7.dynip.ntu.edu.sg 
    port: 6384
  - host: server2.dynip.ntu.edu.sg  
    port: 1883
  max_connections: 100


"""



# with open( os.path.join(cur, "config.yml") ) as f :
config_params  = yaml.safe_load(config_str)

redis_nodes = config_params["redis_cluster"]["nodes"]

# rabbitmq_params = config_params["rabbitmq"]






def is_connect(redis_conn=None ):
    def _c():
        redis_list =[ redis.cluster.ClusterNode(**item) if type(item)==dict else item for item in redis_nodes ] 
        redis_conn = redis.cluster.RedisCluster(startup_nodes=redis_list, decode_responses=True)
        return redis_conn
    
    if redis_conn is None :
        return _c()
    if redis_conn.ping():
        return redis_conn
    else:
        return _c()
    
    


def is_validate(account_id ,func_name):
    
    key = key_prefix + func_name + ":" + str(account_id) #+ ":" + str(args) + ":" + str(kwargs)
    
    
class RateLimit:
    
    # Define the rate limit values
    rate_limit = 160  # The maximum number of requests allowed per second
    window_size = 60  # The time window (in seconds) for which the rate limit applies
    # Define the key prefix for storing the rate limit data in Redis
    key_prefix = "ratelimit:"
    key_prefix_gap = "ratelimit_gap:"
    
    def __init__(self, rate_limit=160 ,window_size=60 ,func_name = "", gap_window=None ):
        self.rate_limit  = rate_limit 
        self.window_size  = window_size 
        self.func_name  = func_name 

    
        self.redis_conn =is_connect()


        self.gap_window = None 
        if gap_window is not None :
            self.gap_window = int(gap_window) 
        
    # def rate_limited(self, func, *dec_args, **dec_kwargs):
    #     """Decorator for rate limiting"""
    
    def limit(self, account_id):
        self.redis_conn = is_connect(self.redis_conn )
        # redis_conn = kwargs.get("redis_conn", None )
        # print ("redis_conn from kwargs", redis_conn , type(redis_conn) )
        # print ( list(kwargs) , args )
        # redis_conn = is_connect( redis_conn )
        
        
        # Generate a unique key for the rate limit based on the function name, account ID, and arguments
        key = self.key_prefix  + self.func_name + ":" + str(account_id) #+ ":" + str(args) + ":" + str(kwargs)
        key_gap = self.key_prefix_gap  + self.func_name + ":" + str(account_id) #+ ":" + str(args) + ":" + str(kwargs)
        # print ("redis ", key )
        # Get the current timestamp
        timestamp = time.time()
        
        # Use a pipeline to atomically perform the rate limit checks and updates
        pipeline = self.redis_conn.pipeline()
        
        
        # print ("pipeline---",self.gap_window,"-->",key_gap )
        if self.gap_window is not None :
            key_gap_or = self.redis_conn.get(key_gap)
            # print (key_gap,"pipeline---",self.gap_window,"-->",key_gap_or ,"...key_gap_or" )
            if key_gap_or is not None :
                return False 
            
        
        # Add a Redis hash field to store the timestamp of the most recent request
        pipeline.hsetnx(key, "timestamp", timestamp)
        
        # Add a Redis hash field to store the number of requests made within the current time window
        pipeline.hincrby(key, "count", 1)
        
        # Add a Redis hash field to set the time-to-live (TTL) for the rate limit data
        pipeline.expire(key, self.window_size)
        
        try:
            # Execute the pipeline
            pipeline.execute()
            
            # Get the current count and timestamp values from Redis
            count = int(self.redis_conn.hget(key, "count"))
            timestamp = float(self.redis_conn.hget(key, "timestamp"))
            
            # Calculate the time since the most recent request
            elapsed_time = timestamp + self.window_size - time.time()
            
            # Check if the rate limit has been exceeded
            if count > self.rate_limit:
                # Calculate the time to wait before allowing the next request
                wait_time = max(elapsed_time, 0)
                
                # Sleep for the required amount of time
                # time.sleep(wait_time)
                raise StopIteration("limit now ")
                
        except RedisError:
            traceback.print_exc()
            # Handle Redis connection errors
            return False 
            pass

        if self.gap_window is not None :
            self.redis_conn.set(key_gap, 1, )
            self.redis_conn.expire(key_gap, self.gap_window )
            
            
        return True 
        
        # Call the decorated function with the account ID and other arguments
        # return func(account_id, *args, **kwargs)
    
        # # print (dec_args,  list(dec_kwargs) ,"list.dec_kwargs" )
        #
        # from functools import partial
        #
        # return partial(wrapper, *dec_args, **dec_kwargs)



    # def rate_limited(self, func, *dec_args, **dec_kwargs):
    #     """Decorator for rate limiting"""
    #     def wrapper(account_id, *args, **kwargs):
    #         # Generate a unique key for the rate limit based on the function name, account ID, and arguments
    #         key = self.key_prefix  + self.func_name + ":" + str(account_id) #+ ":" + str(args) + ":" + str(kwargs)
    #         print ("redis ", key )
    #         # Get the current timestamp
    #         timestamp = time.time()
    #
    #         redis_conn = kwargs.get("redis_conn", None )
    #         # print ("redis_conn from kwargs", redis_conn , type(redis_conn) )
    #         # print ( list(kwargs) , args )
    #         redis_conn = is_connect( redis_conn )
    #         # Use a pipeline to atomically perform the rate limit checks and updates
    #         pipeline = redis_conn.pipeline()
    #
    #         # Add a Redis hash field to store the timestamp of the most recent request
    #         pipeline.hsetnx(key, "timestamp", timestamp)
    #
    #         # Add a Redis hash field to store the number of requests made within the current time window
    #         pipeline.hincrby(key, "count", 1)
    #
    #         # Add a Redis hash field to set the time-to-live (TTL) for the rate limit data
    #         pipeline.expire(key, self,window_size)
    #
    #         try:
    #             # Execute the pipeline
    #             pipeline.execute()
    #
    #             # Get the current count and timestamp values from Redis
    #             count = int(redis_conn.hget(key, "count"))
    #             timestamp = float(redis_conn.hget(key, "timestamp"))
    #
    #             # Calculate the time since the most recent request
    #             elapsed_time = timestamp + self.window_size - time.time()
    #
    #             # Check if the rate limit has been exceeded
    #             if count > self.rate_limit:
    #                 # Calculate the time to wait before allowing the next request
    #                 wait_time = max(elapsed_time, 0)
    #
    #                 # Sleep for the required amount of time
    #                 # time.sleep(wait_time)
    #                 raise StopIteration("limit now ")
    #
    #         except RedisError:
    #             # Handle Redis connection errors
    #             pass
    #
    #         # Call the decorated function with the account ID and other arguments
    #         return func(account_id, *args, **kwargs)
    #
    #     # print (dec_args,  list(dec_kwargs) ,"list.dec_kwargs" )
    #
    #     from functools import partial
    #
    #     return partial(wrapper, *dec_args, **dec_kwargs)

# Example usage of the rate_limited decorator with different account IDs
# @rate_limited
# def api_request(account_id):
#     print(f"API request made for account {account_id}")



#
# import asyncio
# import aioamqp
#
# # from .openai_cls import OpenaiDetector
# # Define the number of consumers to start
# num_consumers = 3
#
# # Define the RabbitMQ connection parameters
#
#
# # Define a callback coroutine to handle incoming messages
# async def callback(channel, body, envelope, properties):
#     print(f"Received message: {body.decode()}")
#
#
#
#     await channel.basic_client_ack(delivery_tag=envelope.delivery_tag)
#
# # Define an async function to start a consumer
# async def start_consumer():
#     try:
#         # Connect to RabbitMQ and create a channel
#         transport, protocol = await aioamqp.connect(**rabbitmq_params)
#         channel = await protocol.channel()
#
#         # Declare the queue and start consuming messages
#         await channel.queue_declare(queue_name='my_queue')
#         await channel.basic_qos(prefetch_count=1)
#         await channel.basic_consume(callback, queue_name='my_queue')
#
#         # Wait for incoming messages indefinitely
#         while True:
#             await asyncio.sleep(1)
#     except Exception as e:
#         print(f"Error: {e}")
#     finally:
#         # Close the connection on exit
#         await protocol.close()
#         transport.close()
#
# # Start multiple consumers in parallel
# async def main():
#     consumers = []
#     for i in range(num_consumers):
#         consumer = asyncio.ensure_future(start_consumer())
#         consumers.append(consumer)
#     await asyncio.gather(*consumers)

# # Run the main async function
# loop = asyncio.get_event_loop()
# loop.run_until_complete(main())
#
#
#
#
# # Make 20 API requests for two different accounts
# for i in range(20):
#     api_request(account_id=1)
#     api_request(account_id=2)
#
#
#
#
#




# Initialize the Redis Cluster
# redis_nodes = [
#     {"host": "redis-node1", "port": 6379},
#     {"host": "redis-node2", "port": 6379},
#     {"host": "redis-node3", "port": 6379}
# ]

# rabbitmq_params = {
#     'host': 'localhost',
#     'port': 5672,
#     'login': 'guest',
#     'password': 'guest',
#     'virtualhost': '/'
# }
import yaml
import pika
import requests


class RabbitMQConsumer:

    def __init__(self, config ):

    # Read configuration from YAML file
    # with open('rabbitmq.yaml', 'r') as file:
    #     config = yaml.safe_load(file)['rabbitmq']
        self.config_conn = config["connect_params"]  
        self.config_queue = config["queue_params"]  
        
        
        self.channel = None 
        
    
    def __do_connect(self, ):
        # Establish RabbitMQ connection
        config = self.config_conn 
        
        credentials = pika.PlainCredentials(config['username'], config['password'])
        parameters = pika.ConnectionParameters(
            host=config['host'],
            port=config['port'],
            virtual_host=config['virtual_host'],
            credentials=credentials
        )
        connection = pika.BlockingConnection(parameters)
    
        self._connection = connection 
        
        
        config = self.config_queue
        
        # Create channel and bind queues to the exchange
        channel = self._connection.channel()
        channel.exchange_declare(
            exchange=config['exchange']['name'],
            exchange_type=config['exchange']['type'],
            durable=config['exchange']['durable'],
            auto_delete=config['exchange']['auto_delete']
        )
        
        self.dispatch_queue_list= []
        
        for queue_config in config['queues']:
            queue_name = queue_config['name']
            queue_priority = queue_config['priority']
            channel.queue_declare(queue=queue_name, arguments={
                'x-max-priority': 10, # maximum priority range (1-10)
                'x-priority': queue_priority # queue priority
            })
            channel.queue_bind(exchange=config['exchange']['name'], queue=queue_name)
        
            self.dispatch_queue_list .append( queue_name )
            
        # Declare result queue
        result_queue_name = config['result_queue']['name']
        result_queue_durable = config['result_queue']['durable']
        result_queue_auto_delete = config['result_queue']['auto_delete']
        channel.queue_declare(
            queue=result_queue_name,
            durable=result_queue_durable,
            auto_delete=result_queue_auto_delete
        )
        
        self.result_queue_name = result_queue_name
        self.channel  = channel
        
    def ping (self,retry = 5):
        if self.channel is not None and self.channel.is_open:
            return True 
        else :
            if retry>0:
                self.__do_connect()
                return self.ping(retry=retry-1 )
            else:
                return False 
          
    def run(self):  
        self.ping()
 
        result_queue_name = self.config_queue["result_queue"]["name"]
        exchange_name  = self.config_queue["exchange"]["name"]
        
        # Define callback function to consume messages and send HTTP request
        def callback(ch, method, properties, body):
            self.ping()
            print("Received message:", body.decode())
            print("Processing message...")
            
            content="""
            # Send HTTP request to URL and get response
            url = "http://api.google.com/api"
            response = requests.get(url)
            content = response.content
            """
            body_ret  = {
                "id":body["id"],
                "data":body["id"],
                }
            
            # import random 
            # rid = random.randint(0,10)
            # import time 
            # time.sleep(rid )
            
            # Publish response to result queue
            ret = ch.basic_publish(
                exchange='',
                routing_key=result_queue_name,
                body=body_ret,
                # properties=pika.BasicProperties(
                #     delivery_mode=2,  # make message persistent
                # )
            )
        
            print("Message processed", "-->", ret )
        
        
        config =self.config_queue 
        # Start consuming messages
        for queue_config in config['queues']:
            queue_name = queue_config['name']
            self.channel.basic_consume(
                queue=queue_name,
                on_message_callback=callback,
                auto_ack=False ,
            )
        
        print("Listening for messages. To exit, press CTRL+C")
        self.channel.start_consuming()
        
        
        
