chat / backend.py
Paul Magee
Implement Conversation Context Management
69f6777
raw
history blame
17.8 kB
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.llms.anthropic import Anthropic
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
from llama_index.core import StorageContext, load_index_from_storage
from llama_index.core.memory import ChatMemoryBuffer
import logging
import os
from dotenv import load_dotenv
import time
from typing import Optional, Dict, Any, List
from tqdm import tqdm
import streamlit as st
# Set up logging to track what the chatbot is doing
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Disable tokenizer parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Create a directory for storing the index
INDEX_DIR = "index"
if not os.path.exists(INDEX_DIR):
os.makedirs(INDEX_DIR)
# Cache LLM model to be reused across sessions
@st.cache_resource
def load_llm_model(api_key, model_name="claude-3-7-sonnet-20250219", temperature=0.1, max_tokens=2048):
"""Load Language Model once and reuse across all sessions."""
logger.info("Loading Claude language model (cached)...")
return Anthropic(
api_key=api_key,
model=model_name,
temperature=temperature,
max_tokens=max_tokens,
timeout=30.0 # Set a 30-second timeout for API requests
)
# Cache embedding model to be reused across sessions
@st.cache_resource
def load_embedding_model(model_name="sentence-transformers/all-MiniLM-L6-v2", device="cpu", batch_size=8):
"""Load Embedding Model once and reuse across all sessions."""
logger.info("Loading text embedding model (cached)...")
# Try to get HuggingFace token from Streamlit secrets if available
try:
hf_token = st.secrets.get("HUGGINGFACE_TOKEN", None)
except:
hf_token = None
logger.info("No HuggingFace token found in secrets, proceeding without authentication")
return HuggingFaceEmbedding(
model_name=model_name,
device=device,
embed_batch_size=batch_size,
token=hf_token # Will be None if not found in secrets
)
# Cache the index loading/creation to be shared across sessions
@st.cache_resource
def load_or_create_index(_documents=None, data_dir="data"):
"""Load existing index or create new one, shared across sessions."""
try:
# Check if index already exists
if os.path.exists(os.path.join(INDEX_DIR, "index.json")) and _documents is None:
logger.info("Loading existing index (cached)...")
storage_context = StorageContext.from_defaults(persist_dir=INDEX_DIR)
index = load_index_from_storage(storage_context)
logger.info("Index loaded successfully")
return index
# Create a new index
logger.info("Creating new index (cached)...")
# If documents weren't provided, load them
if _documents is None:
documents = SimpleDirectoryReader(data_dir).load_data()
else:
documents = _documents
# Ensure we're using HuggingFace embeddings explicitly before creating the index
embed_model = load_embedding_model()
Settings.embed_model = embed_model
with tqdm(total=1, desc="Creating searchable index") as pbar:
index = VectorStoreIndex.from_documents(documents)
# Save the index
index.storage_context.persist(persist_dir=INDEX_DIR)
pbar.update(1)
logger.info("Index created and saved successfully")
return index
except Exception as e:
logger.error(f"Error in load_or_create_index: {e}")
raise
# Get a thread-safe callback manager (cached for reuse)
@st.cache_resource
def get_callback_manager(debug_mode=True):
"""Get the appropriate callback manager based on environment."""
if debug_mode:
debug_handler = LlamaDebugHandler(print_trace_on_end=True)
return CallbackManager([debug_handler])
else:
# Use a lightweight callback handler for production
return CallbackManager([])
# Cache memory buffer for chat history
@st.cache_resource
def get_chat_memory(token_limit=1500):
"""Get a chat memory buffer with a token limit to manage context window."""
logger.info("Creating chat memory buffer...")
return ChatMemoryBuffer.from_defaults(token_limit=token_limit)
class Chatbot:
def __init__(self, config: Optional[Dict[str, Any]] = None, llm=None, embed_model=None, index=None):
"""Initialize the chatbot with configuration."""
# Set up basic variables and load configuration
self.config = config or {}
self.api_key = self._get_api_key()
# Use provided resources or load them using cached functions
self.llm = llm or load_llm_model(
self.api_key,
self.config.get("model", "claude-3-7-sonnet-20250219"),
self.config.get("temperature", 0.1),
self.config.get("max_tokens", 2048)
)
self.embed_model = embed_model or load_embedding_model(
self.config.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2"),
self.config.get("device", "cpu"),
self.config.get("embed_batch_size", 8)
)
self.index = index
self.query_engine = None
self.chat_engine = None
self.chat_memory = get_chat_memory()
# Set up debugging tools to help track any issues
self.callback_manager = get_callback_manager()
# Configure settings
self._configure_settings()
def _get_api_key(self) -> str:
"""Get API key from environment or config."""
# Load the API key from environment variables or config file
load_dotenv()
api_key = os.getenv("ANTHROPIC_API_KEY") or self.config.get("api_key")
if not api_key:
raise ValueError("API key not found in environment or config")
return api_key
def _configure_settings(self):
"""Configure all settings for the chatbot."""
try:
# Configure all the settings for the chatbot
logger.info("Configuring chatbot settings...")
Settings.embed_model = self.embed_model
Settings.text_splitter = SentenceSplitter(
chunk_size=self.config.get("chunk_size", 1024),
chunk_overlap=self.config.get("chunk_overlap", 100),
paragraph_separator="\n\n"
)
Settings.llm = self.llm
Settings.callback_manager = self.callback_manager
logger.info("Components initialized successfully")
except Exception as e:
logger.error(f"Error configuring settings: {e}")
raise
def load_documents(self, data_dir: str = "data"):
"""Load documents with retry logic."""
# Try to load documents up to 3 times if there's an error
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
logger.info(f"Loading documents from {data_dir}...")
documents = SimpleDirectoryReader(data_dir).load_data()
logger.info(f"Loaded {len(documents)} documents")
return documents
except Exception as e:
if attempt < max_retries - 1:
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
logger.error(f"Failed to load documents after {max_retries} attempts: {e}")
raise
def create_index(self, documents):
"""Create or load index with error handling."""
try:
if self.index is None:
self.index = load_or_create_index(documents)
return self.index
except Exception as e:
logger.error(f"Error creating/loading index: {e}")
raise
def update_index(self, new_documents: List):
"""Update existing index with new documents without rebuilding."""
try:
if self.index is None:
logger.warning("No existing index found. Creating new index instead.")
self.create_index(new_documents)
return
logger.info(f"Updating index with {len(new_documents)} new documents...")
with tqdm(total=1, desc="Updating searchable index") as pbar:
# Insert the new documents into the existing index
for doc in new_documents:
self.index.insert(doc)
# Persist the updated index
self.index.storage_context.persist(persist_dir=INDEX_DIR)
pbar.update(1)
logger.info("Index updated and saved successfully")
# Reinitialize engines with updated index
self.initialize_query_engine()
self.initialize_chat_engine()
except Exception as e:
logger.error(f"Error updating index: {e}")
raise
def initialize_query_engine(self):
"""Initialize query engine with error handling."""
try:
# Set up the system that will handle questions
logger.info("Initializing query engine...")
if self.index is None:
# Load or create index if needed
documents = self.load_documents()
self.create_index(documents)
self.query_engine = self.index.as_query_engine()
logger.info("Query engine initialized successfully")
except Exception as e:
logger.error(f"Error initializing query engine: {e}")
raise
def initialize_chat_engine(self):
"""Initialize chat engine with memory for conversation context."""
try:
# Set up the chat engine with memory for conversations
logger.info("Initializing chat engine...")
if self.index is None:
# Load or create index if needed
documents = self.load_documents()
self.create_index(documents)
# Create chat engine with the memory buffer for context
self.chat_engine = self.index.as_chat_engine(
chat_mode="context", # Simpler mode that's more stable
memory=self.chat_memory,
similarity_top_k=3, # Retrieve fewer but more relevant documents
system_prompt=(
"You are a helpful assistant that answers questions based on the provided documents. "
"When answering follow-up questions, use both the conversation history and the retrieved documents. "
"If you don't know the answer, say 'I don't have information about that in my documents.'"
)
)
logger.info("Chat engine initialized successfully")
except Exception as e:
logger.error(f"Error initializing chat engine: {e}")
raise
def query(self, query_text: str) -> str:
"""Execute a query with error handling and retries."""
# Try to answer questions up to 3 times if there's an error
max_retries = 3
retry_delay = 1
# Special handling for very short follow-up queries
if len(query_text.strip().split()) <= 3 and self.chat_memory:
logger.info(f"Detected potential follow-up question: {query_text}")
# Check if the memory has messages (by safely checking memory attributes)
has_messages = False
try:
# Check if memory has chat history in different possible ways
if hasattr(self.chat_memory, "chat_history") and self.chat_memory.chat_history:
has_messages = True
elif hasattr(self.chat_memory, "messages") and self.chat_memory.messages:
has_messages = True
except Exception as e:
logger.warning(f"Error checking chat memory: {e}")
# Only expand generic follow-ups if there's chat history
if has_messages:
# Check if it's a very generic follow-up like "tell me more" or "continue"
generic_followups = ["tell me more", "more", "continue", "go on", "elaborate", "explain more"]
if query_text.lower() in generic_followups or query_text.lower().strip() in generic_followups:
expanded_query = "Please provide more information about the topic we were just discussing."
logger.info(f"Expanded generic follow-up to: {expanded_query}")
query_text = expanded_query
for attempt in range(max_retries):
try:
logger.info(f"Executing query: {query_text}")
print("\nThinking...", end="", flush=True)
# Use chat engine if initialized, otherwise use query engine
if self.chat_engine is not None:
# Make sure we're prioritizing document retrieval
logger.info("Using chat engine with document retrieval")
# Get response from chat engine
response = self.chat_engine.chat(query_text)
# Log sources if available
if hasattr(response, 'source_nodes') and response.source_nodes:
logger.info(f"Retrieved {len(response.source_nodes)} source nodes for context")
else:
logger.warning("No source nodes retrieved for this query")
else:
# Fallback to query engine
logger.info("Using query engine for document retrieval")
response = self.query_engine.query(query_text)
print(" Done!")
logger.info("Query executed successfully")
return str(response)
except Exception as e:
if attempt < max_retries - 1:
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
logger.error(f"Failed to execute query after {max_retries} attempts: {e}")
# Provide a graceful error message to the user
return "I'm having trouble processing your request. Could you please rephrase your question or ask something else?"
def reset_chat_history(self):
"""Reset the chat history to start a new conversation."""
logger.info("Resetting chat history")
self.chat_memory.reset()
if self.chat_engine is not None:
# Reinitialize the chat engine with a fresh memory
self.initialize_chat_engine()
def cleanup(self):
"""Clean up resources."""
try:
# Clean up any resources we used
logger.info("Cleaning up resources...")
# Nothing to clean up since resources are managed by st.cache_resource
logger.info("Cleanup completed successfully")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
# For CLI usage
def main():
# Set up all the configuration settings for the chatbot
config = {
"model": "claude-3-7-sonnet-20250219",
"temperature": 0.1,
"max_tokens": 2048, # Allow for longer responses
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"device": "cpu",
"embed_batch_size": 8,
"chunk_size": 1024,
"chunk_overlap": 100
}
chatbot = None
try:
# Create and set up the chatbot
print("\nInitializing chatbot...")
chatbot = Chatbot(config)
# Load the documents we want to analyze
documents = chatbot.load_documents()
# Create a searchable index from the documents
chatbot.create_index(documents)
# Set up the system that will handle questions
chatbot.initialize_chat_engine()
print("\nChatbot is ready! You can ask questions about your documents.")
print("Type 'exit' to quit or 'clear' to reset chat history.")
print("-" * 50)
while True:
# Get user input
question = input("\nYour question: ").strip()
# Check if user wants to exit
if question.lower() in ['exit', 'quit', 'bye']:
print("\nGoodbye!")
break
# Check if user wants to clear chat history
if question.lower() == 'clear':
chatbot.reset_chat_history()
print("\nChat history has been cleared.")
continue
# Get the answer
answer = chatbot.query(question)
print("\nAnswer:", answer)
except KeyboardInterrupt:
print("\nExiting...")
except Exception as e:
print(f"\nError: {e}")
finally:
if chatbot:
chatbot.cleanup()
if __name__ == "__main__":
main()