Spaces:
Running
Running
""" | |
Core chatbot implementation for document question answering. | |
""" | |
import logging | |
import os | |
import time | |
from typing import Optional, Dict, Any, List | |
from tqdm import tqdm | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, StorageContext, load_index_from_storage | |
from llama_index.llms.anthropic import Anthropic | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler | |
import config | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
class Chatbot: | |
"""Chatbot for document question answering using LlamaIndex.""" | |
def __init__(self, config_dict: Optional[Dict[str, Any]] = None): | |
"""Initialize the chatbot with configuration. | |
Args: | |
config_dict: Optional configuration dictionary. If not provided, | |
configuration is loaded from environment variables. | |
""" | |
# Set up basic variables and load configuration | |
self.config = config_dict or config.get_chatbot_config() | |
self.api_key = self._get_api_key() | |
self.index = None | |
self.query_engine = None | |
self.llm = None | |
self.embed_model = None | |
# Set up debugging tools to help track any issues | |
self.debug_handler = LlamaDebugHandler(print_trace_on_end=True) | |
self.callback_manager = CallbackManager([self.debug_handler]) | |
# Set up all the components needed for the chatbot | |
self._initialize_components() | |
def _get_api_key(self) -> str: | |
"""Get API key from environment or config. | |
Returns: | |
API key as string | |
Raises: | |
ValueError: If API key is not found | |
""" | |
api_key = config.ANTHROPIC_API_KEY or self.config.get("api_key") | |
if not api_key: | |
raise ValueError("API key not found in environment or config") | |
return api_key | |
def _initialize_components(self): | |
"""Initialize all components with proper error handling. | |
Sets up the LLM, embedding model, and other settings. | |
Raises: | |
Exception: If component initialization fails | |
""" | |
try: | |
# Set up the language model (Claude) with our settings | |
logger.info("Setting up Claude language model...") | |
self.llm = Anthropic( | |
api_key=self.api_key, | |
model=self.config.get("model", config.LLM_MODEL), | |
temperature=self.config.get("temperature", config.LLM_TEMPERATURE), | |
max_tokens=self.config.get("max_tokens", config.LLM_MAX_TOKENS) | |
) | |
# Set up the model that converts text into numbers (embeddings) | |
logger.info("Setting up text embedding model...") | |
self.embed_model = HuggingFaceEmbedding( | |
model_name=self.config.get("embedding_model", config.EMBEDDING_MODEL), | |
device=self.config.get("device", config.EMBEDDING_DEVICE), | |
embed_batch_size=self.config.get("embed_batch_size", config.EMBEDDING_BATCH_SIZE) | |
) | |
# Configure all the settings for the chatbot | |
logger.info("Configuring chatbot settings...") | |
Settings.embed_model = self.embed_model | |
Settings.text_splitter = SentenceSplitter( | |
chunk_size=self.config.get("chunk_size", config.CHUNK_SIZE), | |
chunk_overlap=self.config.get("chunk_overlap", config.CHUNK_OVERLAP), | |
paragraph_separator="\n\n" | |
) | |
Settings.llm = self.llm | |
Settings.callback_manager = self.callback_manager | |
logger.info("Components initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing components: {e}") | |
raise | |
def load_documents(self, data_dir: str = None) -> List: | |
"""Load documents with retry logic. | |
Args: | |
data_dir: Directory containing documents to load. If None, uses default. | |
Returns: | |
List of loaded documents | |
Raises: | |
Exception: If document loading fails after retries | |
""" | |
# Try to load documents up to 3 times if there's an error | |
max_retries = 3 | |
retry_delay = 1 | |
data_dir = data_dir or config.DATA_DIR | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Loading documents from {data_dir}...") | |
documents = SimpleDirectoryReader(data_dir).load_data() | |
logger.info(f"Loaded {len(documents)} documents") | |
return documents | |
except Exception as e: | |
if attempt < max_retries - 1: | |
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
logger.error(f"Failed to load documents after {max_retries} attempts: {e}") | |
raise | |
def create_index(self, documents, index_dir: str = None): | |
"""Create index with error handling. | |
Args: | |
documents: Documents to index | |
index_dir: Directory to store the index. If None, uses default. | |
Raises: | |
Exception: If index creation fails | |
""" | |
index_dir = index_dir or config.INDEX_DIR | |
try: | |
# Check if index already exists | |
if os.path.exists(os.path.join(index_dir, "index_store.json")): | |
logger.info("Loading existing index...") | |
storage_context = StorageContext.from_defaults(persist_dir=index_dir) | |
self.index = load_index_from_storage(storage_context) | |
logger.info("Index loaded successfully") | |
return | |
# Create a new index if none exists | |
logger.info("Creating new index...") | |
with tqdm(total=1, desc="Creating searchable index") as pbar: | |
self.index = VectorStoreIndex.from_documents(documents) | |
# Save the index | |
self.index.storage_context.persist(persist_dir=index_dir) | |
pbar.update(1) | |
logger.info("Index created and saved successfully") | |
except Exception as e: | |
logger.error(f"Error creating/loading index: {e}") | |
raise | |
def initialize_query_engine(self): | |
"""Initialize query engine with error handling. | |
Sets up the query engine from the index. | |
Raises: | |
Exception: If query engine initialization fails | |
""" | |
try: | |
# Set up the system that will handle questions | |
logger.info("Initializing query engine...") | |
self.query_engine = self.index.as_query_engine() | |
logger.info("Query engine initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing query engine: {e}") | |
raise | |
def query(self, query_text: str) -> str: | |
"""Execute a query with error handling and retries. | |
Args: | |
query_text: The question to answer | |
Returns: | |
Response as string | |
Raises: | |
Exception: If query fails after retries | |
""" | |
# Try to answer questions up to 3 times if there's an error | |
max_retries = 3 | |
retry_delay = 1 | |
for attempt in range(max_retries): | |
try: | |
logger.info(f"Executing query: {query_text}") | |
response = self.query_engine.query(query_text) | |
logger.info("Query executed successfully") | |
return str(response) | |
except Exception as e: | |
if attempt < max_retries - 1: | |
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
else: | |
logger.error(f"Failed to execute query after {max_retries} attempts: {e}") | |
raise | |
def cleanup(self): | |
"""Clean up resources. | |
Performs any necessary cleanup operations. | |
""" | |
try: | |
# Clean up any resources we used | |
logger.info("Cleaning up resources...") | |
logger.info("Cleanup completed successfully") | |
except Exception as e: | |
logger.error(f"Error during cleanup: {e}") |