| """Vector store management for document embeddings.""" |
|
|
| import os |
| from typing import List, Optional |
| from pathlib import Path |
|
|
| import chromadb |
| from chromadb.config import Settings |
| from llama_index.core import Document, VectorStoreIndex, StorageContext |
| from llama_index.vector_stores.chroma import ChromaVectorStore |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
| from llama_index.core.node_parser import SentenceSplitter |
|
|
| from src.config import config |
|
|
|
|
| class VectorStoreManager: |
| """Manage ChromaDB vector store for document embeddings.""" |
|
|
| def __init__(self): |
| self.collection_name = config.collection_name |
| self.persist_dir = str(config.chroma_persist_dir) |
| self.embedding_model = config.embedding_model |
|
|
| |
| print(f"Loading embedding model: {self.embedding_model}") |
| self.embed_model = HuggingFaceEmbedding( |
| model_name=self.embedding_model, |
| cache_folder="./models" |
| ) |
|
|
| |
| self.chroma_client = chromadb.PersistentClient( |
| path=self.persist_dir, |
| settings=Settings(anonymized_telemetry=False) |
| ) |
|
|
| |
| self.collection = None |
| self.vector_store = None |
| self.index = None |
|
|
| def initialize_collection(self, reset: bool = False) -> None: |
| """Initialize ChromaDB collection.""" |
| if reset: |
| |
| try: |
| self.chroma_client.delete_collection(name=self.collection_name) |
| print(f"Deleted existing collection: {self.collection_name}") |
| except Exception: |
| pass |
|
|
| |
| self.collection = self.chroma_client.get_or_create_collection( |
| name=self.collection_name, |
| metadata={"hnsw:space": "cosine"} |
| ) |
| print(f"Using collection: {self.collection_name}") |
|
|
| |
| self.vector_store = ChromaVectorStore( |
| chroma_collection=self.collection, |
| embedding_function=self.embed_model |
| ) |
|
|
| def create_index(self, documents: List[Document], show_progress: bool = True) -> VectorStoreIndex: |
| """Create vector index from documents.""" |
| if not self.vector_store: |
| self.initialize_collection() |
|
|
| print(f"Creating index from {len(documents)} documents...") |
|
|
| |
| storage_context = StorageContext.from_defaults( |
| vector_store=self.vector_store |
| ) |
|
|
| |
| self.index = VectorStoreIndex.from_documents( |
| documents, |
| storage_context=storage_context, |
| embed_model=self.embed_model, |
| show_progress=show_progress |
| ) |
|
|
| print("Index created successfully!") |
| return self.index |
|
|
| def load_index(self) -> Optional[VectorStoreIndex]: |
| """Load existing index from storage.""" |
| if not self.vector_store: |
| self.initialize_collection() |
|
|
| |
| if self.collection.count() == 0: |
| print("No existing index found in ChromaDB") |
| return None |
|
|
| print(f"Loading index with {self.collection.count()} vectors") |
|
|
| |
| storage_context = StorageContext.from_defaults( |
| vector_store=self.vector_store |
| ) |
|
|
| |
| self.index = VectorStoreIndex.from_vector_store( |
| self.vector_store, |
| storage_context=storage_context, |
| embed_model=self.embed_model |
| ) |
|
|
| return self.index |
|
|
| def get_or_create_index( |
| self, |
| documents: Optional[List[Document]] = None, |
| force_recreate: bool = False |
| ) -> VectorStoreIndex: |
| """Get existing index or create new one.""" |
| if not force_recreate: |
| |
| index = self.load_index() |
| if index: |
| return index |
|
|
| |
| if not documents: |
| raise ValueError("No documents provided for creating index") |
|
|
| self.initialize_collection(reset=True) |
| return self.create_index(documents) |
|
|
| def query(self, query_text: str, top_k: int = None) -> List: |
| """Query the vector store.""" |
| if not self.index: |
| raise ValueError("Index not initialized. Call get_or_create_index first.") |
|
|
| if top_k is None: |
| top_k = config.top_k_retrieval |
|
|
| |
| retriever = self.index.as_retriever( |
| similarity_top_k=top_k |
| ) |
|
|
| |
| nodes = retriever.retrieve(query_text) |
| return nodes |
|
|
| def get_stats(self) -> dict: |
| """Get statistics about the vector store.""" |
| if not self.collection: |
| self.initialize_collection() |
|
|
| stats = { |
| "collection_name": self.collection_name, |
| "persist_dir": self.persist_dir, |
| "embedding_model": self.embedding_model, |
| "num_vectors": self.collection.count(), |
| "metadata": self.collection.metadata |
| } |
|
|
| return stats |
|
|
|
|
| def main(): |
| """Test vector store functionality.""" |
| from src.document_processor import HPMORProcessor |
|
|
| |
| processor = HPMORProcessor() |
| documents = processor.process() |
|
|
| |
| vector_store = VectorStoreManager() |
| index = vector_store.get_or_create_index(documents, force_recreate=True) |
|
|
| |
| stats = vector_store.get_stats() |
| print("\nVector Store Statistics:") |
| for key, value in stats.items(): |
| print(f" {key}: {value}") |
|
|
| |
| test_query = "What is Harry's opinion on magic?" |
| print(f"\nTest query: '{test_query}'") |
| results = vector_store.query(test_query, top_k=3) |
|
|
| print(f"\nFound {len(results)} relevant chunks:") |
| for i, node in enumerate(results, 1): |
| print(f"\n{i}. Score: {node.score:.4f}") |
| print(f" Chapter: {node.metadata.get('chapter_title', 'Unknown')}") |
| print(f" Text preview: {node.text[:200]}...") |
|
|
|
|
| if __name__ == "__main__": |
| main() |