Spaces:
Running
Running
| # Consumer | |
| import time | |
| import pika | |
| import os | |
| from Server import get_response | |
| import json | |
| from agent.agent_graph.StateTasks import ProblemState | |
| import argparse | |
| import redis | |
| from encryption_utils import decrypt_token_from_json | |
| ################################################## | |
| # VARIABLES | |
| ################################################## | |
| # args for this file | |
| argparse_model = argparse.ArgumentParser() | |
| argparse_model.add_argument("--id", type=int, default=0, help="Consumer ID") | |
| consumer_id = argparse_model.parse_args().id | |
| RABBITMQ_URL = os.environ["RABBITMQ_URL"] | |
| QUEUE_NAME = os.environ["QUEUE_NAME"] | |
| redis_host = os.environ["REDIS_HOST"] | |
| redis_port = os.environ["REDIS_PORT"] | |
| redis_password = os.environ["REDIS_PASSWORD"] | |
| ################################################## | |
| # PROCESSING METHODS | |
| ################################################## | |
| def redis_send(user_id,msg_id,answer): | |
| r = redis.Redis( | |
| host=redis_host, | |
| port=redis_port, | |
| decode_responses=True, | |
| username="default", | |
| password=redis_password, | |
| ) | |
| success = r.set(f'ANSWER_FOR_USER_ID{user_id}_OF_{msg_id}',json.dumps(answer)) | |
| return success | |
| def model_call(request,token): | |
| # fill with last state | |
| try: | |
| state = json.loads(request['last_state']) | |
| except Exception: | |
| state: ProblemState = { | |
| "question": request['prompt'], | |
| "memory": request['memory'] | |
| } | |
| answer = get_response(request['prompt'], request['memory'],token,state,request['user_email'],request['user_name']) | |
| # drop unserlizable keys | |
| for k in ["llm","rag_model"]: | |
| answer[k] = "" | |
| return answer | |
| def process_message(recieved_msg): | |
| # decrypt token | |
| token = decrypt_token_from_json(json.loads(recieved_msg['ht_token_encrypted_dumped'])) | |
| # call the model | |
| model_answer = model_call(recieved_msg,token) | |
| # send answer to redis | |
| user_id = recieved_msg["user_id"] | |
| msg_id = recieved_msg["msg_id"] | |
| redis_send_res = redis_send(user_id,msg_id,model_answer) | |
| print({"STATUS": redis_send_res , "CONSUMER": {consumer_id}}) # add monitoring but still hide user data | |
| ################################################## | |
| # CONSUMER METHODS | |
| ################################################## | |
| def get_connection(): | |
| params = pika.URLParameters(RABBITMQ_URL) | |
| return pika.BlockingConnection(params) | |
| def callback(ch, method, properties, body): | |
| ##### Recieve message and process it | |
| recieved_msg = json.loads(body.decode()) | |
| print("-------------------------------------------------") | |
| print(f"MSG AT CONSUMER {consumer_id}" ) | |
| ##### Process Message | |
| process_message(recieved_msg) | |
| ###### Finalize | |
| ch.basic_ack(delivery_tag=method.delivery_tag) | |
| def start_consumer(): | |
| # when scalled each server has consumer | |
| params = pika.URLParameters(RABBITMQ_URL) | |
| connection = pika.BlockingConnection(params) | |
| channel = connection.channel() | |
| channel.queue_declare(queue=QUEUE_NAME, durable=True) | |
| channel.basic_qos(prefetch_count=1) | |
| channel.basic_consume( | |
| queue=QUEUE_NAME, | |
| on_message_callback=callback | |
| ) | |
| print("Waiting for messages...") | |
| channel.start_consuming() | |
| ################################################## | |
| # MAIN | |
| ################################################## | |
| if __name__ == "__main__": | |
| print(f"Starting New Consumer {consumer_id}...") | |
| start_consumer() |