Spaces:
Sleeping
Sleeping
File size: 9,868 Bytes
2197ab7 5fe3652 2197ab7 5b40ec9 2197ab7 5b40ec9 2197ab7 5b40ec9 2197ab7 5b40ec9 2197ab7 5b40ec9 2197ab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 |
"""Database management for fabric-to-espanso."""
from typing import Optional, List, Dict
import logging
import time
from qdrant_client import QdrantClient
from qdrant_client.http import models, exceptions
from qdrant_client.http.models import Distance, VectorParams, PointStruct
from .config import config
from .exceptions import DatabaseConnectionError, CollectionError, DatabaseInitializationError, ConfigurationError
logger = logging.getLogger('fabric_to_espanso')
def get_dense_vector_name(client: QdrantClient, collection_name: str) -> str:
"""
Get the name of the dense vector from the collection configuration.
Args:
client: Initialized Qdrant client
collection_name: Name of the collection
Returns:
Name of the dense vector as used in the collection
"""
try:
return list(client.get_collection(collection_name).config.params.vectors.keys())[0]
except (IndexError, AttributeError) as e:
logger.warning(f"Could not get dense vector name: {e}")
# Fallback to a default name
return "fast-multilingual-e5-large"
def get_sparse_vector_name(client: QdrantClient, collection_name: str) -> str:
"""
Get the name of the sparse vector from the collection configuration.
Args:
client: Initialized Qdrant client
collection_name: Name of the collection
Returns:
Name of the sparse vector as used in the collection
"""
try:
return list(client.get_collection(collection_name).config.params.sparse_vectors.keys())[0]
except (IndexError, AttributeError) as e:
logger.warning(f"Could not get sparse vector name: {e}")
# Fallback to a default name
return "fast-sparse-splade_pp_en_v1"
def create_database_connection(url: Optional[str] = None, api_key: Optional[str] = None) -> QdrantClient:
"""Create a database connection.
Args:
url: Optional database URL. If not provided, uses configuration.
Returns:
QdrantClient: Connected database client
Raises:
DatabaseConnectionError: If connection fails after retries
"""
url = url or config.database.url
for attempt in range(config.database.max_retries + 1):
try:
client = QdrantClient(
url=url,
timeout=config.database.timeout,
api_key=api_key
)
# Test connection
client.get_collections()
return client
except Exception as e:
if attempt == config.database.max_retries:
raise DatabaseConnectionError(
f"Failed to connect to database at {url} after "
f"{config.database.max_retries} attempts: {str(e)}"
) from e
logger.warning(
f"Connection attempt {attempt + 1} failed, retrying in "
f"{config.database.retry_delay} seconds..."
)
time.sleep(config.database.retry_delay)
def initialize_qdrant_database(
url: str = config.database.url,
api_key: Optional[str] = "",
collection_name: str = config.embedding.collection_name,
use_fastembed: bool = config.embedding.use_fastembed,
dense_model: str = config.embedding.dense_model_name,
sparse_model: str = config.embedding.sparse_model_name
) -> QdrantClient:
"""Initialize the Qdrant database for storing markdown file information.
Args:
collection_name: Name of the collection to initialize
use_fastembed: Whether to use FastEmbed for embeddings
embed_model: Name of the embedding model to use
Returns:
QdrantClient: Initialized database client
Raises:
DatabaseInitializationError: If initialization fails
CollectionError: If collection creation fails
ConfigurationError: If configuration is invalid
"""
try:
# Validate configuration
config.validate()
# Create database connection
client = create_database_connection(url=url, api_key=api_key)
client.set_model(dense_model)
client.set_sparse_model(sparse_model)
# Check if collection exists
collections = client.get_collections()
collection_names = [c.name for c in collections.collections]
if collection_name not in collection_names:
logger.info(f"Creating new collection: {collection_name}")
# Create collection with appropriate vector configuration
if use_fastembed:
vectors_config = client.get_fastembed_vector_params()
sparse_vectors_config = client.get_fastembed_sparse_vector_params()
else:
print("Creating database without Fastembed not implemented yet.")
raise NotImplementedError()
try:
client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
on_disk_payload=True
)
except exceptions.UnexpectedResponse as e:
raise CollectionError(
f"Failed to create collection {collection_name}: {str(e)}"
) from e
# Create indexes for efficient searching
for field_name, field_type in [
("filename", models.PayloadSchemaType.KEYWORD),
("date", models.PayloadSchemaType.DATETIME)
]:
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=field_type
)
logger.info(f"Created indexes for collection {collection_name}")
# Log collection status
collection_info = client.get_collection(collection_name)
logger.info(
f"Collection {collection_name} ready with "
f"{collection_info.points_count} points"
)
return client
except Exception as e:
logger.error(f"Database initialization failed: {str(e)}", exc_info=True)
if isinstance(e, (DatabaseConnectionError, CollectionError)):
raise
raise DatabaseInitializationError(str(e)) from e
def validate_database_payload(
client: QdrantClient,
collection_name: str,
) -> Dict:
"""Validate the payload of all points in the Qdrant database.
Args:
client: Initialized Qdrant client
collection_name: Name of the collection to validate
"""
# First validate existing points in database
logger.info("Validating existing database points...")
offset = None
while True:
scroll_result = client.scroll(
collection_name=collection_name,
limit=5, # Process in batches of 5
offset=offset
)
points, offset = scroll_result
for point in points:
try:
fixed_payload = validate_point_payload(point.payload, point.id)
if fixed_payload != point.payload:
# Update point with fixed payload
point_struct = PointStruct(
id=point.id,
vector=point.vector,
payload=fixed_payload
)
client.upsert(collection_name=collection_name, points=[point_struct])
logger.info(f"Fixed and updated point {point.id} in database")
except ConfigurationError as e:
logger.error(str(e))
if not offset: # No more points to process
break
logger.info("Database validation completed")
def validate_point_payload(payload: dict, point_id: Optional[str] = None) -> dict:
"""Validate and fix point payload fields.
Only use if somehow many points have become corrupted.
Args:
payload (dict): Point payload to validate
point_id (str, optional): ID of the point for logging purposes
Returns:
dict: Validated and potentially fixed payload
Raises:
ConfigurationError: If required fields are missing and cannot be fixed
"""
print(f"Validating point {point_id if point_id else ''}")
from .exceptions import ConfigurationError
# Check for critical fields
if 'filename' not in payload or 'content' not in payload:
error_msg = f"Point {point_id if point_id else ''} is missing critical fields: "
error_msg += "'filename' and/or 'content' are required and cannot be defaulted"
raise ConfigurationError(error_msg)
# Copy payload to avoid modifying the original
fixed_payload = payload.copy()
# Apply defaults and fixes for non-critical fields
if 'purpose' not in fixed_payload or not fixed_payload['purpose']:
fixed_payload['purpose'] = fixed_payload['content']
logger.warning(f"Point {point_id if point_id else ''}: 'purpose' was missing, set to content value")
if 'filesize' not in fixed_payload:
fixed_payload['filesize'] = self.required_fields_defaults['filesize']
logger.warning(f"Point {point_id if point_id else ''}: 'filesize' was missing, set to {self.required_fields_defaults['filesize']}")
if 'trigger' not in fixed_payload:
fixed_payload['trigger'] = self.required_fields_defaults['trigger']
logger.warning(f"Point {point_id if point_id else ''}: 'trigger' was missing, set to {self.required_fields_defaults['trigger']}")
return fixed_payload |