Hopsakee's picture
Upload folder using huggingface_hub
5fe3652 verified
"""Database management for fabric-to-espanso."""
from typing import Optional, List, Dict
import logging
import time
from qdrant_client import QdrantClient
from qdrant_client.http import models, exceptions
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from .config import config
from .exceptions import DatabaseConnectionError, CollectionError, DatabaseInitializationError, ConfigurationError
logger = logging.getLogger('fabric_to_espanso')
def get_dense_vector_name(client: QdrantClient, collection_name: str) -> str:
"""
Get the name of the dense vector from the collection configuration.
Args:
client: Initialized Qdrant client
collection_name: Name of the collection
Returns:
Name of the dense vector as used in the collection
"""
try:
return list(client.get_collection(collection_name).config.params.vectors.keys())[0]
except (IndexError, AttributeError) as e:
logger.warning(f"Could not get dense vector name: {e}")
# Fallback to a default name
return "fast-multilingual-e5-large"
def get_sparse_vector_name(client: QdrantClient, collection_name: str) -> str:
"""
Get the name of the sparse vector from the collection configuration.
Args:
client: Initialized Qdrant client
collection_name: Name of the collection
Returns:
Name of the sparse vector as used in the collection
"""
try:
return list(client.get_collection(collection_name).config.params.sparse_vectors.keys())[0]
except (IndexError, AttributeError) as e:
logger.warning(f"Could not get sparse vector name: {e}")
# Fallback to a default name
return "fast-sparse-splade_pp_en_v1"
def create_database_connection(url: Optional[str] = None, api_key: Optional[str] = None) -> QdrantClient:
"""Create a database connection.
Args:
url: Optional database URL. If not provided, uses configuration.
Returns:
QdrantClient: Connected database client
Raises:
DatabaseConnectionError: If connection fails after retries
"""
url = url or config.database.url
for attempt in range(config.database.max_retries + 1):
try:
client = QdrantClient(
url=url,
timeout=config.database.timeout,
api_key=api_key
)
# Test connection
client.get_collections()
return client
except Exception as e:
if attempt == config.database.max_retries:
raise DatabaseConnectionError(
f"Failed to connect to database at {url} after "
f"{config.database.max_retries} attempts: {str(e)}"
) from e
logger.warning(
f"Connection attempt {attempt + 1} failed, retrying in "
f"{config.database.retry_delay} seconds..."
)
time.sleep(config.database.retry_delay)
def initialize_qdrant_database(
url: str = config.database.url,
api_key: Optional[str] = "",
collection_name: str = config.embedding.collection_name,
use_fastembed: bool = config.embedding.use_fastembed,
dense_model: str = config.embedding.dense_model_name,
sparse_model: str = config.embedding.sparse_model_name
) -> QdrantClient:
"""Initialize the Qdrant database for storing markdown file information.
Args:
collection_name: Name of the collection to initialize
use_fastembed: Whether to use FastEmbed for embeddings
embed_model: Name of the embedding model to use
Returns:
QdrantClient: Initialized database client
Raises:
DatabaseInitializationError: If initialization fails
CollectionError: If collection creation fails
ConfigurationError: If configuration is invalid
"""
try:
# Validate configuration
config.validate()
# Create database connection
client = create_database_connection(url=url, api_key=api_key)
client.set_model(dense_model)
client.set_sparse_model(sparse_model)
# Check if collection exists
collections = client.get_collections()
collection_names = [c.name for c in collections.collections]
if collection_name not in collection_names:
logger.info(f"Creating new collection: {collection_name}")
# Create collection with appropriate vector configuration
if use_fastembed:
vectors_config = client.get_fastembed_vector_params()
sparse_vectors_config = client.get_fastembed_sparse_vector_params()
else:
print("Creating database without Fastembed not implemented yet.")
raise NotImplementedError()
try:
client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
on_disk_payload=True
)
except exceptions.UnexpectedResponse as e:
raise CollectionError(
f"Failed to create collection {collection_name}: {str(e)}"
) from e
# Create indexes for efficient searching
for field_name, field_type in [
("filename", models.PayloadSchemaType.KEYWORD),
("date", models.PayloadSchemaType.DATETIME)
]:
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=field_type
)
logger.info(f"Created indexes for collection {collection_name}")
# Log collection status
collection_info = client.get_collection(collection_name)
logger.info(
f"Collection {collection_name} ready with "
f"{collection_info.points_count} points"
)
return client
except Exception as e:
logger.error(f"Database initialization failed: {str(e)}", exc_info=True)
if isinstance(e, (DatabaseConnectionError, CollectionError)):
raise
raise DatabaseInitializationError(str(e)) from e
def validate_database_payload(
client: QdrantClient,
collection_name: str,
) -> Dict:
"""Validate the payload of all points in the Qdrant database.
Args:
client: Initialized Qdrant client
collection_name: Name of the collection to validate
"""
# First validate existing points in database
logger.info("Validating existing database points...")
offset = None
while True:
scroll_result = client.scroll(
collection_name=collection_name,
limit=5, # Process in batches of 5
offset=offset
)
points, offset = scroll_result
for point in points:
try:
fixed_payload = validate_point_payload(point.payload, point.id)
if fixed_payload != point.payload:
# Update point with fixed payload
point_struct = PointStruct(
id=point.id,
vector=point.vector,
payload=fixed_payload
)
client.upsert(collection_name=collection_name, points=[point_struct])
logger.info(f"Fixed and updated point {point.id} in database")
except ConfigurationError as e:
logger.error(str(e))
if not offset: # No more points to process
break
logger.info("Database validation completed")
def validate_point_payload(payload: dict, point_id: Optional[str] = None) -> dict:
"""Validate and fix point payload fields.
Only use if somehow many points have become corrupted.
Args:
payload (dict): Point payload to validate
point_id (str, optional): ID of the point for logging purposes
Returns:
dict: Validated and potentially fixed payload
Raises:
ConfigurationError: If required fields are missing and cannot be fixed
"""
print(f"Validating point {point_id if point_id else ''}")
from .exceptions import ConfigurationError
# Check for critical fields
if 'filename' not in payload or 'content' not in payload:
error_msg = f"Point {point_id if point_id else ''} is missing critical fields: "
error_msg += "'filename' and/or 'content' are required and cannot be defaulted"
raise ConfigurationError(error_msg)
# Copy payload to avoid modifying the original
fixed_payload = payload.copy()
# Apply defaults and fixes for non-critical fields
if 'purpose' not in fixed_payload or not fixed_payload['purpose']:
fixed_payload['purpose'] = fixed_payload['content']
logger.warning(f"Point {point_id if point_id else ''}: 'purpose' was missing, set to content value")
if 'filesize' not in fixed_payload:
fixed_payload['filesize'] = self.required_fields_defaults['filesize']
logger.warning(f"Point {point_id if point_id else ''}: 'filesize' was missing, set to {self.required_fields_defaults['filesize']}")
if 'trigger' not in fixed_payload:
fixed_payload['trigger'] = self.required_fields_defaults['trigger']
logger.warning(f"Point {point_id if point_id else ''}: 'trigger' was missing, set to {self.required_fields_defaults['trigger']}")
return fixed_payload