|
|
"""Vector store management and operations.""" |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, List, Optional |
|
|
|
|
|
|
|
|
import torch |
|
|
from langchain_qdrant import QdrantVectorStore |
|
|
from langchain.docstore.document import Document |
|
|
from langchain_core.embeddings import Embeddings |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
|
|
|
|
|
class MatryoshkaEmbeddings(Embeddings): |
|
|
"""Custom embeddings class that supports Matryoshka dimension truncation.""" |
|
|
|
|
|
def __init__(self, model_name: str, truncate_dim: int = None, **kwargs): |
|
|
""" |
|
|
Initialize Matryoshka embeddings. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the model |
|
|
truncate_dim: Dimension to truncate to (for Matryoshka models) |
|
|
**kwargs: Additional arguments (ignored for Matryoshka models) |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.truncate_dim = truncate_dim |
|
|
|
|
|
if truncate_dim and "matryoshka" in model_name.lower(): |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.model = SentenceTransformer(model_name, truncate_dim=truncate_dim, device=device) |
|
|
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions") |
|
|
else: |
|
|
|
|
|
self.model = HuggingFaceEmbeddings(model_name=model_name, **kwargs) |
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Embed documents.""" |
|
|
if self.truncate_dim and "matryoshka" in self.model_name.lower(): |
|
|
embeddings = self.model.encode(texts, normalize_embeddings=True) |
|
|
return embeddings.tolist() |
|
|
else: |
|
|
return self.model.embed_documents(texts) |
|
|
|
|
|
def embed_query(self, text: str) -> List[float]: |
|
|
"""Embed query.""" |
|
|
if self.truncate_dim and "matryoshka" in self.model_name.lower(): |
|
|
embedding = self.model.encode([text], normalize_embeddings=True) |
|
|
return embedding[0].tolist() |
|
|
else: |
|
|
return self.model.embed_query(text) |
|
|
|
|
|
|
|
|
class VectorStoreManager: |
|
|
"""Manages vector store operations and connections.""" |
|
|
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
|
""" |
|
|
Initialize vector store manager. |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
""" |
|
|
self.config = config |
|
|
self.embeddings = self._create_embeddings() |
|
|
self.vectorstore = None |
|
|
|
|
|
|
|
|
self.metadata_fields = [ |
|
|
("metadata.year", "keyword"), |
|
|
("metadata.source", "keyword"), |
|
|
("metadata.filename", "keyword"), |
|
|
|
|
|
] |
|
|
|
|
|
def _create_embeddings(self) -> HuggingFaceEmbeddings: |
|
|
"""Create embeddings model from configuration.""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name = self.config["retriever"]["model"] |
|
|
normalize = self.config["retriever"]["normalize"] |
|
|
|
|
|
model_kwargs = {"device": device} |
|
|
encode_kwargs = { |
|
|
"normalize_embeddings": normalize, |
|
|
"batch_size": 100, |
|
|
} |
|
|
|
|
|
|
|
|
if "matryoshka" in model_name.lower(): |
|
|
|
|
|
collection_name = self.config.get("qdrant", {}).get("collection_name", "") |
|
|
|
|
|
if "modernbert-embed-base-akryl-matryoshka" in collection_name: |
|
|
|
|
|
truncate_dim = 768 |
|
|
print(f"🔧 Matryoshka model configured for {truncate_dim} dimensions") |
|
|
|
|
|
|
|
|
embeddings = MatryoshkaEmbeddings( |
|
|
model_name=model_name, |
|
|
truncate_dim=truncate_dim, |
|
|
model_kwargs=model_kwargs, |
|
|
encode_kwargs=encode_kwargs, |
|
|
show_progress=True, |
|
|
) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=model_name, |
|
|
model_kwargs=model_kwargs, |
|
|
encode_kwargs=encode_kwargs, |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def ensure_metadata_indexes(self) -> None: |
|
|
""" |
|
|
Create payload indexes for all required metadata fields. |
|
|
This ensures filtering works properly, especially in Qdrant Cloud. |
|
|
""" |
|
|
if not self.vectorstore: |
|
|
return |
|
|
|
|
|
qdrant_config = self.config["qdrant"] |
|
|
collection_name = qdrant_config["collection_name"] |
|
|
|
|
|
for field_name, field_type in self.metadata_fields: |
|
|
try: |
|
|
self.vectorstore.client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name=field_name, |
|
|
field_type=field_type |
|
|
) |
|
|
print(f"Created payload index for {field_name} ({field_type})") |
|
|
except Exception as e: |
|
|
|
|
|
print(f"Index creation for {field_name} ({field_type}): {str(e)}") |
|
|
|
|
|
def connect_to_existing(self, force_recreate: bool = False) -> QdrantVectorStore: |
|
|
""" |
|
|
Connect to existing Qdrant collection. |
|
|
|
|
|
Args: |
|
|
force_recreate: If True, recreate the collection if dimension mismatch occurs |
|
|
|
|
|
Returns: |
|
|
QdrantVectorStore instance |
|
|
""" |
|
|
qdrant_config = self.config["qdrant"] |
|
|
|
|
|
kwargs_qdrant = { |
|
|
"url": qdrant_config["url"], |
|
|
"collection_name": qdrant_config["collection_name"], |
|
|
"prefer_grpc": qdrant_config.get("prefer_grpc", True), |
|
|
"api_key": qdrant_config.get("api_key", None), |
|
|
} |
|
|
|
|
|
if force_recreate: |
|
|
kwargs_qdrant["force_recreate"] = True |
|
|
|
|
|
self.vectorstore = QdrantVectorStore.from_existing_collection( |
|
|
embedding=self.embeddings, |
|
|
**kwargs_qdrant |
|
|
) |
|
|
|
|
|
|
|
|
self.ensure_metadata_indexes() |
|
|
|
|
|
return self.vectorstore |
|
|
|
|
|
def create_from_documents(self, documents: List[Document]) -> QdrantVectorStore: |
|
|
""" |
|
|
Create new Qdrant collection from documents. |
|
|
|
|
|
Args: |
|
|
documents: List of Document objects |
|
|
|
|
|
Returns: |
|
|
QdrantVectorStore instance |
|
|
""" |
|
|
qdrant_config = self.config["qdrant"] |
|
|
|
|
|
kwargs_qdrant = { |
|
|
"url": qdrant_config["url"], |
|
|
"collection_name": qdrant_config["collection_name"], |
|
|
"prefer_grpc": qdrant_config.get("prefer_grpc", True), |
|
|
"api_key": qdrant_config.get("api_key", None), |
|
|
} |
|
|
|
|
|
self.vectorstore = QdrantVectorStore.from_documents( |
|
|
documents=documents, |
|
|
embedding=self.embeddings, |
|
|
**kwargs_qdrant |
|
|
) |
|
|
|
|
|
|
|
|
self.ensure_metadata_indexes() |
|
|
|
|
|
return self.vectorstore |
|
|
|
|
|
def delete_collection(self) -> None: |
|
|
""" |
|
|
Delete the current Qdrant collection. |
|
|
|
|
|
Returns: |
|
|
QdrantVectorStore instance |
|
|
""" |
|
|
qdrant_config = self.config["qdrant"] |
|
|
collection_name = qdrant_config.get("collection_name") |
|
|
|
|
|
self.vectorstore.client.delete_collection( |
|
|
collection_name=collection_name |
|
|
) |
|
|
|
|
|
return self.vectorstore |
|
|
|
|
|
def get_vectorstore(self) -> Optional[QdrantVectorStore]: |
|
|
"""Get current vectorstore instance.""" |
|
|
return self.vectorstore |
|
|
|
|
|
|
|
|
def get_local_qdrant(config: Dict[str, Any]) -> QdrantVectorStore: |
|
|
""" |
|
|
Get local Qdrant vector store (legacy function for compatibility). |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
|
|
|
Returns: |
|
|
QdrantVectorStore instance |
|
|
""" |
|
|
manager = VectorStoreManager(config) |
|
|
return manager.connect_to_existing() |
|
|
|
|
|
|
|
|
def create_vectorstore(config: Dict[str, Any], documents: List[Document]) -> QdrantVectorStore: |
|
|
""" |
|
|
Create new vector store from documents. |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
documents: List of Document objects |
|
|
|
|
|
Returns: |
|
|
QdrantVectorStore instance |
|
|
""" |
|
|
manager = VectorStoreManager(config) |
|
|
return manager.create_from_documents(documents) |
|
|
|
|
|
|
|
|
def get_embeddings_model(config: Dict[str, Any]) -> HuggingFaceEmbeddings: |
|
|
""" |
|
|
Create embeddings model from configuration (legacy function). |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
|
|
|
Returns: |
|
|
HuggingFaceEmbeddings instance |
|
|
""" |
|
|
manager = VectorStoreManager(config) |
|
|
return manager.embeddings |
|
|
|