audit_assistant / src /vectorstore.py
akryldigital's picture
Pilot (#2)
92633a7 verified
raw
history blame
9.23 kB
"""Vector store management and operations."""
from pathlib import Path
from typing import Dict, Any, List, Optional
import torch
from langchain_qdrant import QdrantVectorStore
from langchain.docstore.document import Document
from langchain_core.embeddings import Embeddings
from sentence_transformers import SentenceTransformer
from langchain_huggingface import HuggingFaceEmbeddings
class MatryoshkaEmbeddings(Embeddings):
"""Custom embeddings class that supports Matryoshka dimension truncation."""
def __init__(self, model_name: str, truncate_dim: int = None, **kwargs):
"""
Initialize Matryoshka embeddings.
Args:
model_name: Name of the model
truncate_dim: Dimension to truncate to (for Matryoshka models)
**kwargs: Additional arguments (ignored for Matryoshka models)
"""
self.model_name = model_name
self.truncate_dim = truncate_dim
if truncate_dim and "matryoshka" in model_name.lower():
# Use SentenceTransformer directly for Matryoshka models
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device)
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
else:
# Use standard HuggingFaceEmbeddings
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents."""
if self.truncate_dim and "matryoshka" in self.model_name.lower():
embeddings = self.model.encode(texts, normalize_embeddings=True)
return embeddings.tolist()
else:
return self.model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Embed query."""
if self.truncate_dim and "matryoshka" in self.model_name.lower():
embedding = self.model.encode([text], normalize_embeddings=True)
return embedding[0].tolist()
else:
return self.model.embed_query(text)
class VectorStoreManager:
"""Manages vector store operations and connections."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize vector store manager.
Args:
config: Configuration dictionary
"""
self.config = config
self.embeddings = self._create_embeddings()
self.vectorstore = None
# Define metadata fields that need payload indexes for filtering
self.metadata_fields = [
("metadata.year", "keyword"),
("metadata.source", "keyword"),
("metadata.filename", "keyword"),
# Add more metadata fields as needed
]
def _create_embeddings(self) -> HuggingFaceEmbeddings:
"""Create embeddings model from configuration."""
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = self.config["retriever"]["model"]
normalize = self.config["retriever"]["normalize"]
model_kwargs = {"device": device}
encode_kwargs = {
"normalize_embeddings": normalize,
"batch_size": 100,
}
# For Matryoshka models, check if we need to truncate dimensions
if "matryoshka" in model_name.lower():
# Check if we have a specific dimension requirement
collection_name = self.config.get("qdrant", {}).get("collection_name", "")
if "modernbert-embed-base-akryl-matryoshka" in collection_name:
# This collection expects 768 dimensions
truncate_dim = 768
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions")
# Use custom MatryoshkaEmbeddings
embeddings = MatryoshkaEmbeddings(
model_name=model_name,
truncate_dim=truncate_dim,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
show_progress=True,
)
return embeddings
# Use standard HuggingFaceEmbeddings for non-Matryoshka models
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
show_progress=True,
)
return embeddings
def ensure_metadata_indexes(self) -> None:
"""
Create payload indexes for all required metadata fields.
This ensures filtering works properly, especially in Qdrant Cloud.
"""
if not self.vectorstore:
return
qdrant_config = self.config["qdrant"]
collection_name = qdrant_config["collection_name"]
for field_name, field_type in self.metadata_fields:
try:
self.vectorstore.client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_type=field_type
)
print(f"Created payload index for {field_name} ({field_type})")
except Exception as e:
# Index might already exist or other error - log but continue
print(f"Index creation for {field_name} ({field_type}): {str(e)}")
def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore:
"""
Connect to existing Qdrant collection.
Args:
force_recreate: If True, recreate the collection if dimension mismatch occurs
Returns:
QdrantVectorStore instance
"""
qdrant_config = self.config["qdrant"]
kwargs_qdrant = {
"url": qdrant_config["url"],
"collection_name": qdrant_config["collection_name"],
"prefer_grpc": qdrant_config.get("prefer_grpc", True),
"api_key": qdrant_config.get("api_key", None),
}
if force_recreate:
kwargs_qdrant["force_recreate"] = True
self.vectorstore = QdrantVectorStore.from_existing_collection(
embedding=self.embeddings,
**kwargs_qdrant
)
# Ensure payload indexes exist for metadata filtering
self.ensure_metadata_indexes()
return self.vectorstore
def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore:
"""
Create new Qdrant collection from documents.
Args:
documents: List of Document objects
Returns:
QdrantVectorStore instance
"""
qdrant_config = self.config["qdrant"]
kwargs_qdrant = {
"url": qdrant_config["url"],
"collection_name": qdrant_config["collection_name"],
"prefer_grpc": qdrant_config.get("prefer_grpc", True),
"api_key": qdrant_config.get("api_key", None),
}
self.vectorstore = QdrantVectorStore.from_documents(
documents=documents,
embedding=self.embeddings,
**kwargs_qdrant
)
# Ensure payload indexes exist for metadata filtering
self.ensure_metadata_indexes()
return self.vectorstore
def delete_collection(self) -> None:
"""
Delete the current Qdrant collection.
Returns:
QdrantVectorStore instance
"""
qdrant_config = self.config["qdrant"]
collection_name = qdrant_config.get("collection_name")
self.vectorstore.client.delete_collection(
collection_name=collection_name
)
return self.vectorstore
def get_vectorstore(self) -> Optional[QdrantVectorStore]:
"""Get current vectorstore instance."""
return self.vectorstore
def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore:
"""
Get local Qdrant vector store (legacy function for compatibility).
Args:
config: Configuration dictionary
Returns:
QdrantVectorStore instance
"""
manager = VectorStoreManager(config)
return manager.connect_to_existing()
def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore:
"""
Create new vector store from documents.
Args:
config: Configuration dictionary
documents: List of Document objects
Returns:
QdrantVectorStore instance
"""
manager = VectorStoreManager(config)
return manager.create_from_documents(documents)
def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings:
"""
Create embeddings model from configuration (legacy function).
Args:
config: Configuration dictionary
Returns:
HuggingFaceEmbeddings instance
"""
manager = VectorStoreManager(config)
return manager.embeddings