Spaces:
Running
Running
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings | |
from llama_index.llms.anthropic import Anthropic | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler | |
from llama_index.core import StorageContext, load_index_from_storage | |
from llama_index.core.memory import ChatMemoryBuffer | |
import logging | |
import os | |
from dotenv import load_dotenv | |
import time | |
from typing import Optional, Dict, Any, List | |
from tqdm import tqdm | |
import streamlit as st | |
# Set up logging to track what the chatbot is doing | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Disable tokenizer parallelism warnings | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# Create a directory for storing the index | |
INDEX_DIR = "index" | |
if not os.path.exists(INDEX_DIR): | |
os.makedirs(INDEX_DIR) | |
# Cache LLM model to be reused across sessions | |
def load_llm_model(api_key, model_name="claude-3-7-sonnet-20250219", temperature=0.1, max_tokens=2048): | |
"""Load Language Model once and reuse across all sessions.""" | |
logger.info("Loading Claude language model (cached)...") | |
return Anthropic( | |
api_key=api_key, | |
model=model_name, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
timeout=30.0 # Set a 30-second timeout for API requests | |
) | |
# Cache embedding model to be reused across sessions | |
def load_embedding_model(model_name="sentence-transformers/all-MiniLM-L6-v2", device="cpu", batch_size=8): | |
"""Load Embedding Model once and reuse across all sessions.""" | |
logger.info("Loading text embedding model (cached)...") | |
# Try to get HuggingFace token from Streamlit secrets if available | |
try: | |
hf_token = st.secrets.get("HUGGINGFACE_TOKEN", None) | |
except: | |
hf_token = None | |
logger.info("No HuggingFace token found in secrets, proceeding without authentication") | |
return HuggingFaceEmbedding( | |
model_name=model_name, | |
device=device, | |
embed_batch_size=batch_size, | |
token=hf_token # Will be None if not found in secrets | |
) | |
# Cache the index loading/creation to be shared across sessions | |
def load_or_create_index(_documents=None, data_dir="data"): | |
"""Load existing index or create new one, shared across sessions.""" | |
try: | |
# Check if index already exists | |
if os.path.exists(os.path.join(INDEX_DIR, "index.json")) and _documents is None: | |
logger.info("Loading existing index (cached)...") | |
storage_context = StorageContext.from_defaults(persist_dir=INDEX_DIR) | |
index = load_index_from_storage(storage_context) | |
logger.info("Index loaded successfully") | |
return index | |
# Create a new index | |
logger.info("Creating new index (cached)...") | |
# If documents weren't provided, load them | |
if _documents is None: | |
documents = SimpleDirectoryReader(data_dir).load_data() | |
else: | |
documents = _documents | |
# Ensure we're using HuggingFace embeddings explicitly before creating the index | |
embed_model = load_embedding_model() | |
Settings.embed_model = embed_model | |
with tqdm(total=1, desc="Creating searchable index") as pbar: | |
index = VectorStoreIndex.from_documents(documents) | |
# Save the index | |
index.storage_context.persist(persist_dir=INDEX_DIR) | |
pbar.update(1) | |
logger.info("Index created and saved successfully") | |
return index | |
except Exception as e: | |
logger.error(f"Error in load_or_create_index: {e}") | |
raise | |
# Get a thread-safe callback manager (cached for reuse) | |
def get_callback_manager(debug_mode=True): | |
"""Get the appropriate callback manager based on environment.""" | |
if debug_mode: | |
debug_handler = LlamaDebugHandler(print_trace_on_end=True) | |
return CallbackManager([debug_handler]) | |
else: | |
# Use a lightweight callback handler for production | |
return CallbackManager([]) | |
# Cache memory buffer for chat history | |
def get_chat_memory(token_limit=1500): | |
"""Get a chat memory buffer with a token limit to manage context window.""" | |
logger.info("Creating chat memory buffer...") | |
return ChatMemoryBuffer.from_defaults(token_limit=token_limit) | |
class Chatbot: | |
def __init__(self, config: Optional[Dict[str, Any]] = None, llm=None, embed_model=None, index=None): | |
"""Initialize the chatbot with configuration.""" | |
# Set up basic variables and load configuration | |
self.config = config or {} | |
self.api_key = self._get_api_key() | |
# Use provided resources or load them using cached functions | |
self.llm = llm or load_llm_model( | |
self.api_key, | |
self.config.get("model", "claude-3-7-sonnet-20250219"), | |
self.config.get("temperature", 0.1), | |
self.config.get("max_tokens", 2048) | |
) | |
self.embed_model = embed_model or load_embedding_model( | |
self.config.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2"), | |
self.config.get("device", "cpu"), | |
self.config.get("embed_batch_size", 8) | |
) | |
self.index = index | |
self.query_engine = None | |
self.chat_engine = None | |
self.chat_memory = get_chat_memory() | |
# Set up debugging tools to help track any issues | |
self.callback_manager = get_callback_manager() | |
# Configure settings | |
self._configure_settings() | |
def _get_api_key(self) -> str: | |
"""Get API key from environment or config.""" | |
# Load the API key from environment variables or config file | |
load_dotenv() | |
api_key = os.getenv("ANTHROPIC_API_KEY") or self.config.get("api_key") | |
if not api_key: | |
raise ValueError("API key not found in environment or config") | |
return api_key | |
def _configure_settings(self): | |
"""Configure all settings for the chatbot.""" | |
try: | |
# Configure all the settings for the chatbot | |
logger.info("Configuring chatbot settings...") | |
Settings.embed_model = self.embed_model | |
Settings.text_splitter = SentenceSplitter( | |
chunk_size=self.config.get("chunk_size", 1024), | |
chunk_overlap=self.config.get("chunk_overlap", 100), | |
paragraph_separator="\n\n" | |
) | |
Settings.llm = self.llm | |
Settings.callback_manager = self.callback_manager | |
logger.info("Components initialized successfully") | |
except Exception as e: | |
logger.error(f"Error configuring settings: {e}") | |
raise | |
def load_documents(self, data_dir: str = "data"): | |
"""Load documents with retry logic.""" | |
# Try to load documents up to 3 times if there's an error | |
max_retries = 3 | |
retry_delay = 1 | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Loading documents from {data_dir}...") | |
documents = SimpleDirectoryReader(data_dir).load_data() | |
logger.info(f"Loaded {len(documents)} documents") | |
return documents | |
except Exception as e: | |
if attempt < max_retries - 1: | |
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
logger.error(f"Failed to load documents after {max_retries} attempts: {e}") | |
raise | |
def create_index(self, documents): | |
"""Create or load index with error handling.""" | |
try: | |
if self.index is None: | |
self.index = load_or_create_index(documents) | |
return self.index | |
except Exception as e: | |
logger.error(f"Error creating/loading index: {e}") | |
raise | |
def update_index(self, new_documents: List): | |
"""Update existing index with new documents without rebuilding.""" | |
try: | |
if self.index is None: | |
logger.warning("No existing index found. Creating new index instead.") | |
self.create_index(new_documents) | |
return | |
logger.info(f"Updating index with {len(new_documents)} new documents...") | |
with tqdm(total=1, desc="Updating searchable index") as pbar: | |
# Insert the new documents into the existing index | |
for doc in new_documents: | |
self.index.insert(doc) | |
# Persist the updated index | |
self.index.storage_context.persist(persist_dir=INDEX_DIR) | |
pbar.update(1) | |
logger.info("Index updated and saved successfully") | |
# Reinitialize engines with updated index | |
self.initialize_query_engine() | |
self.initialize_chat_engine() | |
except Exception as e: | |
logger.error(f"Error updating index: {e}") | |
raise | |
def initialize_query_engine(self): | |
"""Initialize query engine with error handling.""" | |
try: | |
# Set up the system that will handle questions | |
logger.info("Initializing query engine...") | |
if self.index is None: | |
# Load or create index if needed | |
documents = self.load_documents() | |
self.create_index(documents) | |
self.query_engine = self.index.as_query_engine() | |
logger.info("Query engine initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing query engine: {e}") | |
raise | |
def initialize_chat_engine(self): | |
"""Initialize chat engine with memory for conversation context.""" | |
try: | |
# Set up the chat engine with memory for conversations | |
logger.info("Initializing chat engine...") | |
if self.index is None: | |
# Load or create index if needed | |
documents = self.load_documents() | |
self.create_index(documents) | |
# Create chat engine with the memory buffer for context | |
self.chat_engine = self.index.as_chat_engine( | |
chat_mode="context", # Simpler mode that's more stable | |
memory=self.chat_memory, | |
similarity_top_k=3, # Retrieve fewer but more relevant documents | |
system_prompt=( | |
"You are a helpful assistant that answers questions based on the provided documents. " | |
"When answering follow-up questions, use both the conversation history and the retrieved documents. " | |
"If you don't know the answer, say 'I don't have information about that in my documents.'" | |
) | |
) | |
logger.info("Chat engine initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing chat engine: {e}") | |
raise | |
def query(self, query_text: str) -> str: | |
"""Execute a query with error handling and retries.""" | |
# Try to answer questions up to 3 times if there's an error | |
max_retries = 3 | |
retry_delay = 1 | |
# Special handling for very short follow-up queries | |
if len(query_text.strip().split()) <= 3 and self.chat_memory: | |
logger.info(f"Detected potential follow-up question: {query_text}") | |
# Check if the memory has messages (by safely checking memory attributes) | |
has_messages = False | |
try: | |
# Check if memory has chat history in different possible ways | |
if hasattr(self.chat_memory, "chat_history") and self.chat_memory.chat_history: | |
has_messages = True | |
elif hasattr(self.chat_memory, "messages") and self.chat_memory.messages: | |
has_messages = True | |
except Exception as e: | |
logger.warning(f"Error checking chat memory: {e}") | |
# Only expand generic follow-ups if there's chat history | |
if has_messages: | |
# Check if it's a very generic follow-up like "tell me more" or "continue" | |
generic_followups = ["tell me more", "more", "continue", "go on", "elaborate", "explain more"] | |
if query_text.lower() in generic_followups or query_text.lower().strip() in generic_followups: | |
expanded_query = "Please provide more information about the topic we were just discussing." | |
logger.info(f"Expanded generic follow-up to: {expanded_query}") | |
query_text = expanded_query | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Executing query: {query_text}") | |
print("\nThinking...", end="", flush=True) | |
# Use chat engine if initialized, otherwise use query engine | |
if self.chat_engine is not None: | |
# Make sure we're prioritizing document retrieval | |
logger.info("Using chat engine with document retrieval") | |
# Get response from chat engine | |
response = self.chat_engine.chat(query_text) | |
# Log sources if available | |
if hasattr(response, 'source_nodes') and response.source_nodes: | |
logger.info(f"Retrieved {len(response.source_nodes)} source nodes for context") | |
else: | |
logger.warning("No source nodes retrieved for this query") | |
else: | |
# Fallback to query engine | |
logger.info("Using query engine for document retrieval") | |
response = self.query_engine.query(query_text) | |
print(" Done!") | |
logger.info("Query executed successfully") | |
return str(response) | |
except Exception as e: | |
if attempt < max_retries - 1: | |
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
logger.error(f"Failed to execute query after {max_retries} attempts: {e}") | |
# Provide a graceful error message to the user | |
return "I'm having trouble processing your request. Could you please rephrase your question or ask something else?" | |
def reset_chat_history(self): | |
"""Reset the chat history to start a new conversation.""" | |
logger.info("Resetting chat history") | |
self.chat_memory.reset() | |
if self.chat_engine is not None: | |
# Reinitialize the chat engine with a fresh memory | |
self.initialize_chat_engine() | |
def cleanup(self): | |
"""Clean up resources.""" | |
try: | |
# Clean up any resources we used | |
logger.info("Cleaning up resources...") | |
# Nothing to clean up since resources are managed by st.cache_resource | |
logger.info("Cleanup completed successfully") | |
except Exception as e: | |
logger.error(f"Error during cleanup: {e}") | |
# For CLI usage | |
def main(): | |
# Set up all the configuration settings for the chatbot | |
config = { | |
"model": "claude-3-7-sonnet-20250219", | |
"temperature": 0.1, | |
"max_tokens": 2048, # Allow for longer responses | |
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2", | |
"device": "cpu", | |
"embed_batch_size": 8, | |
"chunk_size": 1024, | |
"chunk_overlap": 100 | |
} | |
chatbot = None | |
try: | |
# Create and set up the chatbot | |
print("\nInitializing chatbot...") | |
chatbot = Chatbot(config) | |
# Load the documents we want to analyze | |
documents = chatbot.load_documents() | |
# Create a searchable index from the documents | |
chatbot.create_index(documents) | |
# Set up the system that will handle questions | |
chatbot.initialize_chat_engine() | |
print("\nChatbot is ready! You can ask questions about your documents.") | |
print("Type 'exit' to quit or 'clear' to reset chat history.") | |
print("-" * 50) | |
while True: | |
# Get user input | |
question = input("\nYour question: ").strip() | |
# Check if user wants to exit | |
if question.lower() in ['exit', 'quit', 'bye']: | |
print("\nGoodbye!") | |
break | |
# Check if user wants to clear chat history | |
if question.lower() == 'clear': | |
chatbot.reset_chat_history() | |
print("\nChat history has been cleared.") | |
continue | |
# Get the answer | |
answer = chatbot.query(question) | |
print("\nAnswer:", answer) | |
except KeyboardInterrupt: | |
print("\nExiting...") | |
except Exception as e: | |
print(f"\nError: {e}") | |
finally: | |
if chatbot: | |
chatbot.cleanup() | |
if __name__ == "__main__": | |
main() |