Spaces:
Sleeping
Sleeping
"""Database management for fabric-to-espanso.""" | |
from typing import Optional, List, Dict | |
import logging | |
import time | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models, exceptions | |
from qdrant_client.http.models import Distance, VectorParams, PointStruct | |
from .config import config | |
from .exceptions import DatabaseConnectionError, CollectionError, DatabaseInitializationError, ConfigurationError | |
logger = logging.getLogger('fabric_to_espanso') | |
def get_dense_vector_name(client: QdrantClient, collection_name: str) -> str: | |
""" | |
Get the name of the dense vector from the collection configuration. | |
Args: | |
client: Initialized Qdrant client | |
collection_name: Name of the collection | |
Returns: | |
Name of the dense vector as used in the collection | |
""" | |
try: | |
return list(client.get_collection(collection_name).config.params.vectors.keys())[0] | |
except (IndexError, AttributeError) as e: | |
logger.warning(f"Could not get dense vector name: {e}") | |
# Fallback to a default name | |
return "fast-multilingual-e5-large" | |
def get_sparse_vector_name(client: QdrantClient, collection_name: str) -> str: | |
""" | |
Get the name of the sparse vector from the collection configuration. | |
Args: | |
client: Initialized Qdrant client | |
collection_name: Name of the collection | |
Returns: | |
Name of the sparse vector as used in the collection | |
""" | |
try: | |
return list(client.get_collection(collection_name).config.params.sparse_vectors.keys())[0] | |
except (IndexError, AttributeError) as e: | |
logger.warning(f"Could not get sparse vector name: {e}") | |
# Fallback to a default name | |
return "fast-sparse-splade_pp_en_v1" | |
def create_database_connection(url: Optional[str] = None, api_key: Optional[str] = None) -> QdrantClient: | |
"""Create a database connection. | |
Args: | |
url: Optional database URL. If not provided, uses configuration. | |
Returns: | |
QdrantClient: Connected database client | |
Raises: | |
DatabaseConnectionError: If connection fails after retries | |
""" | |
url = url or config.database.url | |
for attempt in range(config.database.max_retries + 1): | |
try: | |
client = QdrantClient( | |
url=url, | |
timeout=config.database.timeout, | |
api_key=api_key | |
) | |
# Test connection | |
client.get_collections() | |
return client | |
except Exception as e: | |
if attempt == config.database.max_retries: | |
raise DatabaseConnectionError( | |
f"Failed to connect to database at {url} after " | |
f"{config.database.max_retries} attempts: {str(e)}" | |
) from e | |
logger.warning( | |
f"Connection attempt {attempt + 1} failed, retrying in " | |
f"{config.database.retry_delay} seconds..." | |
) | |
time.sleep(config.database.retry_delay) | |
def initialize_qdrant_database( | |
url: str = config.database.url, | |
api_key: Optional[str] = "", | |
collection_name: str = config.embedding.collection_name, | |
use_fastembed: bool = config.embedding.use_fastembed, | |
dense_model: str = config.embedding.dense_model_name, | |
sparse_model: str = config.embedding.sparse_model_name | |
) -> QdrantClient: | |
"""Initialize the Qdrant database for storing markdown file information. | |
Args: | |
collection_name: Name of the collection to initialize | |
use_fastembed: Whether to use FastEmbed for embeddings | |
embed_model: Name of the embedding model to use | |
Returns: | |
QdrantClient: Initialized database client | |
Raises: | |
DatabaseInitializationError: If initialization fails | |
CollectionError: If collection creation fails | |
ConfigurationError: If configuration is invalid | |
""" | |
try: | |
# Validate configuration | |
config.validate() | |
# Create database connection | |
client = create_database_connection(url=url, api_key=api_key) | |
client.set_model(dense_model) | |
client.set_sparse_model(sparse_model) | |
# Check if collection exists | |
collections = client.get_collections() | |
collection_names = [c.name for c in collections.collections] | |
if collection_name not in collection_names: | |
logger.info(f"Creating new collection: {collection_name}") | |
# Create collection with appropriate vector configuration | |
if use_fastembed: | |
vectors_config = client.get_fastembed_vector_params() | |
sparse_vectors_config = client.get_fastembed_sparse_vector_params() | |
else: | |
print("Creating database without Fastembed not implemented yet.") | |
raise NotImplementedError() | |
try: | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=vectors_config, | |
sparse_vectors_config=sparse_vectors_config, | |
on_disk_payload=True | |
) | |
except exceptions.UnexpectedResponse as e: | |
raise CollectionError( | |
f"Failed to create collection {collection_name}: {str(e)}" | |
) from e | |
# Create indexes for efficient searching | |
for field_name, field_type in [ | |
("filename", models.PayloadSchemaType.KEYWORD), | |
("date", models.PayloadSchemaType.DATETIME) | |
]: | |
client.create_payload_index( | |
collection_name=collection_name, | |
field_name=field_name, | |
field_schema=field_type | |
) | |
logger.info(f"Created indexes for collection {collection_name}") | |
# Log collection status | |
collection_info = client.get_collection(collection_name) | |
logger.info( | |
f"Collection {collection_name} ready with " | |
f"{collection_info.points_count} points" | |
) | |
return client | |
except Exception as e: | |
logger.error(f"Database initialization failed: {str(e)}", exc_info=True) | |
if isinstance(e, (DatabaseConnectionError, CollectionError)): | |
raise | |
raise DatabaseInitializationError(str(e)) from e | |
def validate_database_payload( | |
client: QdrantClient, | |
collection_name: str, | |
) -> Dict: | |
"""Validate the payload of all points in the Qdrant database. | |
Args: | |
client: Initialized Qdrant client | |
collection_name: Name of the collection to validate | |
""" | |
# First validate existing points in database | |
logger.info("Validating existing database points...") | |
offset = None | |
while True: | |
scroll_result = client.scroll( | |
collection_name=collection_name, | |
limit=5, # Process in batches of 5 | |
offset=offset | |
) | |
points, offset = scroll_result | |
for point in points: | |
try: | |
fixed_payload = validate_point_payload(point.payload, point.id) | |
if fixed_payload != point.payload: | |
# Update point with fixed payload | |
point_struct = PointStruct( | |
id=point.id, | |
vector=point.vector, | |
payload=fixed_payload | |
) | |
client.upsert(collection_name=collection_name, points=[point_struct]) | |
logger.info(f"Fixed and updated point {point.id} in database") | |
except ConfigurationError as e: | |
logger.error(str(e)) | |
if not offset: # No more points to process | |
break | |
logger.info("Database validation completed") | |
def validate_point_payload(payload: dict, point_id: Optional[str] = None) -> dict: | |
"""Validate and fix point payload fields. | |
Only use if somehow many points have become corrupted. | |
Args: | |
payload (dict): Point payload to validate | |
point_id (str, optional): ID of the point for logging purposes | |
Returns: | |
dict: Validated and potentially fixed payload | |
Raises: | |
ConfigurationError: If required fields are missing and cannot be fixed | |
""" | |
print(f"Validating point {point_id if point_id else ''}") | |
from .exceptions import ConfigurationError | |
# Check for critical fields | |
if 'filename' not in payload or 'content' not in payload: | |
error_msg = f"Point {point_id if point_id else ''} is missing critical fields: " | |
error_msg += "'filename' and/or 'content' are required and cannot be defaulted" | |
raise ConfigurationError(error_msg) | |
# Copy payload to avoid modifying the original | |
fixed_payload = payload.copy() | |
# Apply defaults and fixes for non-critical fields | |
if 'purpose' not in fixed_payload or not fixed_payload['purpose']: | |
fixed_payload['purpose'] = fixed_payload['content'] | |
logger.warning(f"Point {point_id if point_id else ''}: 'purpose' was missing, set to content value") | |
if 'filesize' not in fixed_payload: | |
fixed_payload['filesize'] = self.required_fields_defaults['filesize'] | |
logger.warning(f"Point {point_id if point_id else ''}: 'filesize' was missing, set to {self.required_fields_defaults['filesize']}") | |
if 'trigger' not in fixed_payload: | |
fixed_payload['trigger'] = self.required_fields_defaults['trigger'] | |
logger.warning(f"Point {point_id if point_id else ''}: 'trigger' was missing, set to {self.required_fields_defaults['trigger']}") | |
return fixed_payload |