SPARKNET / src /rag /store.py
MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Vector Store Interface and ChromaDB Implementation
Provides:
- Abstract VectorStore interface
- ChromaDB implementation with local persistence
- Chunk storage with metadata
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any, Tuple
from pathlib import Path
from pydantic import BaseModel, Field
from loguru import logger
import hashlib
import json
try:
import chromadb
from chromadb.config import Settings
CHROMADB_AVAILABLE = True
except ImportError:
CHROMADB_AVAILABLE = False
logger.warning("ChromaDB not available. Install with: pip install chromadb")
class VectorStoreConfig(BaseModel):
"""Configuration for vector store."""
# Storage
persist_directory: str = Field(
default="./data/vectorstore",
description="Directory for persistent storage"
)
collection_name: str = Field(
default="sparknet_documents",
description="Name of the collection"
)
# Search settings
default_top_k: int = Field(default=5, ge=1, description="Default number of results")
similarity_threshold: float = Field(
default=0.7,
ge=0.0,
le=1.0,
description="Minimum similarity score"
)
# ChromaDB settings
anonymized_telemetry: bool = Field(default=False)
class VectorSearchResult(BaseModel):
"""Result from vector search."""
chunk_id: str
document_id: str
text: str
metadata: Dict[str, Any]
similarity: float
# Source information
page: Optional[int] = None
bbox: Optional[Dict[str, float]] = None
chunk_type: Optional[str] = None
class VectorStore(ABC):
"""Abstract interface for vector stores."""
@abstractmethod
def add_chunks(
self,
chunks: List[Dict[str, Any]],
embeddings: List[List[float]],
) -> List[str]:
"""
Add chunks with embeddings to the store.
Args:
chunks: List of chunk dictionaries with text and metadata
embeddings: Corresponding embeddings
Returns:
List of stored chunk IDs
"""
pass
@abstractmethod
def search(
self,
query_embedding: List[float],
top_k: int = 5,
filters: Optional[Dict[str, Any]] = None,
) -> List[VectorSearchResult]:
"""
Search for similar chunks.
Args:
query_embedding: Query vector
top_k: Number of results
filters: Optional metadata filters
Returns:
List of search results
"""
pass
@abstractmethod
def delete_document(self, document_id: str) -> int:
"""
Delete all chunks for a document.
Args:
document_id: Document ID to delete
Returns:
Number of chunks deleted
"""
pass
@abstractmethod
def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific chunk by ID."""
pass
@abstractmethod
def count(self, document_id: Optional[str] = None) -> int:
"""Count chunks in store, optionally filtered by document."""
pass
class ChromaVectorStore(VectorStore):
"""
ChromaDB implementation of vector store.
Features:
- Local persistent storage
- Metadata filtering
- Similarity search with cosine distance
"""
def __init__(self, config: Optional[VectorStoreConfig] = None):
"""Initialize ChromaDB store."""
if not CHROMADB_AVAILABLE:
raise ImportError("ChromaDB is required. Install with: pip install chromadb")
self.config = config or VectorStoreConfig()
# Ensure persist directory exists
persist_path = Path(self.config.persist_directory)
persist_path.mkdir(parents=True, exist_ok=True)
# Initialize ChromaDB client
self._client = chromadb.PersistentClient(
path=str(persist_path),
settings=Settings(
anonymized_telemetry=self.config.anonymized_telemetry,
)
)
# Get or create collection
self._collection = self._client.get_or_create_collection(
name=self.config.collection_name,
metadata={"hnsw:space": "cosine"}
)
logger.info(
f"ChromaDB initialized: {self.config.collection_name} "
f"({self._collection.count()} chunks)"
)
def add_chunks(
self,
chunks: List[Dict[str, Any]],
embeddings: List[List[float]],
) -> List[str]:
"""Add chunks with embeddings."""
if not chunks:
return []
if len(chunks) != len(embeddings):
raise ValueError(
f"Chunks ({len(chunks)}) and embeddings ({len(embeddings)}) "
"must have same length"
)
ids = []
documents = []
metadatas = []
for chunk in chunks:
# Generate or use existing ID
chunk_id = chunk.get("chunk_id")
if not chunk_id:
# Generate deterministic ID
content = f"{chunk.get('document_id', '')}-{chunk.get('text', '')[:100]}"
chunk_id = hashlib.md5(content.encode()).hexdigest()[:16]
ids.append(chunk_id)
documents.append(chunk.get("text", ""))
# Prepare metadata (ChromaDB only supports primitive types)
metadata = {
"document_id": chunk.get("document_id", ""),
"source_path": chunk.get("source_path", ""),
"chunk_type": chunk.get("chunk_type", "text"),
"page": chunk.get("page", 0),
"sequence_index": chunk.get("sequence_index", 0),
"confidence": chunk.get("confidence", 1.0),
}
# Add bbox as JSON string
if "bbox" in chunk and chunk["bbox"]:
bbox = chunk["bbox"]
if hasattr(bbox, "model_dump"):
metadata["bbox_json"] = json.dumps(bbox.model_dump())
elif isinstance(bbox, dict):
metadata["bbox_json"] = json.dumps(bbox)
metadatas.append(metadata)
# Add to collection
self._collection.add(
ids=ids,
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
)
logger.debug(f"Added {len(ids)} chunks to vector store")
return ids
def search(
self,
query_embedding: List[float],
top_k: int = 5,
filters: Optional[Dict[str, Any]] = None,
) -> List[VectorSearchResult]:
"""Search for similar chunks."""
# Build where clause for filters
where = None
if filters:
where = self._build_where_clause(filters)
# Query
results = self._collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where,
include=["documents", "metadatas", "distances"],
)
# Convert to result objects
search_results = []
if results["ids"] and results["ids"][0]:
for i, chunk_id in enumerate(results["ids"][0]):
# Convert distance to similarity (cosine distance to similarity)
distance = results["distances"][0][i] if results["distances"] else 0
similarity = 1 - distance # Cosine similarity
# Apply threshold
if similarity < self.config.similarity_threshold:
continue
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
# Parse bbox from JSON
bbox = None
if "bbox_json" in metadata:
try:
bbox = json.loads(metadata["bbox_json"])
except:
pass
result = VectorSearchResult(
chunk_id=chunk_id,
document_id=metadata.get("document_id", ""),
text=results["documents"][0][i] if results["documents"] else "",
metadata=metadata,
similarity=similarity,
page=metadata.get("page"),
bbox=bbox,
chunk_type=metadata.get("chunk_type"),
)
search_results.append(result)
return search_results
def _build_where_clause(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Build ChromaDB where clause from filters."""
conditions = []
for key, value in filters.items():
if key == "document_id":
conditions.append({"document_id": {"$eq": value}})
elif key == "chunk_type":
if isinstance(value, list):
conditions.append({"chunk_type": {"$in": value}})
else:
conditions.append({"chunk_type": {"$eq": value}})
elif key == "page":
if isinstance(value, dict):
# Range filter: {"page": {"min": 1, "max": 5}}
if "min" in value:
conditions.append({"page": {"$gte": value["min"]}})
if "max" in value:
conditions.append({"page": {"$lte": value["max"]}})
else:
conditions.append({"page": {"$eq": value}})
elif key == "confidence_min":
conditions.append({"confidence": {"$gte": value}})
if len(conditions) == 0:
return None
elif len(conditions) == 1:
return conditions[0]
else:
return {"$and": conditions}
def delete_document(self, document_id: str) -> int:
"""Delete all chunks for a document."""
# Get chunks for document
results = self._collection.get(
where={"document_id": {"$eq": document_id}},
include=[],
)
if not results["ids"]:
return 0
count = len(results["ids"])
# Delete
self._collection.delete(ids=results["ids"])
logger.info(f"Deleted {count} chunks for document {document_id}")
return count
def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]:
"""Get a specific chunk by ID."""
results = self._collection.get(
ids=[chunk_id],
include=["documents", "metadatas"],
)
if not results["ids"]:
return None
metadata = results["metadatas"][0] if results["metadatas"] else {}
return {
"chunk_id": chunk_id,
"text": results["documents"][0] if results["documents"] else "",
**metadata,
}
def count(self, document_id: Optional[str] = None) -> int:
"""Count chunks in store."""
if document_id:
results = self._collection.get(
where={"document_id": {"$eq": document_id}},
include=[],
)
return len(results["ids"]) if results["ids"] else 0
return self._collection.count()
def list_documents(self) -> List[str]:
"""List all unique document IDs in the store."""
results = self._collection.get(include=["metadatas"])
if not results["metadatas"]:
return []
doc_ids = set()
for meta in results["metadatas"]:
if meta and "document_id" in meta:
doc_ids.add(meta["document_id"])
return list(doc_ids)
# Global instance and factory
_vector_store: Optional[VectorStore] = None
def get_vector_store(
config: Optional[VectorStoreConfig] = None,
store_type: str = "chromadb",
) -> VectorStore:
"""
Get or create singleton vector store.
Args:
config: Store configuration
store_type: Type of store ("chromadb")
Returns:
VectorStore instance
"""
global _vector_store
if _vector_store is None:
if store_type == "chromadb":
_vector_store = ChromaVectorStore(config)
else:
raise ValueError(f"Unknown store type: {store_type}")
return _vector_store
def reset_vector_store():
"""Reset the global vector store instance."""
global _vector_store
_vector_store = None