Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import json | |
import uuid | |
import concurrent.futures | |
from requests.exceptions import ChunkedEncodingError | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Define the endpoints | |
host = os.getenv("BACKEND_URL") | |
default_endpoint = f"{host}/chat/default" | |
memory_endpoint = f"{host}/chat/memory" | |
history_endpoint = f"{host}/view_history" | |
tarot_cards = [ | |
"The Fool", | |
"The Magician", | |
"The High Priestess", | |
"The Empress", | |
"The Emperor", | |
"The Hierophant", | |
"The Lovers", | |
"The Chariot", | |
"Strength", | |
"The Hermit", | |
"Wheel of Fortune", | |
"Justice", | |
"The Hanged Man", | |
"Death", | |
"Temperance", | |
"The Devil", | |
"The Tower", | |
"The Star", | |
"The Moon", | |
"The Sun", | |
"Judgement", | |
"The World", | |
"Ace of Cups", | |
"Two of Cups", | |
"Three of Cups", | |
"Four of Cups", | |
"Five of Cups", | |
"Six of Cups", | |
"Seven of Cups", | |
"Eight of Cups", | |
"Nine of Cups", | |
"Ten of Cups", | |
"Page of Cups", | |
"Knight of Cups", | |
"Queen of Cups", | |
"King of Cups", | |
"Ace of Pentacles", | |
"Two of Pentacles", | |
"Three of Pentacles", | |
"Four of Pentacles", | |
"Five of Pentacles", | |
"Six of Pentacles", | |
"Seven of Pentacles", | |
"Eight of Pentacles", | |
"Nine of Pentacles", | |
"Ten of Pentacles", | |
"Page of Pentacles", | |
"Knight of Pentacles", | |
"Queen of Pentacles", | |
"King of Pentacles", | |
"Ace of Swords", | |
"Two of Swords", | |
"Three of Swords", | |
"Four of Swords", | |
"Five of Swords", | |
"Six of Swords", | |
"Seven of Swords", | |
"Eight of Swords", | |
"Nine of Swords", | |
"Ten of Swords", | |
"Page of Swords", | |
"Knight of Swords", | |
"Queen of Swords", | |
"King of Swords", | |
"Ace of Wands", | |
"Two of Wands", | |
"Three of Wands", | |
"Four of Wands", | |
"Five of Wands", | |
"Six of Wands", | |
"Seven of Wands", | |
"Eight of Wands", | |
"Nine of Wands", | |
"Ten of Wands", | |
"Page of Wands", | |
"Knight of Wands", | |
"Queen of Wands", | |
"King of Wands", | |
] | |
# Define the request payload structure | |
class ChatRequest: | |
def __init__(self, session_id, messages, model_id, temperature, seer_name, seer_personality): | |
self.session_id = session_id | |
self.messages = messages | |
self.model_id = model_id | |
self.temperature = temperature | |
self.seer_name = seer_name | |
self.seer_personality = seer_personality | |
class ChatRequestWithMemory(ChatRequest): | |
def __init__(self, session_id, messages, model_id, temperature, seer_name, seer_personality, summary_threshold): | |
super().__init__(session_id, messages, model_id, temperature, seer_name, seer_personality) | |
self.summary_threshold = summary_threshold | |
def compare_chatbots(session_id, messages, model_id, temperature, seer_name, seer_personality, summary_threshold, tarot_card): | |
# Convert messages list to a single string | |
# Prepare the payloads | |
print("tarot_card", tarot_card) | |
payload_default = json.dumps({ | |
"session_id": session_id + "_default", | |
"messages": messages, | |
"model_id": model_id, | |
"temperature": temperature, | |
"tarot_card": tarot_card, | |
"seer_name": seer_name, | |
"seer_personality": seer_personality, | |
}) | |
payload_memory = json.dumps({ | |
"session_id": session_id + "_memory", | |
"messages": messages, | |
"model_id": model_id, | |
"temperature": temperature, | |
"seer_name": seer_name, | |
"seer_personality": seer_personality, | |
"tarot_card": tarot_card, | |
"summary_threshold": summary_threshold, | |
}) | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
def call_endpoint(url, payload): | |
try: | |
response = requests.request("POST", url, headers=headers, data=payload) | |
if response.status_code == 200: | |
try: | |
return response.text | |
except requests.exceptions.JSONDecodeError: | |
return "Error: Response is not valid JSON" | |
else: | |
return f"Error: {response.status_code} - {response.text}" | |
except ChunkedEncodingError: | |
return "Error: Response ended prematurely" | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_default = executor.submit(call_endpoint, default_endpoint, payload_default) | |
future_memory = executor.submit(call_endpoint, memory_endpoint, payload_memory) | |
response_default_text = future_default.result() | |
response_memory_text = future_memory.result() | |
return response_default_text, response_memory_text | |
# Function to handle chat interaction | |
def chat_interaction(session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chat_history_default, chat_history_memory, tarot_card): | |
response_default, response_memory = compare_chatbots(session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, tarot_card) | |
chat_history_default.append((message, response_default)) | |
chat_history_memory.append((message, response_memory)) | |
message = "" | |
tarot_card = [] | |
return message, chat_history_default, chat_history_memory, tarot_card | |
# Function to reload session ID and clear chat history | |
def reload_session_and_clear_chat(): | |
new_session_id = str(uuid.uuid4()) | |
new_session_id_memory = f"{new_session_id}_memory" | |
return new_session_id, new_session_id_memory, [], [] | |
# Function to load chat history | |
def load_chat_history(session_id): | |
try: | |
response = requests.get(f"{history_endpoint}?session_id={session_id}") | |
if response.status_code == 200: | |
return response.json() | |
else: | |
return {"error": f"Error: {response.status_code} - {response.text}"} | |
except Exception as e: | |
return {"error": str(e)} | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chatbot Comparison") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Default Chatbot") | |
chatbot_default = gr.Chatbot(elem_id="chatbot_default") | |
with gr.Column(): | |
gr.Markdown("## Memory Chatbot") | |
chatbot_memory = gr.Chatbot(elem_id="chatbot_memory") | |
with gr.Row(): | |
message = gr.Textbox(label="Message", show_label=False, scale=3) | |
submit_button = gr.Button("Submit", scale=1, variant="primary") | |
session_id_default = str(uuid.uuid4()) | |
model_id_choices = [ | |
"openthai-llama3.1-70b", | |
"llama-3.1-8b-instant", | |
"llama-3.1-70b-versatile", | |
"typhoon-v1.5-instruct", | |
"typhoon-v1.5x-70b-instruct", | |
"gemma2-9b-it", | |
] | |
with gr.Accordion("Settings", open=False): | |
reload_button = gr.Button("Reload Session", scale=1, variant="secondary") | |
session_id = gr.Textbox(label="Session ID", value=session_id_default) | |
model_id = gr.Dropdown(label="Model ID", choices=model_id_choices, value=model_id_choices[0]) | |
temperature = gr.Slider(0, 1, step=0.1, label="Temperature", value=0.5) | |
seer_name = gr.Textbox(label="Seer Name", value="แม่หมอแพตตี้") | |
seer_personality = gr.Textbox(label="Seer Personality", value="You are a friend who is always ready to help.") | |
tarot_card = gr.Dropdown(label="Tarot Card", value=[], choices=tarot_cards, multiselect=True) | |
summary_threshold = gr.Number(label="Summary Threshold", value=7) | |
with gr.Accordion("View History of Memory Chatbot", open=False): | |
session_id_memory = gr.Textbox(label="Session ID", value=f"{session_id_default}_memory") | |
load_history_button = gr.Button("Load Chat History", scale=1, variant="secondary") # New button | |
chat_history_json = gr.JSON(label="Chat History") # New JSON field | |
submit_button.click( | |
lambda session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction( | |
session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chatbot_default, chatbot_memory, tarot_card | |
), | |
inputs=[session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chatbot_default, chatbot_memory, tarot_card], | |
outputs=[message, chatbot_default, chatbot_memory, tarot_card] | |
) | |
message.submit( | |
lambda session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chatbot_default, chatbot_memory, tarot_card: chat_interaction( | |
session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chatbot_default, chatbot_memory, tarot_card | |
), | |
inputs=[session_id, message, model_id, temperature, seer_name, seer_personality, summary_threshold, chatbot_default, chatbot_memory, tarot_card], | |
outputs=[message, chatbot_default, chatbot_memory, tarot_card] | |
) | |
reload_button.click( | |
reload_session_and_clear_chat, | |
inputs=[], | |
outputs=[session_id, session_id_memory, chatbot_default, chatbot_memory] | |
) | |
load_history_button.click( | |
load_chat_history, | |
inputs=[session_id_memory], | |
outputs=[chat_history_json] | |
) | |
# Launch the interface | |
demo.launch(show_api=False) |