Spaces:
Sleeping
Sleeping
| from src.agents.chat_agent import BaseChatAgent | |
| from src.utils.utils import save_as_json | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| import os | |
| import sys | |
| import time | |
| import traceback | |
| import boto3 | |
| import argparse | |
| import pytz | |
| import json | |
| from datetime import datetime | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| class Server: | |
| def __init__(self) -> None: | |
| self.patient_info = "" | |
| self.conversation = [] | |
| self.patient_out = None | |
| self.doctor_out = None | |
| self.patient = None | |
| self.doctor = None | |
| self.conversation_round = 0 | |
| self.interview_protocol_index = None | |
| def set_timestamp(self): | |
| self.timestamp = datetime.now(pytz.timezone('US/Eastern')).strftime("%m/%d/%Y-%H:%M:%S") | |
| def set_patient(self, patient): | |
| self.patient = patient | |
| self.patient_info = { | |
| "patient_model_config": patient.agent_config, | |
| } | |
| def set_doctor(self, doctor): | |
| self.doctor = doctor | |
| def set_interview_protocol_index(self, interview_protocol_index): | |
| self.interview_protocol_index = interview_protocol_index | |
| def generate_doctor_response(self): | |
| ''' | |
| Must be called after setting the patient and doctor | |
| ''' | |
| self.doctor_out = self.doctor.talk_to_user( | |
| self.patient_out, conversations=self.conversation)[0] | |
| return self.doctor_out | |
| def submit_doctor_response(self, response): | |
| self.conversation.append(("doctor", response)) | |
| self.doctor.context.add_assistant_prompt(response) | |
| def submit_patient_response(self, response): | |
| self.conversation.append(("patient", response)) | |
| self.patient.context.add_assistant_prompt(response) | |
| def get_response(self, patient_prompt): | |
| self.patient_out = patient_prompt | |
| self.submit_patient_response(patient_prompt) | |
| print(f'Round {self.conversation_round} Patient: {patient_prompt}') | |
| if patient_prompt is not None: | |
| self.conversation_round += 1 | |
| doctor_out = self.generate_doctor_response() | |
| self.submit_doctor_response(doctor_out) | |
| print(f'Round {self.conversation_round} Doctor: {doctor_out}') | |
| return {"response": doctor_out} | |
| def to_dict(self): | |
| return { | |
| 'time_stamp': self.timestamp, | |
| 'patient': { | |
| 'patient_user_id': self.patient.patient_id, | |
| 'patient_info': self.patient_info, | |
| 'patient_context': self.patient.context.msg | |
| }, | |
| 'doctor': { | |
| 'doctor_model_config': self.doctor.agent_config, | |
| 'doctor_context': self.doctor.context.msg | |
| }, | |
| "conversation": self.conversation, | |
| "interview_protocol_index": self.interview_protocol_index | |
| } | |
| def __json__(self): | |
| return self.to_dict() | |
| def reset(self): | |
| self.conversation = [] | |
| self.conversation_round = 0 | |
| if hasattr(self.doctor, 'reset') and callable(getattr(self.doctor, 'reset')): | |
| self.doctor.reset() | |
| if hasattr(self.patient, 'reset') and callable(getattr(self.patient, 'reset')): | |
| self.patient.reset() | |
| def create_app(): | |
| app = Flask(__name__) | |
| CORS(app) | |
| app.user_servers = {} | |
| return app | |
| def configure_routes(app, args): | |
| def home(): | |
| ''' | |
| This api will return the default prompts used in the backend, including system prompt, autobiography generation prompt, therapy prompt, and conversation instruction prompt | |
| Return: | |
| { | |
| system_prompt: String, | |
| autobio_generation_prompt: String, | |
| therapy_prompt: String, | |
| conv_instruction_prompt: String | |
| } | |
| ''' | |
| return jsonify({ | |
| }), 200 | |
| def initialization(): | |
| ''' | |
| This API processes user configurations to initialize conversation states. It specifically accepts the following parameters: | |
| api_key, username, chapter_name, topic_name, and prompts. The API will then: | |
| 1. Initialize a Server() instance for managing conversations and sessions. | |
| 2. Configure the user-defined prompts. | |
| 3. Set up the chapter and topic for the conversation. | |
| 4. Configure the save path for both local storage and Amazon S3. | |
| ''' | |
| data = request.get_json() | |
| username = data.get('username') | |
| api_key = data.get('api_key') | |
| if api_key and isinstance(api_key, str) and api_key.strip(): | |
| os.environ["OPENAI_API_KEY"] = api_key.strip() | |
| # initialize | |
| # server.patient.patient_id = username | |
| counselor = BaseChatAgent(config_path=args.counselor_config_path) | |
| print(counselor) | |
| server = Server() | |
| # server.set_doctor = counselor | |
| server.doctor = counselor | |
| app.user_servers[username] = server | |
| return jsonify({"message": "API key set successfully"}), 200 | |
| def download_conversations(): | |
| """ | |
| This API retrieves the user's conversation history based on their username and returns the conversation data to the frontend. | |
| Return: | |
| conversations: List[String] | |
| """ | |
| data = request.get_json() | |
| username = data.get('username') | |
| chatbot_type = data.get('chatbot_type') | |
| if not username: | |
| return jsonify({'error': 'Username not provided'}), 400 | |
| if not chatbot_type: | |
| return jsonify({'error': 'Chatbot type not provided'}), 400 | |
| conversation_dir = os.path.join('user_data', chatbot_type, username, 'conversation') | |
| if not os.path.exists(conversation_dir): | |
| return jsonify({'error': 'User not found or no conversations available'}), 404 | |
| # Llist all files in the conversation directory | |
| files = os.listdir(conversation_dir) | |
| conversations = [] | |
| # read each conversation file and append the conversation data to the list | |
| for file_name in files: | |
| file_path = os.path.join(conversation_dir, file_name) | |
| try: | |
| with open(file_path, 'r') as f: | |
| conversation_data = json.load(f) | |
| # extract the 'conversation' from the JSON | |
| conversation_content = conversation_data.get('conversation', []) | |
| conversations.append({ | |
| 'file_name': file_name, | |
| 'conversation': conversation_content | |
| }) | |
| except Exception as e: | |
| print(f"Error reading {file_name}: {e}") | |
| continue | |
| return jsonify(conversations), 200 | |
| def save_conversation_memory(): | |
| """ | |
| This API saves the current conversation history and memory events to the backend, then synchronizes the data with the Amazon S3 server. | |
| """ | |
| data = request.get_json() | |
| username = data.get('username') | |
| chatbot_type = data.get('chatbot_type') | |
| if not username: | |
| return jsonify({"error": "Username not provided"}), 400 | |
| if not chatbot_type: | |
| return jsonify({"error": "Chatbot type not provided"}), 400 | |
| server = app.user_servers.get(username) | |
| if not server: | |
| return jsonify({"error": "User session not found"}), 400 | |
| # save conversation history | |
| server.set_timestamp() | |
| save_name = f'{server.chapter_name}-{server.topic_name}-{server.timestamp}.json' | |
| save_name = save_name.replace(' ', '-').replace('/', '-') | |
| print(save_name) | |
| # save to local file | |
| local_conv_file_path = os.path.join(server.patient.conv_history_path, save_name) | |
| save_as_json(local_conv_file_path, server.to_dict()) | |
| local_memory_graph_file = os.path.join(server.patient.memory_graph_path, save_name) | |
| # if the chatbot type is 'baseline', create a dummy memory graph file | |
| if chatbot_type == 'baseline': | |
| save_as_json(local_memory_graph_file, {'time_indexed_memory_chain': []}) | |
| else: | |
| # save memory graph | |
| server.doctor.memory_graph.save(local_memory_graph_file) | |
| # Auto-upload to Google Drive if authenticated | |
| try: | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from google_drive_sync import auto_upload_to_drive | |
| # Upload conversation file | |
| auto_upload_to_drive(local_conv_file_path, user_id=username, folder_name="Chatbot_Conversations") | |
| # Upload memory graph file | |
| auto_upload_to_drive(local_memory_graph_file, user_id=username, folder_name="Chatbot_Conversations") | |
| except Exception as e: | |
| # Fail silently if Google Drive upload fails | |
| print(f"Google Drive auto-upload failed (non-critical): {str(e)}") | |
| return jsonify({"message": "Current conversation and memory graph are saved!"}), 200 | |
| def get_response(): | |
| """ | |
| This API retrieves the chatbot's response and returns both the response and updated memory events to the frontend. | |
| Return: | |
| { | |
| doctor_response: String, | |
| memory_events: List[dict] | |
| } | |
| """ | |
| data = request.get_json() | |
| username = data.get('username') | |
| # patient_prompt = data.get('patient_prompt') | |
| # chatbot_type = data.get('chatbot_type') | |
| # if not username or not patient_prompt: | |
| # return jsonify({"error": "Username or patient prompt not provided"}), 400 | |
| # if not chatbot_type: | |
| # return jsonify({"error": "Chatbot type not provided"}), 400 | |
| # if not | |
| # server = app.user_servers.get(username) | |
| # if not server: | |
| # return jsonify({"error": "User session not found"}), 400 | |
| # print(server.patient.patient_id, server.chapter_name, server.topic_name) | |
| # doctor_response = server.get_response(patient_prompt=patient_prompt) | |
| # if chatbot_type == 'baseline': | |
| # memory_events = [] | |
| # else: | |
| # memory_events = server.doctor.memory_graph.to_list() | |
| print('username', username) | |
| server = app.user_servers.get(username) | |
| llm_chatbot = server.doctor | |
| response = llm_chatbot.talk_to_user(data) | |
| return jsonify({'doctor_response': response}) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| # parser.add_argument('--patient-config-path', type=str, | |
| # default='./src/configs/patient_config.yaml') | |
| parser.add_argument('--counselor-config-path', type=str, | |
| default='./src/configs/counselor_config.yaml') | |
| # parser.add_argument('--retriever-config-path', type=str, | |
| # default='./src/configs/retrievers/faiss_retriever.yaml') | |
| parser.add_argument('--store-dir', | |
| type=str, default='./user_data') | |
| # parser.add_argument('--memory-graph-config', default='./src/configs/memory_graph_config.yaml') | |
| # parser.add_argument('--num-conversation-round', type=int, default=30) | |
| args = parser.parse_args() | |
| app = create_app() | |
| configure_routes(app, args) | |
| port = int(os.environ.get('PORT', 8080)) | |
| app.run(port=port, host='0.0.0.0', debug=False) | |
| if __name__ == '__main__': | |
| main() | |