Spaces:
Sleeping
Sleeping
| from llama_index.llms.ollama import Ollama | |
| from llama_index.embeddings.huggingface_optimum import OptimumEmbedding | |
| from llama_index.core import Settings | |
| from llama_index.core.memory import ChatMemoryBuffer | |
| from llama_index.core.storage.chat_store import SimpleChatStore | |
| from llama_index.core import VectorStoreIndex, StorageContext | |
| from llama_index.vector_stores.duckdb import DuckDBVectorStore | |
| from llama_index.core.llms import ChatMessage, MessageRole | |
| import uuid | |
| import os | |
| import json | |
| import nest_asyncio | |
| from datetime import datetime | |
| import copy | |
| import ollama | |
| import gradio as gr | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from gradio.themes import Base | |
| from gradio.events import EditData | |
| from huggingface_hub import whoami | |
| import re | |
| from llama_index.core.evaluation import FaithfulnessEvaluator | |
| from huggingface_hub import snapshot_download | |
| import html | |
| import concurrent.futures | |
| import time | |
| nest_asyncio.apply() | |
| PERSISTENT_DIR = "/data" | |
| FORCE_UPDATE_FLAG = False | |
| VECTOR_STORE_DIR = "./vector_stores" | |
| EMBED_MODEL_PATH = "./datas/bge_onnx" | |
| CONFIG_PATH = "config.json" | |
| DEFAULT_LLM = "hf.co/JatinkInnovision/ComFit4:Q4_K_M" | |
| DEFAULT_VECTOR_STORE = "ComFit" | |
| CONVERSATION_HISTORY_PATH = "./conversation_history" | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful assistant which helps users to understand scientific knowledge " | |
| "about biomechanics of injuries to human bodies." | |
| ) | |
| # HF required | |
| EMBED_MODEL_PATH = os.path.join(PERSISTENT_DIR, "bge_onnx") | |
| VECTOR_STORE_DIR = os.path.join(PERSISTENT_DIR, "vector_stores") | |
| CONVERSATION_HISTORY_PATH = os.path.join(PERSISTENT_DIR, "conversation_history") | |
| token = os.getenv("HF_TOKEN") | |
| dataset_id = os.getenv("DATASET_ID") | |
| def download_data_if_needed(): | |
| global FORCE_UPDATE_FLAG | |
| if not os.path.exists(EMBED_MODEL_PATH) or not os.path.exists(VECTOR_STORE_DIR): | |
| FORCE_UPDATE_FLAG = True | |
| if FORCE_UPDATE_FLAG: | |
| snapshot_download( | |
| repo_id=dataset_id, | |
| repo_type="dataset", | |
| token=token, | |
| local_dir=PERSISTENT_DIR | |
| ) | |
| print("Data downloaded successfully.") | |
| else: | |
| print("Data exists.") | |
| download_data_if_needed() | |
| def process_text_with_think_tags(text): | |
| # Check if the text contains think tags | |
| think_pattern = r'<think>(.*?)</think>' | |
| think_matches = re.findall(think_pattern, text, re.DOTALL) | |
| if think_matches: | |
| # There are think tags present | |
| # Extract the content inside think tags | |
| think_content = think_matches[0] # Taking the first think block | |
| # Remove the think tags part from the original text | |
| remaining_text = re.sub(think_pattern, '', text, flags=re.DOTALL).strip() | |
| # Return both parts separately | |
| return { | |
| 'has_two_parts': True, | |
| 'think_part': think_content, | |
| 'regular_part': remaining_text | |
| } | |
| else: | |
| # No think tags, just one part | |
| return { | |
| 'has_two_parts': False, | |
| 'full_text': text | |
| } | |
| class VectorStoreManager: | |
| def __init__(self): | |
| self.vector_stores = self.initialize_vector_stores() | |
| def initialize_vector_stores(self): | |
| """Scan vector store directory for DuckDB files, supporting nested directories""" | |
| vector_stores = {} | |
| if os.path.exists(VECTOR_STORE_DIR): | |
| # Add default store if it exists | |
| comfit_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb") | |
| if os.path.exists(comfit_path): | |
| vector_stores[DEFAULT_VECTOR_STORE] = { | |
| "path": comfit_path, | |
| "display_name": DEFAULT_VECTOR_STORE, | |
| "data": DuckDBVectorStore.from_local(comfit_path) | |
| } | |
| # Scan for .duckdb files in root directory and subdirectories | |
| for root, dirs, files in os.walk(VECTOR_STORE_DIR): | |
| for file in files: | |
| if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb": | |
| # Skip the default store since we've already handled it | |
| if root == VECTOR_STORE_DIR and file == f"{DEFAULT_VECTOR_STORE}.duckdb": | |
| continue | |
| # Get the full path to the file | |
| file_path = os.path.join(root, file) | |
| # Calculate store_name: combine category and subcategory | |
| rel_path = os.path.relpath(file_path, VECTOR_STORE_DIR) | |
| path_parts = rel_path.split(os.sep) | |
| if len(path_parts) == 1: | |
| # Files in the root directory | |
| store_name = path_parts[0][:-7] # Remove .duckdb | |
| display_name = store_name | |
| else: | |
| # Files in subdirectories | |
| category = path_parts[0] | |
| file_name = path_parts[-1][:-7] # Remove .duckdb | |
| store_name = f"{category}_{file_name}" | |
| display_name = f"{category} - {file_name}" | |
| vector_stores[store_name] = { | |
| "path": file_path, | |
| "display_name": display_name, | |
| "data": DuckDBVectorStore.from_local(file_path) | |
| } | |
| return vector_stores | |
| def get_vector_store_data(self, store_name): | |
| """Get the actual vector store data by store name""" | |
| return self.vector_stores[store_name]["data"] | |
| def get_vector_store_by_display_name(self, display_name): | |
| """Find a vector store by its display name""" | |
| for name, store_info in self.vector_stores.items(): | |
| if store_info["display_name"] == display_name: | |
| return self.vector_stores[name]["data"] | |
| return None | |
| def get_all_store_names(self): | |
| """Get all vector store names""" | |
| return list(self.vector_stores.keys()) | |
| def get_all_display_names(self): | |
| """Get all display names as a list""" | |
| return [store_info["display_name"] for store_info in self.vector_stores.values()] | |
| def get_display_name(self, store_name): | |
| """Get display name for a store name""" | |
| return self.vector_stores[store_name]["display_name"] | |
| def get_name_display_pairs(self): | |
| """Get list of (display_name, store_name) tuples for UI dropdowns""" | |
| return [(v["display_name"], k) for k, v in self.vector_stores.items()] | |
| # Create a global instance | |
| vector_store_manager = VectorStoreManager() | |
| class ComFitChatbot: | |
| def __init__(self): | |
| self.initialize() | |
| def initialize(self): | |
| self.session_manager = SessionManager() | |
| self.embed_model = OptimumEmbedding(folder_name=EMBED_MODEL_PATH) | |
| Settings.embed_model = self.embed_model | |
| self.vector_stores = self.initialize_vector_store() | |
| self.config = self._load_config() | |
| self.llm_options = self._initialize_models() | |
| def get_user_data(self, user_id): | |
| return user_id | |
| def _load_config(self): | |
| """Load model configuration from JSON file""" | |
| try: | |
| with open(CONFIG_PATH, 'r') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Error loading config: {e}") | |
| return {"models": []} | |
| def _initialize_models(self): | |
| """Initialize and verify all models from config""" | |
| config_models = self.config.get("models", []) | |
| available_models = {} | |
| # Get currently available Ollama models | |
| try: | |
| current_models = {m['name']: m['name'] for m in ollama.list()['models']} | |
| print(current_models) | |
| except Exception as e: | |
| print(f"Error fetching current models: {e}") | |
| current_models = {} | |
| # Check each configured model | |
| for model_name in config_models: | |
| if model_name not in current_models: | |
| print(f"Model {model_name} not found locally. Attempting to pull...") | |
| try: | |
| ollama.pull(model_name) | |
| available_models[model_name] = model_name | |
| print(f"Successfully pulled model {model_name}") | |
| except Exception as e: | |
| print(f"Error pulling model {model_name}: {e}") | |
| continue | |
| else: | |
| available_models[model_name] = current_models[model_name] | |
| return available_models | |
| def get_available_models(self): | |
| """Return dictionary of available models""" | |
| return self.available_models | |
| def initialize_vector_store(self): | |
| """Scan vector store directory for DuckDB files, supporting nested directories""" | |
| vector_stores = {} | |
| if os.path.exists(VECTOR_STORE_DIR): | |
| # Add default store if it exists | |
| comfit_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb") | |
| if os.path.exists(comfit_path): | |
| vector_stores[DEFAULT_VECTOR_STORE] = { | |
| "path": comfit_path, | |
| "display_name": DEFAULT_VECTOR_STORE, | |
| "data": DuckDBVectorStore.from_local(comfit_path) | |
| } | |
| # Scan for .duckdb files in root directory and subdirectories | |
| for root, dirs, files in os.walk(VECTOR_STORE_DIR): | |
| for file in files: | |
| if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb": | |
| # Skip the default store since we've already handled it | |
| if root == VECTOR_STORE_DIR and file == f"{DEFAULT_VECTOR_STORE}.duckdb": | |
| continue | |
| # Get the full path to the file | |
| file_path = os.path.join(root, file) | |
| # Calculate store_name: combine category and subcategory | |
| rel_path = os.path.relpath(file_path, VECTOR_STORE_DIR) | |
| path_parts = rel_path.split(os.sep) | |
| if len(path_parts) == 1: | |
| # Files in the root directory | |
| store_name = path_parts[0][:-7] # Remove .duckdb | |
| display_name = store_name | |
| else: | |
| # Files in subdirectories | |
| category = path_parts[0] | |
| file_name = path_parts[-1][:-7] # Remove .duckdb | |
| store_name = f"{category}_{file_name}" | |
| display_name = f"{category} - {file_name}" | |
| vector_stores[store_name] = { | |
| "path": file_path, | |
| "display_name": display_name, | |
| "data": DuckDBVectorStore.from_local(file_path) | |
| } | |
| return vector_stores | |
| def get_vector_store(self, vector_store_name): | |
| return self.vector_stores[vector_store_name]["data"] | |
| class comfitChatEngine: | |
| """ | |
| Manages the core components needed for chat functionality with RAG. | |
| Handles LLM, vector store, memory, chat store, and indexes. | |
| """ | |
| def __init__(self, user_id=None, llm_name=None, vector_store_name=None): | |
| """Initialize the chat engine with all necessary components""" | |
| self.user_id = user_id | |
| self.llm = None | |
| self.llm_name = llm_name | |
| self.vector_store = None | |
| self.vector_store_name = vector_store_name | |
| self.storage_context = None | |
| self.index = None | |
| self.chat_store = None | |
| self.memory = None | |
| self.chat_engine = None | |
| self.rebuild_chat_engine_flag = True | |
| # Conversation metadata management | |
| self.convs_metadata = {} | |
| self.current_conv_id = None | |
| if user_id: | |
| self.initialize_chat_store() | |
| self.initialize_convs_metadata() | |
| # Set initial components if provided | |
| if llm_name: | |
| self.set_llm(llm_name) | |
| if vector_store_name: | |
| self.set_vector_store(vector_store_name) | |
| def initialize_convs_metadata(self): | |
| print(f"Initializing convs metadata for user {self.user_id}") | |
| self.convs_metadata_file_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}_metadata.json") | |
| self.sorted_conversation_list = [] | |
| self.get_convs_metadata() | |
| def get_convs_metadata(self): | |
| if os.path.exists(self.convs_metadata_file_path): | |
| with open(self.convs_metadata_file_path, "r") as f: | |
| self.convs_metadata = json.load(f) | |
| self.sorted_conversation_list = self.get_sorted_conversation_list() | |
| def set_current_conv_id(self, input_value, type="index"): | |
| if len(self.sorted_conversation_list) == 0: | |
| self.current_conv_id = None | |
| self.rebuild_chat_engine_flag = True | |
| return | |
| if type == "index" and self.current_conv_id != self.sorted_conversation_list[input_value]: | |
| self.current_conv_id = self.sorted_conversation_list[input_value] | |
| self.rebuild_chat_engine_flag = True | |
| elif type == "id" and self.current_conv_id != input_value: | |
| self.current_conv_id = input_value | |
| self.rebuild_chat_engine_flag = True | |
| def get_sorted_conversation_list(self): | |
| """ | |
| Returns a list of conversation IDs sorted by update time, | |
| with the most recently updated conversations first. | |
| """ | |
| # Create a list of (conv_id, updated_at) tuples | |
| conv_with_timestamps = [] | |
| for conv_id, metadata in self.convs_metadata.items(): | |
| # Use updated_at timestamp for sorting | |
| if "updated_at" in metadata: | |
| # Convert the ISO timestamp string to datetime object for comparison | |
| update_time = datetime.fromisoformat(metadata["updated_at"]) | |
| conv_with_timestamps.append((conv_id, update_time)) | |
| # Sort by timestamp (descending order - newest first) | |
| sorted_convs = sorted(conv_with_timestamps, key=lambda x: x[1], reverse=True) | |
| # Return just the conversation IDs in the sorted order | |
| return [conv_id for conv_id, _ in sorted_convs] | |
| def get_sorted_conversation_list_for_ui(self): | |
| new_list = [] | |
| for item in self.sorted_conversation_list: | |
| new_list.append([self.convs_metadata[item]["title"]]) | |
| return new_list | |
| def update_convs_metadata(self, conv_id, title=None, create_flag=False): | |
| current_time = datetime.now().isoformat() | |
| if title is not None: | |
| self.convs_metadata[conv_id].update({"title":title}) | |
| self.convs_metadata[conv_id].update({"updated_at":current_time, "llm_name": self.llm_name, "vector_store_name": self.vector_store_name}) | |
| self.sorted_conversation_list = self.get_sorted_conversation_list() | |
| def set_llm(self, llm_name): | |
| self.llm = Ollama( | |
| model=llm_name, | |
| request_timeout=120, | |
| temperature=0.3 | |
| ) | |
| self.set_rebuild_chat_engine_flag(True) | |
| self.llm_name = llm_name | |
| if self.current_conv_id: | |
| self.convs_metadata[self.current_conv_id].update({"llm_name":self.llm_name}) | |
| return self.llm | |
| def set_vector_store(self, vector_store_name): | |
| self.vector_store = vector_store_manager.get_vector_store_by_display_name(vector_store_name) | |
| if self.vector_store: | |
| self.initialize_index() | |
| self.set_rebuild_chat_engine_flag(True) | |
| self.vector_store_name = vector_store_name | |
| if self.current_conv_id: | |
| self.convs_metadata[self.current_conv_id].update({"vector_store_name":self.vector_store_name}) | |
| return self.vector_store | |
| def initialize_index(self): | |
| """Initialize the index using the current vector store""" | |
| if not self.vector_store: | |
| raise ValueError("Vector store must be set before initializing index") | |
| self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store) | |
| self.index = VectorStoreIndex.from_vector_store( | |
| vector_store=self.vector_store, | |
| storage_context=self.storage_context | |
| ) | |
| return self.index | |
| def initialize_chat_store(self): | |
| """Initialize the chat store for the user""" | |
| print(f"Initializing chat store for user {self.user_id}") | |
| chat_store_file_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}.json") | |
| # Ensure directory exists | |
| os.makedirs(os.path.dirname(chat_store_file_path), exist_ok=True) | |
| # Create or load chat store | |
| if not os.path.exists(chat_store_file_path): | |
| self.chat_store = SimpleChatStore() | |
| self.chat_store.persist(persist_path=chat_store_file_path) | |
| else: | |
| self.chat_store = SimpleChatStore.from_persist_path(chat_store_file_path) | |
| self.chat_store_file_path = chat_store_file_path | |
| return self.chat_store | |
| def initialize_memory(self, conversation_id=None): | |
| """Initialize or reinitialize memory with specified conversation ID""" | |
| if not self.chat_store: | |
| raise ValueError("Chat store must be initialized before memory") | |
| print(f"Initializing memory for conversation {conversation_id}") | |
| self.memory = ChatMemoryBuffer.from_defaults( | |
| token_limit=3000, | |
| chat_store=self.chat_store, | |
| chat_store_key=conversation_id | |
| ) | |
| return self.memory | |
| def build_chat_engine(self, conversation_id=None): | |
| """Build the chat engine with all components""" | |
| if not all([self.llm, self.index, self.chat_store]): | |
| raise ValueError("LLM, index, and chat store must be set before building chat engine") | |
| # Initialize or update memory with conversation ID | |
| # if conversation_id and self.current_conv_id != conversation_id: | |
| self.initialize_memory(conversation_id) | |
| self.current_conv_id = conversation_id | |
| # Default system prompt if none provided | |
| # Create the chat engine | |
| self.chat_engine = self.index.as_chat_engine( | |
| chat_mode="context", | |
| llm=self.llm, | |
| memory=self.memory, | |
| system_prompt=SYSTEM_PROMPT | |
| ) | |
| self.set_rebuild_chat_engine_flag(False) | |
| return self.chat_engine | |
| def save_chat_history(self): | |
| """Save chat history to file""" | |
| if self.chat_store and hasattr(self, 'chat_store_file_path'): | |
| self.chat_store.persist(persist_path=self.chat_store_file_path) | |
| def add_message(self, conversation_id, message): | |
| """Add a message to the chat history""" | |
| if self.chat_store: | |
| self.chat_store.add_message(conversation_id, message) | |
| def get_chat_history(self, conversation_id): | |
| """Get chat history for a specific conversation""" | |
| if conversation_id is None: | |
| return [] | |
| if self.chat_store: | |
| return self.chat_store.to_dict()["store"][conversation_id] | |
| return [] | |
| def get_chat_history_for_ui(self, conversation_id): | |
| """Get chat history for a specific conversation""" | |
| if conversation_id is None: | |
| return [] | |
| if self.chat_store: | |
| conv_data = self.chat_store.to_dict()["store"][conversation_id] | |
| conv_data_for_ui = [] | |
| for item in conv_data: | |
| if item["role"] == "user": | |
| conv_data_for_ui.append(item) | |
| else: | |
| content = item["content"] | |
| time_str = None | |
| if "time" in item["additional_kwargs"]: | |
| elapsed_time = item["additional_kwargs"]["time"] | |
| time_str = f"\n\n[Total time: {elapsed_time:.2f}s]" | |
| processed_answer_dict = process_text_with_think_tags(content) | |
| if processed_answer_dict["has_two_parts"]: | |
| think_content = processed_answer_dict["think_part"] | |
| conv_data_for_ui.append({"role": "assistant", "content": think_content, "metadata":{"title":"Thinking...", "status":"done"}}) | |
| remaining_text = processed_answer_dict["regular_part"] | |
| if time_str: | |
| remaining_text += time_str | |
| conv_data_for_ui.append({"role": "assistant", "content": remaining_text}) | |
| else: | |
| item_copy = copy.deepcopy(item) | |
| if time_str: | |
| item_copy["content"] += time_str | |
| conv_data_for_ui.append(item_copy) | |
| return conv_data_for_ui | |
| return [] | |
| def set_rebuild_chat_engine_flag(self, flag): | |
| self.rebuild_chat_engine_flag = flag | |
| def chat(self, message, conversation_id=None): | |
| start_time = time.time() | |
| create_flag = False | |
| if conversation_id is None: | |
| conversation_id = self.create_conversation(message=message) | |
| create_flag = True | |
| print(f"Created new conversation {conversation_id}") | |
| self.set_rebuild_chat_engine_flag(True) | |
| elif self.current_conv_id != conversation_id: | |
| self.set_rebuild_chat_engine_flag(True) | |
| if self.rebuild_chat_engine_flag: | |
| self.chat_engine = self.build_chat_engine(conversation_id) | |
| self.rebuild_chat_engine_flag = False | |
| # Get response | |
| response = self.chat_engine.chat(message) | |
| # answer = response.response | |
| elapsed_time = time.time() - start_time | |
| answer_dict = self.chat_store.get_messages(conversation_id)[-1].dict() | |
| answer_dict['additional_kwargs'].update({"time":elapsed_time}) | |
| new_msg = ChatMessage.model_validate(answer_dict) | |
| self.chat_store.delete_message(conversation_id, -1) | |
| self.chat_store.add_message(conversation_id, new_msg) | |
| self.update_convs_metadata(conversation_id, create_flag=create_flag) | |
| self.save_metadata() | |
| self.save_chat_history() | |
| return response | |
| def create_conversation(self, message=None): | |
| """ | |
| Create a new conversation with metadata | |
| Args: | |
| title: Optional title for the conversation | |
| message: First message to use for generating a title | |
| Returns: | |
| conversation_id: ID of the new conversation | |
| """ | |
| # Generate a new unique conversation ID | |
| conv_id = str(uuid.uuid4()) | |
| # Set as current conversation | |
| self.current_conv_id = conv_id | |
| # Generate title from message if not provided | |
| title = message[:50] + ("..." if len(message) > 50 else "") | |
| # Create timestamp | |
| current_time = datetime.now().isoformat() | |
| # Store metadata with resource information | |
| self.convs_metadata[conv_id] = { | |
| "title": title, | |
| "created_at": current_time, | |
| "updated_at": current_time, | |
| "llm": self.llm_name, | |
| "vector_store": self.vector_store_name, | |
| "message_count": 0 | |
| } | |
| # Initialize chat engine with the new conversation ID | |
| # self.chat_engine = self.build_chat_engine(conv_id) | |
| return conv_id | |
| def update_conversation_metadata(self, conv_id, title=None, increment_message_count=True): | |
| """ | |
| Update conversation metadata | |
| Args: | |
| conv_id: Conversation ID to update | |
| title: Optional new title | |
| increment_message_count: Whether to increment message count | |
| """ | |
| if conv_id not in self.convs_metadata: | |
| return | |
| # Update timestamp | |
| self.convs_metadata[conv_id]["updated_at"] = datetime.now().isoformat() | |
| # Update title if provided | |
| if title: | |
| self.convs_metadata[conv_id]["title"] = title | |
| # Increment message count if requested | |
| if increment_message_count: | |
| self.convs_metadata[conv_id]["message_count"] = self.convs_metadata[conv_id].get("message_count", 0) + 1 | |
| def get_sorted_conversations(self): | |
| """ | |
| Returns a list of conversation IDs sorted by update time, | |
| with the most recently updated conversations first. | |
| """ | |
| # Create a list of (conv_id, updated_at) tuples | |
| conv_with_timestamps = [] | |
| for conv_id, metadata in self.convs_metadata.items(): | |
| # Use updated_at timestamp for sorting | |
| if "updated_at" in metadata: | |
| # Convert the ISO timestamp string to datetime object for comparison | |
| update_time = datetime.fromisoformat(metadata["updated_at"]) | |
| conv_with_timestamps.append((conv_id, update_time)) | |
| # Sort by timestamp (descending order - newest first) | |
| sorted_convs = sorted(conv_with_timestamps, key=lambda x: x[1], reverse=True) | |
| # Return just the conversation IDs in the sorted order | |
| return [conv_id for conv_id, _ in sorted_convs] | |
| def get_conversation_info(self, conv_id): | |
| """Get conversation metadata""" | |
| return self.convs_metadata.get(conv_id, {}) | |
| def save_metadata(self): | |
| """Save conversation metadata to file""" | |
| if hasattr(self, 'chat_store_file_path') and self.user_id: | |
| metadata_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}_metadata.json") | |
| os.makedirs(os.path.dirname(metadata_path), exist_ok=True) | |
| with open(metadata_path, 'w') as f: | |
| json.dump(self.convs_metadata, f) | |
| def load_metadata(self): | |
| """Load conversation metadata from file""" | |
| if self.user_id: | |
| metadata_path = os.path.join(CONVERSATION_HISTORY_PATH, self.user_id, f"{self.user_id}_metadata.json") | |
| if os.path.exists(metadata_path): | |
| try: | |
| with open(metadata_path, 'r') as f: | |
| self.convs_metadata = json.load(f) | |
| except Exception as e: | |
| print(f"Error loading metadata: {e}") | |
| def edit_message(self, index, conversation_id): | |
| if conversation_id is not None: | |
| msg_list = self.chat_store.get_messages(conversation_id) | |
| new_msg_list = msg_list[:index] | |
| self.chat_store.set_messages(conversation_id, new_msg_list) | |
| self.save_metadata() | |
| self.save_chat_history() | |
| def retry_message(self, conversation_id): | |
| if conversation_id is not None: | |
| self.undo_message(conversation_id) | |
| self.save_metadata() | |
| self.save_chat_history() | |
| def undo_message(self, conversation_id): | |
| if conversation_id is not None: | |
| msg_list = self.chat_store.get_messages(conversation_id) | |
| if msg_list[-1].role == MessageRole.ASSISTANT and len(msg_list) > 0: | |
| self.chat_store.delete_last_message(conversation_id) | |
| if msg_list[-1].role == MessageRole.USER and len(msg_list) > 0: | |
| self.chat_store.delete_last_message(conversation_id) | |
| self.update_convs_metadata(conversation_id) | |
| self.save_metadata() | |
| self.save_chat_history() | |
| def delete_conversation(self, conversation_id): | |
| if conversation_id is not None: | |
| self.chat_store.delete_messages(conversation_id) | |
| self.convs_metadata.pop(conversation_id) | |
| self.save_metadata() | |
| self.save_chat_history() | |
| self.sorted_conversation_list = self.get_sorted_conversation_list() | |
| class SessionManager: | |
| def __init__(self): | |
| self.sessions = {} | |
| def create_session(self, user_id=None): | |
| if user_id is None: | |
| return None | |
| print(f"Creating session for user {user_id}") | |
| if user_id not in self.sessions: | |
| self.sessions[user_id] = comfitChatEngine(user_id, llm_name=DEFAULT_LLM, vector_store_name=DEFAULT_VECTOR_STORE) | |
| print(f"Session created for user {user_id}") | |
| return self.sessions[user_id] | |
| class ChatbotUI: | |
| """UI handler for the chatbot application""" | |
| def __init__(self, comfit_chatbot): | |
| """Initialize with a chat engine""" | |
| self.comfit_chatbot = comfit_chatbot | |
| self.init_attr() | |
| def init_attr(self): | |
| self.llm_options = self.comfit_chatbot.llm_options | |
| self.vector_stores = self.comfit_chatbot.vector_stores | |
| # self.vector_stores_options = [(v["display_name"], k) for k, v in self.comfit_chatbot.vector_stores.items()] | |
| # self.init_conversations_history() | |
| # def init_conversations_history(self): | |
| # chat_session = self.comfit_chatbot.session_manager.sessions[USER_NAME] | |
| # self.init_convs_list = chat_session.get_sorted_conversation_list_for_ui() | |
| # if len(self.init_convs_list) > 0: | |
| # self.init_chat_history = chat_session.get_chat_history(chat_session.sorted_conversation_list[0]) | |
| # self.init_convs_index = 0 | |
| # else: | |
| # self.init_chat_history = [] | |
| # self.init_convs_index = None | |
| def create_ui(self): | |
| with gr.Blocks(title="Comfort and Fit Copilot (ComFit Copilot)") as demo: | |
| user_id = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| gr.Markdown("<img src='/gradio_api/file/logo.png' alt='Innovision Logo' height='150' width='390'>") | |
| with gr.Column(scale=1): | |
| login_btn = gr.LoginButton() | |
| with gr.Row(): | |
| gr.Markdown("# Comfort and Fit Copilot (ComFit Copilot)") | |
| # Move model selection to the top row | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| llm_dropdown = gr.Dropdown( | |
| label="Select Language Model", | |
| choices=list(self.llm_options.values()), | |
| value=next(iter(self.llm_options.values()), None) | |
| ) | |
| with gr.Column(scale=3): | |
| vector_dropdown = gr.Dropdown( | |
| label="Comfort and Fit Knowledge Base", | |
| choices=[(v["display_name"]) for k, v in self.vector_stores.items()], | |
| value=next(iter(self.vector_stores.keys()), None) | |
| ) | |
| # Main content with sidebar and chat area | |
| with gr.Row(): | |
| # Left sidebar for conversation history | |
| with gr.Column(scale=1, elem_classes="sidebar"): | |
| new_chat_btn = gr.Button("New Chat", size="sm") | |
| # Hidden textbox for conversation data | |
| conversation_data = gr.Textbox(visible=False) | |
| # Dataset for conversation history | |
| conversation_history = gr.Dataset( | |
| components=[conversation_data], | |
| label="Conversation History", | |
| type="index", | |
| layout="table" | |
| ) | |
| # Main chat area | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| render_markdown=True, | |
| show_copy_button=True, | |
| type="messages", | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Ask me anything", placeholder="Log in to start chatting", interactive=False) | |
| # def get_auth_id(oauth_token: gr.OAuthToken | None) -> str: | |
| # if oauth_token is None: | |
| # return None | |
| # id = whoami(oauth_token.token)['id'] | |
| # return id | |
| def get_auth_id(oauth_token: gr.OAuthToken | None) -> str | None: | |
| print(oauth_token) | |
| if oauth_token is None: | |
| return None | |
| try: | |
| user_info = whoami(oauth_token.token) | |
| print(user_info) | |
| return user_info.get('id') | |
| except Exception as e: | |
| print(f"Authentication failed: {e}") | |
| return None | |
| def add_msg(msg, history): | |
| history.append({"role": "user", "content": msg}) | |
| return history | |
| def chat_with_comfit(history, user_id, conv_idx): | |
| start_time = time.time() | |
| msg = history[-1]["content"] | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| # user_engine.che | |
| # conv_id = None | |
| if conv_idx is not None: | |
| conv_id = user_engine.sorted_conversation_list[conv_idx] | |
| else: | |
| conv_id = None | |
| # if len(history) == 1 and conv_idx is None: | |
| # conv_id = None | |
| response = user_engine.chat(msg, conv_id) | |
| answer = response.response | |
| processed_answer_dict = process_text_with_think_tags(answer) | |
| if processed_answer_dict["has_two_parts"]: | |
| think_content = processed_answer_dict["think_part"] | |
| remaining_text = processed_answer_dict["regular_part"] | |
| # thick_msg = gr.ChatMessage(role="assistant", content="", metadata={"title":"Thinking..."}) | |
| history.append({"role": "assistant", "content": "", "metadata":{"title":"Thinking...", "status":"pending"}}) | |
| # history.append(thick_msg) | |
| for character in think_content: | |
| history[-1]["content"] += character | |
| yield history | |
| elapsed_time = time.time() - start_time | |
| history[-1]["metadata"]["title"] = f"Thinking... [Thinking time: {elapsed_time:.2f}s]" | |
| history[-1]["metadata"]["status"] = "done" | |
| yield history | |
| # Start response time measurement | |
| history.append({"role": "assistant", "content": ""}) | |
| for character in remaining_text: | |
| history[-1]["content"] += character | |
| yield history | |
| elapsed_time = time.time() - start_time | |
| history[-1]["content"] += f"\n\n[Total time: {elapsed_time:.2f}s]" | |
| yield history | |
| else: | |
| full_text = processed_answer_dict["full_text"] | |
| history.append({"role": "assistant", "content": ""}) | |
| for character in full_text: | |
| history[-1]["content"] += character | |
| yield history | |
| elapsed_time = time.time() - start_time | |
| history[-1]["content"] += f"\n\n[Total time: {elapsed_time:.2f}s]" | |
| yield history | |
| def clear_msg(): | |
| return "" | |
| def update_conversation_history(user_id): | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| ui_list = user_engine.get_sorted_conversation_list_for_ui() | |
| if len(ui_list) > 0: | |
| idx = 0 | |
| else: | |
| idx = None | |
| return gr.update(samples=ui_list, value=idx) | |
| msg.submit( | |
| add_msg, | |
| [msg, chatbot], | |
| [chatbot] | |
| ).then( | |
| clear_msg, | |
| None, | |
| [msg] | |
| ).then( | |
| chat_with_comfit, | |
| [chatbot, user_id, conversation_history], | |
| [chatbot] | |
| ).then( | |
| update_conversation_history, | |
| [user_id], | |
| [conversation_history] | |
| ) | |
| def click_to_select_conversation(conversation_history, user_id): | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| user_engine.set_current_conv_id(conversation_history, type="index") | |
| chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id) | |
| llm_name = user_engine.convs_metadata[user_engine.current_conv_id]["llm_name"] | |
| vector_store_name = user_engine.convs_metadata[user_engine.current_conv_id]["vector_store_name"] | |
| return gr.update(value=conversation_history), chat_history, gr.update(value=llm_name), gr.update(value=vector_store_name) | |
| conversation_history.click( | |
| click_to_select_conversation, | |
| [conversation_history, user_id], | |
| [conversation_history, chatbot, llm_dropdown, vector_dropdown] | |
| ) | |
| # msg.submit( | |
| # chat_with_comfit, | |
| # [msg, chatbot, user_id_dropdown], | |
| # [chatbot] | |
| # ) | |
| # msg.submit( | |
| # clear_msg, | |
| # None, | |
| # [msg] | |
| # ).then( | |
| # chat_with_comfit, | |
| # [msg, chatbot, user_id_dropdown], | |
| # [chatbot] | |
| # ) | |
| # clear_btn.click( | |
| # clear_session, | |
| # [session_state], | |
| # [chatbot, session_state], | |
| # queue=False | |
| # ) | |
| def create_session(user_id): | |
| if user_id is None: | |
| return | |
| self.comfit_chatbot.session_manager.create_session(user_id) | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| sorted_conversation_list = user_engine.get_sorted_conversation_list_for_ui() | |
| if len(sorted_conversation_list) > 0: | |
| index = 0 | |
| else: | |
| index = None | |
| update_conversation_history = gr.update(samples=sorted_conversation_list, value=index) | |
| user_engine.set_current_conv_id(0, type="index") | |
| chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id) | |
| if len(chat_history) > 0: | |
| llm_name = user_engine.convs_metadata[user_engine.current_conv_id]["llm_name"] | |
| vector_store_name = user_engine.convs_metadata[user_engine.current_conv_id]["vector_store_name"] | |
| else: | |
| llm_name = user_engine.llm_name | |
| vector_store_name = user_engine.vector_store_name | |
| yield llm_name, vector_store_name, update_conversation_history, chat_history | |
| def activate_chat(user_id): | |
| if user_id is None: | |
| return gr.update(placeholder="Log in to start chatting", interactive=False) | |
| return gr.update(placeholder="",interactive=True) | |
| demo.load( | |
| get_auth_id, | |
| inputs=None, | |
| outputs=[user_id] | |
| ).then( | |
| create_session, | |
| [user_id], | |
| [llm_dropdown, vector_dropdown, conversation_history, chatbot] | |
| ).success( | |
| activate_chat, | |
| [user_id], | |
| [msg] | |
| ) | |
| def update_llm(user_id, llm_name): | |
| if user_id is None: | |
| return | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| user_engine.set_llm(llm_name) | |
| llm_dropdown.change( | |
| update_llm, | |
| [user_id, llm_dropdown], | |
| None | |
| ) | |
| def update_vector_store(user_id, vector_store_name): | |
| if user_id is None: | |
| return | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| user_engine.set_vector_store(vector_store_name) | |
| vector_dropdown.change( | |
| update_vector_store, | |
| [user_id, vector_dropdown], | |
| None | |
| ) | |
| def edit_chat(user_id, history, edit_data: EditData): | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| idx = edit_data.index | |
| # Count how many user messages appear up to this index in the UI history | |
| user_message_count = 0 | |
| for i in range(idx + 1): | |
| if history[i]["role"] == "user": | |
| user_message_count += 1 | |
| # In backend storage, user messages are at positions 0, 2, 4, 6... | |
| # So the backend index is (user_message_count - 1) * 2 | |
| backend_idx = (user_message_count - 1) * 2 | |
| user_engine.edit_message(backend_idx, user_engine.current_conv_id) | |
| history = history[: idx+1] | |
| return history | |
| chatbot.edit( | |
| edit_chat, | |
| [user_id, chatbot], | |
| [chatbot] | |
| ).success( | |
| chat_with_comfit, | |
| [chatbot, user_id, conversation_history], | |
| [chatbot] | |
| ).success( | |
| update_conversation_history, | |
| [user_id], | |
| [conversation_history] | |
| ) | |
| def retry_chat(user_id, history): | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| user_engine.retry_message(user_engine.current_conv_id) | |
| while history[-1]["role"] == "assistant": | |
| history.pop() | |
| yield history | |
| return history | |
| chatbot.retry( | |
| retry_chat, | |
| [user_id, chatbot], | |
| [chatbot] | |
| ).then( | |
| chat_with_comfit, | |
| [chatbot, user_id, conversation_history], | |
| [chatbot] | |
| ).then( | |
| update_conversation_history, | |
| [user_id], | |
| [conversation_history] | |
| ) | |
| def undo_chat(user_id): | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| user_engine.undo_message(user_engine.current_conv_id) | |
| chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id) | |
| return chat_history | |
| chatbot.undo( | |
| undo_chat, | |
| [user_id], | |
| [chatbot] | |
| ) | |
| def clear_conversation(user_id): | |
| user_engine = self.comfit_chatbot.session_manager.sessions[user_id] | |
| user_engine.delete_conversation(user_engine.current_conv_id) | |
| sorted_conversation_list = user_engine.get_sorted_conversation_list_for_ui() | |
| if len(sorted_conversation_list) > 0: | |
| index = 0 | |
| else: | |
| index = None | |
| update_conversation_history = gr.update(samples=sorted_conversation_list, value=index) | |
| user_engine.set_current_conv_id(index, type="index") | |
| chat_history = user_engine.get_chat_history_for_ui(user_engine.current_conv_id) | |
| yield update_conversation_history, chat_history | |
| chatbot.clear( | |
| clear_conversation, | |
| [user_id], | |
| [conversation_history, chatbot] | |
| ) | |
| # Create new conversation button should only clear the chat area, but not create a new conversation yet | |
| def prepare_new_chat(): | |
| print("prepare_new_chat") | |
| return [], gr.update(value=None) | |
| def print_dataset(value): | |
| print(value) | |
| # Create new conversation | |
| new_chat_btn.click( | |
| prepare_new_chat, | |
| None, | |
| [chatbot, conversation_history], | |
| ).then( | |
| print_dataset, | |
| conversation_history, | |
| None | |
| ) | |
| return demo | |
| # Deployment settings | |
| if __name__ == "__main__": | |
| # Check chat store health | |
| # store_health_ok = check_chat_store_health() | |
| # if not store_health_ok: | |
| # print("WARNING: Chat store health check failed! Some functionality may not work correctly.") | |
| # # Run warm-up to pre-initialize resources | |
| # warm_up_resources() | |
| comfit_chatbot = ComFitChatbot() | |
| ui = ChatbotUI(comfit_chatbot) | |
| demo = ui.create_ui() | |
| demo.queue(max_size=10, default_concurrency_limit=3) | |
| demo.launch(allowed_paths=["logo.png"]) | |