ComFit / app.py
InnovisionLLC's picture
Update app.py
0dc01ad verified
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
from llama_index.core import Settings
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.storage.chat_store import SimpleChatStore
from llama_index.core import VectorStoreIndex, StorageContext
from llama_index.vector_stores.duckdb import DuckDBVectorStore
from llama_index.core.llms import ChatMessage, MessageRole
import uuid
import os
import json
import nest_asyncio
from datetime import datetime
import copy
import ollama
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes
from gradio.themes import Base
from gradio.events import EditData
from huggingface_hub import whoami
import re
from llama_index.core.evaluation import FaithfulnessEvaluator
from huggingface_hub import snapshot_download
import html
import concurrent.futures
import time
nest_asyncio.apply()
PERSISTENT_DIR = "/data"
FORCE_UPDATE_FLAG = False
VECTOR_STORE_DIR = "./vector_stores"
EMBED_MODEL_PATH = "./datas/bge_onnx"
CONFIG_PATH = "config.json"
DEFAULT_LLM = "hf.co/JatinkInnovision/ComFit4:Q4_K_M"
DEFAULT_VECTOR_STORE = "ComFit"
CONVERSATION_HISTORY_PATH = "./conversation_history"
SYSTEM_PROMPT = (
"You are a helpful assistant which helps users to understand scientific knowledge "
"about biomechanics of injuries to human bodies."
)
# HF required
EMBED_MODEL_PATH = os.path.join(PERSISTENT_DIR, "bge_onnx")
VECTOR_STORE_DIR = os.path.join(PERSISTENT_DIR, "vector_stores")
CONVERSATION_HISTORY_PATH = os.path.join(PERSISTENT_DIR, "conversation_history")
token = os.getenv("HF_TOKEN")
dataset_id = os.getenv("DATASET_ID")
def download_data_if_needed():
global FORCE_UPDATE_FLAG
if not os.path.exists(EMBED_MODEL_PATH) or not os.path.exists(VECTOR_STORE_DIR):
FORCE_UPDATE_FLAG = True
if FORCE_UPDATE_FLAG:
snapshot_download(
repo_id=dataset_id,
repo_type="dataset",
token=token,
local_dir=PERSISTENT_DIR
)
print("Data downloaded successfully.")
else:
print("Data exists.")
download_data_if_needed()
def process_text_with_think_tags(text):
# Check if the text contains think tags
think_pattern = r'<think>(.*?)</think>'
think_matches = re.findall(think_pattern, text, re.DOTALL)
if think_matches:
# There are think tags present
# Extract the content inside think tags
think_content = think_matches[0] # Taking the first think block
# Remove the think tags part from the original text
remaining_text = re.sub(think_pattern, '', text, flags=re.DOTALL).strip()
# Return both parts separately
return {
'has_two_parts': True,
'think_part': think_content,
'regular_part': remaining_text
}
else:
# No think tags, just one part
return {
'has_two_parts': False,
'full_text': text
}
class VectorStoreManager:
def __init__(self):
self.vector_stores = self.initialize_vector_stores()
def initialize_vector_stores(self):
"""Scan vector store directory for DuckDB files, supporting nested directories"""
vector_stores = {}
if os.path.exists(VECTOR_STORE_DIR):
# Add default store if it exists
comfit_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb")
if os.path.exists(comfit_path):
vector_stores[DEFAULT_VECTOR_STORE] = {
"path": comfit_path,
"display_name": DEFAULT_VECTOR_STORE,
"data": DuckDBVectorStore.from_local(comfit_path)
}
# Scan for .duckdb files in root directory and subdirectories
for root, dirs, files in os.walk(VECTOR_STORE_DIR):
for file in files:
if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb":
# Skip the default store since we've already handled it
if root == VECTOR_STORE_DIR and file == f"{DEFAULT_VECTOR_STORE}.duckdb":
continue
# Get the full path to the file
file_path = os.path.join(root, file)
# Calculate store_name: combine category and subcategory
rel_path = os.path.relpath(file_path, VECTOR_STORE_DIR)
path_parts = rel_path.split(os.sep)
if len(path_parts) == 1:
# Files in the root directory
store_name = path_parts[0][:-7] # Remove .duckdb
display_name = store_name
else:
# Files in subdirectories
category = path_parts[0]
file_name = path_parts[-1][:-7] # Remove .duckdb
store_name = f"{category}_{file_name}"
display_name = f"{category} - {file_name}"
vector_stores[store_name] = {
"path": file_path,
"display_name": display_name,
"data": DuckDBVectorStore.from_local(file_path)
}
return vector_stores
def get_vector_store_data(self, store_name):
"""Get the actual vector store data by store name"""
return self.vector_stores[store_name]["data"]
def get_vector_store_by_display_name(self, display_name):
"""Find a vector store by its display name"""
for name, store_info in self.vector_stores.items():
if store_info["display_name"] == display_name:
return self.vector_stores[name]["data"]
return None
def get_all_store_names(self):
"""Get all vector store names"""
return list(self.vector_stores.keys())
def get_all_display_names(self):
"""Get all display names as a list"""
return [store_info["display_name"] for store_info in self.vector_stores.values()]
def get_display_name(self, store_name):
"""Get display name for a store name"""
return self.vector_stores[store_name]["display_name"]
def get_name_display_pairs(self):
"""Get list of (display_name, store_name) tuples for UI dropdowns"""
return [(v["display_name"], k) for k, v in self.vector_stores.items()]
# Create a global instance
vector_store_manager = VectorStoreManager()
class ComFitChatbot:
def __init__(self):
self.initialize()
def initialize(self):
self.session_manager = SessionManager()
self.embed_model = OptimumEmbedding(folder_name=EMBED_MODEL_PATH)
Settings.embed_model = self.embed_model
self.vector_stores = self.initialize_vector_store()
self.config = self._load_config()
self.llm_options = self._initialize_models()
def get_user_data(self, user_id):
return user_id
def _load_config(self):
"""Load model configuration from JSON file"""
try:
with open(CONFIG_PATH, 'r') as f:
return json.load(f)
except Exception as e:
print(f"Error loading config: {e}")
return {"models": []}
def _initialize_models(self):
"""Initialize and verify all models from config"""
config_models = self.config.get("models", [])
available_models = {}
# Get currently available Ollama models
try:
current_models = {m['name']: m['name'] for m in ollama.list()['models']}
print(current_models)
except Exception as e:
print(f"Error fetching current models: {e}")
current_models = {}
# Check each configured model
for model_name in config_models:
if model_name not in current_models:
print(f"Model {model_name} not found locally. Attempting to pull...")
try:
ollama.pull(model_name)
available_models[model_name] = model_name
print(f"Successfully pulled model {model_name}")
except Exception as e:
print(f"Error pulling model {model_name}: {e}")
continue
else:
available_models[model_name] = current_models[model_name]
return available_models
def get_available_models(self):
"""Return dictionary of available models"""
return self.available_models
def initialize_vector_store(self):
"""Scan vector store directory for DuckDB files, supporting nested directories"""
vector_stores = {}
if os.path.exists(VECTOR_STORE_DIR):
# Add default store if it exists
comfit_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb")
if os.path.exists(comfit_path):
vector_stores[DEFAULT_VECTOR_STORE] = {
"path": comfit_path,
"display_name": DEFAULT_VECTOR_STORE,
"data": DuckDBVectorStore.from_local(comfit_path)
}
# Scan for .duckdb files in root directory and subdirectories
for root, dirs, files in os.walk(VECTOR_STORE_DIR):
for file in files:
if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb":
# Skip the default store since we've already handled it
if root == VECTOR_STORE_DIR and file == f"{DEFAULT_VECTOR_STORE}.duckdb":
continue
# Get the full path to the file
file_path = os.path.join(root, file)
# Calculate store_name: combine category and subcategory
rel_path = os.path.relpath(file_path, VECTOR_STORE_DIR)
path_parts = rel_path.split(os.sep)
if len(path_parts) == 1:
# Files in the root directory
store_name = path_parts[0][:-7] # Remove .duckdb
display_name = store_name
else:
# Files in subdirectories
category = path_parts[0]
file_name = path_parts[-1][:-7] # Remove .duckdb
store_name = f"{category}_{file_name}"
display_name = f"{category} - {file_name}"
vector_stores[store_name] = {
"path": file_path,
"display_name": display_name,
"data": DuckDBVectorStore.from_local(file_path)
}
return vector_stores
def get_vector_store(self, vector_store_name):
return self.vector_stores[vector_store_name]["data"]
class comfitChatEngine:
"""
Manages the core components needed for chat functionality with RAG.
Handles LLM, vector store, memory, chat store, and indexes.
"""
def __init__(self, user_id=None, llm_name=None, vector_store_name=None):
"""Initialize the chat engine with all necessary components"""
self.user_id = user_id
self.llm = None
self.llm_name = llm_name
self.vector_store = None
self.vector_store_name = vector_store_name
self.storage_context = None
self.index = None
self.chat_store = None
self.memory = None
self.chat_engine = None
self.rebuild_chat_engine_flag = True
# Conversation metadata management
self.convs_metadata = {}
self.current_conv_id = None
if user_id:
self.initialize_chat_store()
self.initialize_convs_metadata()
# Set initial components if provided
if llm_name:
self.set_llm(llm_name)
if vector_store_name:
self.set_vector_store(vector_store_name)
def initialize_convs_metadata(self):
print(f"Initializing convs metadata for user {self.user_id}")
self.convs_metadata_file_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}_metadata.json")
self.sorted_conversation_list = []
self.get_convs_metadata()
def get_convs_metadata(self):
if os.path.exists(self.convs_metadata_file_path):
with open(self.convs_metadata_file_path, "r") as f:
self.convs_metadata = json.load(f)
self.sorted_conversation_list = self.get_sorted_conversation_list()
def set_current_conv_id(self, input_value, type="index"):
if len(self.sorted_conversation_list) == 0:
self.current_conv_id = None
self.rebuild_chat_engine_flag = True
return
if type == "index" and self.current_conv_id != self.sorted_conversation_list[input_value]:
self.current_conv_id = self.sorted_conversation_list[input_value]
self.rebuild_chat_engine_flag = True
elif type == "id" and self.current_conv_id != input_value:
self.current_conv_id = input_value
self.rebuild_chat_engine_flag = True
def get_sorted_conversation_list(self):
"""
Returns a list of conversation IDs sorted by update time,
with the most recently updated conversations first.
"""
# Create a list of (conv_id, updated_at) tuples
conv_with_timestamps = []
for conv_id, metadata in self.convs_metadata.items():
# Use updated_at timestamp for sorting
if "updated_at" in metadata:
# Convert the ISO timestamp string to datetime object for comparison
update_time = datetime.fromisoformat(metadata["updated_at"])
conv_with_timestamps.append((conv_id, update_time))
# Sort by timestamp (descending order - newest first)
sorted_convs = sorted(conv_with_timestamps, key=lambda x: x[1], reverse=True)
# Return just the conversation IDs in the sorted order
return [conv_id for conv_id, _ in sorted_convs]
def get_sorted_conversation_list_for_ui(self):
new_list = []
for item in self.sorted_conversation_list:
new_list.append([self.convs_metadata[item]["title"]])
return new_list
def update_convs_metadata(self, conv_id, title=None, create_flag=False):
current_time = datetime.now().isoformat()
if title is not None:
self.convs_metadata[conv_id].update({"title":title})
self.convs_metadata[conv_id].update({"updated_at":current_time, "llm_name": self.llm_name, "vector_store_name": self.vector_store_name})
self.sorted_conversation_list = self.get_sorted_conversation_list()
def set_llm(self, llm_name):
self.llm = Ollama(
model=llm_name,
request_timeout=120,
temperature=0.3
)
self.set_rebuild_chat_engine_flag(True)
self.llm_name = llm_name
if self.current_conv_id:
self.convs_metadata[self.current_conv_id].update({"llm_name":self.llm_name})
return self.llm
def set_vector_store(self, vector_store_name):
self.vector_store = vector_store_manager.get_vector_store_by_display_name(vector_store_name)
if self.vector_store:
self.initialize_index()
self.set_rebuild_chat_engine_flag(True)
self.vector_store_name = vector_store_name
if self.current_conv_id:
self.convs_metadata[self.current_conv_id].update({"vector_store_name":self.vector_store_name})
return self.vector_store
def initialize_index(self):
"""Initialize the index using the current vector store"""
if not self.vector_store:
raise ValueError("Vector store must be set before initializing index")
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
self.index = VectorStoreIndex.from_vector_store(
vector_store=self.vector_store,
storage_context=self.storage_context
)
return self.index
def initialize_chat_store(self):
"""Initialize the chat store for the user"""
print(f"Initializing chat store for user {self.user_id}")
chat_store_file_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}.json")
# Ensure directory exists
os.makedirs(os.path.dirname(chat_store_file_path), exist_ok=True)
# Create or load chat store
if not os.path.exists(chat_store_file_path):
self.chat_store = SimpleChatStore()
self.chat_store.persist(persist_path=chat_store_file_path)
else:
self.chat_store = SimpleChatStore.from_persist_path(chat_store_file_path)
self.chat_store_file_path = chat_store_file_path
return self.chat_store
def initialize_memory(self, conversation_id=None):
"""Initialize or reinitialize memory with specified conversation ID"""
if not self.chat_store:
raise ValueError("Chat store must be initialized before memory")
print(f"Initializing memory for conversation {conversation_id}")
self.memory = ChatMemoryBuffer.from_defaults(
token_limit=3000,
chat_store=self.chat_store,
chat_store_key=conversation_id
)
return self.memory
def build_chat_engine(self, conversation_id=None):
"""Build the chat engine with all components"""
if not all([self.llm, self.index, self.chat_store]):
raise ValueError("LLM, index, and chat store must be set before building chat engine")
# Initialize or update memory with conversation ID
# if conversation_id and self.current_conv_id != conversation_id:
self.initialize_memory(conversation_id)
self.current_conv_id = conversation_id
# Default system prompt if none provided
# Create the chat engine
self.chat_engine = self.index.as_chat_engine(
chat_mode="context",
llm=self.llm,
memory=self.memory,
system_prompt=SYSTEM_PROMPT
)
self.set_rebuild_chat_engine_flag(False)
return self.chat_engine
def save_chat_history(self):
"""Save chat history to file"""
if self.chat_store and hasattr(self, 'chat_store_file_path'):
self.chat_store.persist(persist_path=self.chat_store_file_path)
def add_message(self, conversation_id, message):
"""Add a message to the chat history"""
if self.chat_store:
self.chat_store.add_message(conversation_id, message)
def get_chat_history(self, conversation_id):
"""Get chat history for a specific conversation"""
if conversation_id is None:
return []
if self.chat_store:
return self.chat_store.to_dict()["store"][conversation_id]
return []
def get_chat_history_for_ui(self, conversation_id):
"""Get chat history for a specific conversation"""
if conversation_id is None:
return []
if self.chat_store:
conv_data = self.chat_store.to_dict()["store"][conversation_id]
conv_data_for_ui = []
for item in conv_data:
if item["role"] == "user":
conv_data_for_ui.append(item)
else:
content = item["content"]
time_str = None
if "time" in item["additional_kwargs"]:
elapsed_time = item["additional_kwargs"]["time"]
time_str = f"\n\n[Total time: {elapsed_time:.2f}s]"
processed_answer_dict = process_text_with_think_tags(content)
if processed_answer_dict["has_two_parts"]:
think_content = processed_answer_dict["think_part"]
conv_data_for_ui.append({"role": "assistant", "content": think_content, "metadata":{"title":"Thinking...", "status":"done"}})
remaining_text = processed_answer_dict["regular_part"]
if time_str:
remaining_text += time_str
conv_data_for_ui.append({"role": "assistant", "content": remaining_text})
else:
item_copy = copy.deepcopy(item)
if time_str:
item_copy["content"] += time_str
conv_data_for_ui.append(item_copy)
return conv_data_for_ui
return []
def set_rebuild_chat_engine_flag(self, flag):
self.rebuild_chat_engine_flag = flag
def chat(self, message, conversation_id=None):
start_time = time.time()
create_flag = False
if conversation_id is None:
conversation_id = self.create_conversation(message=message)
create_flag = True
print(f"Created new conversation {conversation_id}")
self.set_rebuild_chat_engine_flag(True)
elif self.current_conv_id != conversation_id:
self.set_rebuild_chat_engine_flag(True)
if self.rebuild_chat_engine_flag:
self.chat_engine = self.build_chat_engine(conversation_id)
self.rebuild_chat_engine_flag = False
# Get response
response = self.chat_engine.chat(message)
# answer = response.response
elapsed_time = time.time() - start_time
answer_dict = self.chat_store.get_messages(conversation_id)[-1].dict()
answer_dict['additional_kwargs'].update({"time":elapsed_time})
new_msg = ChatMessage.model_validate(answer_dict)
self.chat_store.delete_message(conversation_id, -1)
self.chat_store.add_message(conversation_id, new_msg)
self.update_convs_metadata(conversation_id, create_flag=create_flag)
self.save_metadata()
self.save_chat_history()
return response
def create_conversation(self, message=None):
"""
Create a new conversation with metadata
Args:
title: Optional title for the conversation
message: First message to use for generating a title
Returns:
conversation_id: ID of the new conversation
"""
# Generate a new unique conversation ID
conv_id = str(uuid.uuid4())
# Set as current conversation
self.current_conv_id = conv_id
# Generate title from message if not provided
title = message[:50] + ("..." if len(message) > 50 else "")
# Create timestamp
current_time = datetime.now().isoformat()
# Store metadata with resource information
self.convs_metadata[conv_id] = {
"title": title,
"created_at": current_time,
"updated_at": current_time,
"llm": self.llm_name,
"vector_store": self.vector_store_name,
"message_count": 0
}
# Initialize chat engine with the new conversation ID
# self.chat_engine = self.build_chat_engine(conv_id)
return conv_id
def update_conversation_metadata(self, conv_id, title=None, increment_message_count=True):
"""
Update conversation metadata
Args:
conv_id: Conversation ID to update
title: Optional new title
increment_message_count: Whether to increment message count
"""
if conv_id not in self.convs_metadata:
return
# Update timestamp
self.convs_metadata[conv_id]["updated_at"] = datetime.now().isoformat()
# Update title if provided
if title:
self.convs_metadata[conv_id]["title"] = title
# Increment message count if requested
if increment_message_count:
self.convs_metadata[conv_id]["message_count"] = self.convs_metadata[conv_id].get("message_count", 0) + 1
def get_sorted_conversations(self):
"""
Returns a list of conversation IDs sorted by update time,
with the most recently updated conversations first.
"""
# Create a list of (conv_id, updated_at) tuples
conv_with_timestamps = []
for conv_id, metadata in self.convs_metadata.items():
# Use updated_at timestamp for sorting
if "updated_at" in metadata:
# Convert the ISO timestamp string to datetime object for comparison
update_time = datetime.fromisoformat(metadata["updated_at"])
conv_with_timestamps.append((conv_id, update_time))
# Sort by timestamp (descending order - newest first)
sorted_convs = sorted(conv_with_timestamps, key=lambda x: x[1], reverse=True)
# Return just the conversation IDs in the sorted order
return [conv_id for conv_id, _ in sorted_convs]
def get_conversation_info(self, conv_id):
"""Get conversation metadata"""
return self.convs_metadata.get(conv_id, {})
def save_metadata(self):
"""Save conversation metadata to file"""
if hasattr(self, 'chat_store_file_path') and self.user_id:
metadata_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}_metadata.json")
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
with open(metadata_path, 'w') as f:
json.dump(self.convs_metadata, f)
def load_metadata(self):
"""Load conversation metadata from file"""
if self.user_id:
metadata_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}_metadata.json")
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r') as f:
self.convs_metadata = json.load(f)
except Exception as e:
print(f"Error loading metadata: {e}")
def edit_message(self, index, conversation_id):
if conversation_id is not None:
msg_list = self.chat_store.get_messages(conversation_id)
new_msg_list = msg_list[:index]
self.chat_store.set_messages(conversation_id, new_msg_list)
self.save_metadata()
self.save_chat_history()
def retry_message(self, conversation_id):
if conversation_id is not None:
self.undo_message(conversation_id)
self.save_metadata()
self.save_chat_history()
def undo_message(self, conversation_id):
if conversation_id is not None:
msg_list = self.chat_store.get_messages(conversation_id)
if msg_list[-1].role == MessageRole.ASSISTANT and len(msg_list) > 0:
self.chat_store.delete_last_message(conversation_id)
if msg_list[-1].role == MessageRole.USER and len(msg_list) > 0:
self.chat_store.delete_last_message(conversation_id)
self.update_convs_metadata(conversation_id)
self.save_metadata()
self.save_chat_history()
def delete_conversation(self, conversation_id):
if conversation_id is not None:
self.chat_store.delete_messages(conversation_id)
self.convs_metadata.pop(conversation_id)
self.save_metadata()
self.save_chat_history()
self.sorted_conversation_list = self.get_sorted_conversation_list()
class SessionManager:
def __init__(self):
self.sessions = {}
def create_session(self, user_id=None):
if user_id is None:
return None
print(f"Creating session for user {user_id}")
if user_id not in self.sessions:
self.sessions[user_id] = comfitChatEngine(user_id, llm_name=DEFAULT_LLM, vector_store_name=DEFAULT_VECTOR_STORE)
print(f"Session created for user {user_id}")
return self.sessions[user_id]
class ChatbotUI:
"""UI handler for the chatbot application"""
def __init__(self, comfit_chatbot):
"""Initialize with a chat engine"""
self.comfit_chatbot = comfit_chatbot
self.init_attr()
def init_attr(self):
self.llm_options = self.comfit_chatbot.llm_options
self.vector_stores = self.comfit_chatbot.vector_stores
# self.vector_stores_options = [(v["display_name"], k) for k, v in self.comfit_chatbot.vector_stores.items()]
# self.init_conversations_history()
# def init_conversations_history(self):
# chat_session = self.comfit_chatbot.session_manager.sessions[USER_NAME]
# self.init_convs_list = chat_session.get_sorted_conversation_list_for_ui()
# if len(self.init_convs_list) > 0:
# self.init_chat_history = chat_session.get_chat_history(chat_session.sorted_conversation_list[0])
# self.init_convs_index = 0
# else:
# self.init_chat_history = []
# self.init_convs_index = None
def create_ui(self):
with gr.Blocks(title="Comfort and Fit Copilot (ComFit Copilot)") as demo:
user_id = gr.State(None)
with gr.Row():
with gr.Column(scale=6):
gr.Markdown("<img src='/gradio_api/file/logo.png' alt='Innovision Logo' height='150' width='390'>")
with gr.Column(scale=1):
login_btn = gr.LoginButton()
with gr.Row():
gr.Markdown("# Comfort and Fit Copilot (ComFit Copilot)")
# Move model selection to the top row
with gr.Row():
with gr.Column(scale=3):
llm_dropdown = gr.Dropdown(
label="Select Language Model",
choices=list(self.llm_options.values()),
value=next(iter(self.llm_options.values()), None)
)
with gr.Column(scale=3):
vector_dropdown = gr.Dropdown(
label="Comfort and Fit Knowledge Base",
choices=[(v["display_name"]) for k, v in self.vector_stores.items()],
value=next(iter(self.vector_stores.keys()), None)
)
# Main content with sidebar and chat area
with gr.Row():
# Left sidebar for conversation history
with gr.Column(scale=1, elem_classes="sidebar"):
new_chat_btn = gr.Button("New Chat", size="sm")
# Hidden textbox for conversation data
conversation_data = gr.Textbox(visible=False)
# Dataset for conversation history
conversation_history = gr.Dataset(
components=[conversation_data],
label="Conversation History",
type="index",
layout="table"
)
# Main chat area
with gr.Column(scale=3):
chatbot = gr.Chatbot(
height=500,
render_markdown=True,
show_copy_button=True,
type="messages",
)
with gr.Row():
msg = gr.Textbox(label="Ask me anything", placeholder="Log in to start chatting", interactive=False)
# def get_auth_id(oauth_token: gr.OAuthToken | None) -> str:
# if oauth_token is None:
# return None
# id = whoami(oauth_token.token)['id']
# return id
def get_auth_id(oauth_token: gr.OAuthToken | None) -> str | None:
print(oauth_token)
if oauth_token is None:
return None
try:
user_info = whoami(oauth_token.token)
print(user_info)
return user_info.get('id')
except Exception as e:
print(f"Authentication failed: {e}")
return None
def add_msg(msg, history):
history.append({"role": "user", "content": msg})
return history
def chat_with_comfit(history, user_id, conv_idx):
start_time = time.time()
msg = history[-1]["content"]
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
# user_engine.che
# conv_id = None
if conv_idx is not None:
conv_id = user_engine.sorted_conversation_list[conv_idx]
else:
conv_id = None
# if len(history) == 1 and conv_idx is None:
# conv_id = None
response = user_engine.chat(msg, conv_id)
answer = response.response
processed_answer_dict = process_text_with_think_tags(answer)
if processed_answer_dict["has_two_parts"]:
think_content = processed_answer_dict["think_part"]
remaining_text = processed_answer_dict["regular_part"]
# thick_msg = gr.ChatMessage(role="assistant", content="", metadata={"title":"Thinking..."})
history.append({"role": "assistant", "content": "", "metadata":{"title":"Thinking...", "status":"pending"}})
# history.append(thick_msg)
for character in think_content:
history[-1]["content"] += character
yield history
elapsed_time = time.time() - start_time
history[-1]["metadata"]["title"] = f"Thinking... [Thinking time: {elapsed_time:.2f}s]"
history[-1]["metadata"]["status"] = "done"
yield history
# Start response time measurement
history.append({"role": "assistant", "content": ""})
for character in remaining_text:
history[-1]["content"] += character
yield history
elapsed_time = time.time() - start_time
history[-1]["content"] += f"\n\n[Total time: {elapsed_time:.2f}s]"
yield history
else:
full_text = processed_answer_dict["full_text"]
history.append({"role": "assistant", "content": ""})
for character in full_text:
history[-1]["content"] += character
yield history
elapsed_time = time.time() - start_time
history[-1]["content"] += f"\n\n[Total time: {elapsed_time:.2f}s]"
yield history
def clear_msg():
return ""
def update_conversation_history(user_id):
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
ui_list = user_engine.get_sorted_conversation_list_for_ui()
if len(ui_list) > 0:
idx = 0
else:
idx = None
return gr.update(samples=ui_list, value=idx)
msg.submit(
add_msg,
[msg, chatbot],
[chatbot]
).then(
clear_msg,
None,
[msg]
).then(
chat_with_comfit,
[chatbot, user_id, conversation_history],
[chatbot]
).then(
update_conversation_history,
[user_id],
[conversation_history]
)
def click_to_select_conversation(conversation_history, user_id):
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
user_engine.set_current_conv_id(conversation_history, type="index")
chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id)
llm_name = user_engine.convs_metadata[user_engine.current_conv_id]["llm_name"]
vector_store_name = user_engine.convs_metadata[user_engine.current_conv_id]["vector_store_name"]
return gr.update(value=conversation_history), chat_history, gr.update(value=llm_name), gr.update(value=vector_store_name)
conversation_history.click(
click_to_select_conversation,
[conversation_history, user_id],
[conversation_history, chatbot, llm_dropdown, vector_dropdown]
)
# msg.submit(
# chat_with_comfit,
# [msg, chatbot, user_id_dropdown],
# [chatbot]
# )
# msg.submit(
# clear_msg,
# None,
# [msg]
# ).then(
# chat_with_comfit,
# [msg, chatbot, user_id_dropdown],
# [chatbot]
# )
# clear_btn.click(
# clear_session,
# [session_state],
# [chatbot, session_state],
# queue=False
# )
def create_session(user_id):
if user_id is None:
return
self.comfit_chatbot.session_manager.create_session(user_id)
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
sorted_conversation_list = user_engine.get_sorted_conversation_list_for_ui()
if len(sorted_conversation_list) > 0:
index = 0
else:
index = None
update_conversation_history = gr.update(samples=sorted_conversation_list, value=index)
user_engine.set_current_conv_id(0, type="index")
chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id)
if len(chat_history) > 0:
llm_name = user_engine.convs_metadata[user_engine.current_conv_id]["llm_name"]
vector_store_name = user_engine.convs_metadata[user_engine.current_conv_id]["vector_store_name"]
else:
llm_name = user_engine.llm_name
vector_store_name = user_engine.vector_store_name
yield llm_name, vector_store_name, update_conversation_history, chat_history
def activate_chat(user_id):
if user_id is None:
return gr.update(placeholder="Log in to start chatting", interactive=False)
return gr.update(placeholder="",interactive=True)
demo.load(
get_auth_id,
inputs=None,
outputs=[user_id]
).then(
create_session,
[user_id],
[llm_dropdown, vector_dropdown, conversation_history, chatbot]
).success(
activate_chat,
[user_id],
[msg]
)
def update_llm(user_id, llm_name):
if user_id is None:
return
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
user_engine.set_llm(llm_name)
llm_dropdown.change(
update_llm,
[user_id, llm_dropdown],
None
)
def update_vector_store(user_id, vector_store_name):
if user_id is None:
return
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
user_engine.set_vector_store(vector_store_name)
vector_dropdown.change(
update_vector_store,
[user_id, vector_dropdown],
None
)
def edit_chat(user_id, history, edit_data: EditData):
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
idx = edit_data.index
# Count how many user messages appear up to this index in the UI history
user_message_count = 0
for i in range(idx + 1):
if history[i]["role"] == "user":
user_message_count += 1
# In backend storage, user messages are at positions 0, 2, 4, 6...
# So the backend index is (user_message_count - 1) * 2
backend_idx = (user_message_count - 1) * 2
user_engine.edit_message(backend_idx, user_engine.current_conv_id)
history = history[: idx+1]
return history
chatbot.edit(
edit_chat,
[user_id, chatbot],
[chatbot]
).success(
chat_with_comfit,
[chatbot, user_id, conversation_history],
[chatbot]
).success(
update_conversation_history,
[user_id],
[conversation_history]
)
def retry_chat(user_id, history):
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
user_engine.retry_message(user_engine.current_conv_id)
while history[-1]["role"] == "assistant":
history.pop()
yield history
return history
chatbot.retry(
retry_chat,
[user_id, chatbot],
[chatbot]
).then(
chat_with_comfit,
[chatbot, user_id, conversation_history],
[chatbot]
).then(
update_conversation_history,
[user_id],
[conversation_history]
)
def undo_chat(user_id):
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
user_engine.undo_message(user_engine.current_conv_id)
chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id)
return chat_history
chatbot.undo(
undo_chat,
[user_id],
[chatbot]
)
def clear_conversation(user_id):
user_engine = self.comfit_chatbot.session_manager.sessions[user_id]
user_engine.delete_conversation(user_engine.current_conv_id)
sorted_conversation_list = user_engine.get_sorted_conversation_list_for_ui()
if len(sorted_conversation_list) > 0:
index = 0
else:
index = None
update_conversation_history = gr.update(samples=sorted_conversation_list, value=index)
user_engine.set_current_conv_id(index, type="index")
chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id)
yield update_conversation_history, chat_history
chatbot.clear(
clear_conversation,
[user_id],
[conversation_history, chatbot]
)
# Create new conversation button should only clear the chat area, but not create a new conversation yet
def prepare_new_chat():
print("prepare_new_chat")
return [], gr.update(value=None)
def print_dataset(value):
print(value)
# Create new conversation
new_chat_btn.click(
prepare_new_chat,
None,
[chatbot, conversation_history],
).then(
print_dataset,
conversation_history,
None
)
return demo
# Deployment settings
if __name__ == "__main__":
# Check chat store health
# store_health_ok = check_chat_store_health()
# if not store_health_ok:
# print("WARNING: Chat store health check failed! Some functionality may not work correctly.")
# # Run warm-up to pre-initialize resources
# warm_up_resources()
comfit_chatbot = ComFitChatbot()
ui = ChatbotUI(comfit_chatbot)
demo = ui.create_ui()
demo.queue(max_size=10, default_concurrency_limit=3)
demo.launch(allowed_paths=["logo.png"])