chatbot-mimic-notes / src /server.py
Jesse Liu
init hf
6a725a4
from src.agents.chat_agent import BaseChatAgent
from src.utils.utils import save_as_json
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import sys
import time
import traceback
import boto3
import argparse
import pytz
import json
from datetime import datetime
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class Server:
def __init__(self) -> None:
self.patient_info = ""
self.conversation = []
self.patient_out = None
self.doctor_out = None
self.patient = None
self.doctor = None
self.conversation_round = 0
self.interview_protocol_index = None
def set_timestamp(self):
self.timestamp = datetime.now(pytz.timezone('US/Eastern')).strftime("%m/%d/%Y-%H:%M:%S")
def set_patient(self, patient):
self.patient = patient
self.patient_info = {
"patient_model_config": patient.agent_config,
}
def set_doctor(self, doctor):
self.doctor = doctor
def set_interview_protocol_index(self, interview_protocol_index):
self.interview_protocol_index = interview_protocol_index
def generate_doctor_response(self):
'''
Must be called after setting the patient and doctor
'''
self.doctor_out = self.doctor.talk_to_user(
self.patient_out, conversations=self.conversation)[0]
return self.doctor_out
def submit_doctor_response(self, response):
self.conversation.append(("doctor", response))
self.doctor.context.add_assistant_prompt(response)
def submit_patient_response(self, response):
self.conversation.append(("patient", response))
self.patient.context.add_assistant_prompt(response)
def get_response(self, patient_prompt):
self.patient_out = patient_prompt
self.submit_patient_response(patient_prompt)
print(f'Round {self.conversation_round} Patient: {patient_prompt}')
if patient_prompt is not None:
self.conversation_round += 1
doctor_out = self.generate_doctor_response()
self.submit_doctor_response(doctor_out)
print(f'Round {self.conversation_round} Doctor: {doctor_out}')
return {"response": doctor_out}
def to_dict(self):
return {
'time_stamp': self.timestamp,
'patient': {
'patient_user_id': self.patient.patient_id,
'patient_info': self.patient_info,
'patient_context': self.patient.context.msg
},
'doctor': {
'doctor_model_config': self.doctor.agent_config,
'doctor_context': self.doctor.context.msg
},
"conversation": self.conversation,
"interview_protocol_index": self.interview_protocol_index
}
def __json__(self):
return self.to_dict()
def reset(self):
self.conversation = []
self.conversation_round = 0
if hasattr(self.doctor, 'reset') and callable(getattr(self.doctor, 'reset')):
self.doctor.reset()
if hasattr(self.patient, 'reset') and callable(getattr(self.patient, 'reset')):
self.patient.reset()
def create_app():
app = Flask(__name__)
CORS(app)
app.user_servers = {}
return app
def configure_routes(app, args):
@app.route('/', methods=['GET'])
def home():
'''
This api will return the default prompts used in the backend, including system prompt, autobiography generation prompt, therapy prompt, and conversation instruction prompt
Return:
{
system_prompt: String,
autobio_generation_prompt: String,
therapy_prompt: String,
conv_instruction_prompt: String
}
'''
return jsonify({
}), 200
@app.route('/api/initialization', methods=['POST'])
def initialization():
'''
This API processes user configurations to initialize conversation states. It specifically accepts the following parameters:
api_key, username, chapter_name, topic_name, and prompts. The API will then:
1. Initialize a Server() instance for managing conversations and sessions.
2. Configure the user-defined prompts.
3. Set up the chapter and topic for the conversation.
4. Configure the save path for both local storage and Amazon S3.
'''
data = request.get_json()
username = data.get('username')
api_key = data.get('api_key')
if api_key and isinstance(api_key, str) and api_key.strip():
os.environ["OPENAI_API_KEY"] = api_key.strip()
# initialize
# server.patient.patient_id = username
counselor = BaseChatAgent(config_path=args.counselor_config_path)
print(counselor)
server = Server()
# server.set_doctor = counselor
server.doctor = counselor
app.user_servers[username] = server
return jsonify({"message": "API key set successfully"}), 200
@app.route('/save/download_conversations', methods=['POST'])
def download_conversations():
"""
This API retrieves the user's conversation history based on their username and returns the conversation data to the frontend.
Return:
conversations: List[String]
"""
data = request.get_json()
username = data.get('username')
chatbot_type = data.get('chatbot_type')
if not username:
return jsonify({'error': 'Username not provided'}), 400
if not chatbot_type:
return jsonify({'error': 'Chatbot type not provided'}), 400
conversation_dir = os.path.join('user_data', chatbot_type, username, 'conversation')
if not os.path.exists(conversation_dir):
return jsonify({'error': 'User not found or no conversations available'}), 404
# Llist all files in the conversation directory
files = os.listdir(conversation_dir)
conversations = []
# read each conversation file and append the conversation data to the list
for file_name in files:
file_path = os.path.join(conversation_dir, file_name)
try:
with open(file_path, 'r') as f:
conversation_data = json.load(f)
# extract the 'conversation' from the JSON
conversation_content = conversation_data.get('conversation', [])
conversations.append({
'file_name': file_name,
'conversation': conversation_content
})
except Exception as e:
print(f"Error reading {file_name}: {e}")
continue
return jsonify(conversations), 200
@app.route('/save/end_and_save', methods=['POST'])
def save_conversation_memory():
"""
This API saves the current conversation history and memory events to the backend, then synchronizes the data with the Amazon S3 server.
"""
data = request.get_json()
username = data.get('username')
chatbot_type = data.get('chatbot_type')
if not username:
return jsonify({"error": "Username not provided"}), 400
if not chatbot_type:
return jsonify({"error": "Chatbot type not provided"}), 400
server = app.user_servers.get(username)
if not server:
return jsonify({"error": "User session not found"}), 400
# save conversation history
server.set_timestamp()
save_name = f'{server.chapter_name}-{server.topic_name}-{server.timestamp}.json'
save_name = save_name.replace(' ', '-').replace('/', '-')
print(save_name)
# save to local file
local_conv_file_path = os.path.join(server.patient.conv_history_path, save_name)
save_as_json(local_conv_file_path, server.to_dict())
local_memory_graph_file = os.path.join(server.patient.memory_graph_path, save_name)
# if the chatbot type is 'baseline', create a dummy memory graph file
if chatbot_type == 'baseline':
save_as_json(local_memory_graph_file, {'time_indexed_memory_chain': []})
else:
# save memory graph
server.doctor.memory_graph.save(local_memory_graph_file)
# Auto-upload to Google Drive if authenticated
try:
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from google_drive_sync import auto_upload_to_drive
# Upload conversation file
auto_upload_to_drive(local_conv_file_path, user_id=username, folder_name="Chatbot_Conversations")
# Upload memory graph file
auto_upload_to_drive(local_memory_graph_file, user_id=username, folder_name="Chatbot_Conversations")
except Exception as e:
# Fail silently if Google Drive upload fails
print(f"Google Drive auto-upload failed (non-critical): {str(e)}")
return jsonify({"message": "Current conversation and memory graph are saved!"}), 200
@app.route('/responses/doctor', methods=['POST'])
def get_response():
"""
This API retrieves the chatbot's response and returns both the response and updated memory events to the frontend.
Return:
{
doctor_response: String,
memory_events: List[dict]
}
"""
data = request.get_json()
username = data.get('username')
# patient_prompt = data.get('patient_prompt')
# chatbot_type = data.get('chatbot_type')
# if not username or not patient_prompt:
# return jsonify({"error": "Username or patient prompt not provided"}), 400
# if not chatbot_type:
# return jsonify({"error": "Chatbot type not provided"}), 400
# if not
# server = app.user_servers.get(username)
# if not server:
# return jsonify({"error": "User session not found"}), 400
# print(server.patient.patient_id, server.chapter_name, server.topic_name)
# doctor_response = server.get_response(patient_prompt=patient_prompt)
# if chatbot_type == 'baseline':
# memory_events = []
# else:
# memory_events = server.doctor.memory_graph.to_list()
print('username', username)
server = app.user_servers.get(username)
llm_chatbot = server.doctor
response = llm_chatbot.talk_to_user(data)
return jsonify({'doctor_response': response})
def main():
parser = argparse.ArgumentParser()
# parser.add_argument('--patient-config-path', type=str,
# default='./src/configs/patient_config.yaml')
parser.add_argument('--counselor-config-path', type=str,
default='./src/configs/counselor_config.yaml')
# parser.add_argument('--retriever-config-path', type=str,
# default='./src/configs/retrievers/faiss_retriever.yaml')
parser.add_argument('--store-dir',
type=str, default='./user_data')
# parser.add_argument('--memory-graph-config', default='./src/configs/memory_graph_config.yaml')
# parser.add_argument('--num-conversation-round', type=int, default=30)
args = parser.parse_args()
app = create_app()
configure_routes(app, args)
port = int(os.environ.get('PORT', 8080))
app.run(port=port, host='0.0.0.0', debug=False)
if __name__ == '__main__':
main()