File size: 9,868 Bytes
2197ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe3652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2197ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b40ec9
 
2197ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b40ec9
 
 
2197ab7
 
 
 
 
 
 
 
 
 
5b40ec9
 
2197ab7
5b40ec9
 
2197ab7
 
 
 
5b40ec9
 
2197ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""Database management for fabric-to-espanso."""
from typing import Optional, List, Dict
import logging
import time

from qdrant_client import QdrantClient
from qdrant_client.http import models, exceptions
from qdrant_client.http.models import Distance, VectorParams, PointStruct

from .config import config
from .exceptions import DatabaseConnectionError, CollectionError, DatabaseInitializationError, ConfigurationError

logger = logging.getLogger('fabric_to_espanso')

def get_dense_vector_name(client: QdrantClient, collection_name: str) -> str:
    """
    Get the name of the dense vector from the collection configuration.
    
    Args:
        client: Initialized Qdrant client
        collection_name: Name of the collection
        
    Returns:
        Name of the dense vector as used in the collection
    """
    try:
        return list(client.get_collection(collection_name).config.params.vectors.keys())[0]
    except (IndexError, AttributeError) as e:
        logger.warning(f"Could not get dense vector name: {e}")
        # Fallback to a default name
        return "fast-multilingual-e5-large"

def get_sparse_vector_name(client: QdrantClient, collection_name: str) -> str:
    """
    Get the name of the sparse vector from the collection configuration.
    
    Args:
        client: Initialized Qdrant client
        collection_name: Name of the collection
        
    Returns:
        Name of the sparse vector as used in the collection
    """
    try:
        return list(client.get_collection(collection_name).config.params.sparse_vectors.keys())[0]
    except (IndexError, AttributeError) as e:
        logger.warning(f"Could not get sparse vector name: {e}")
        # Fallback to a default name
        return "fast-sparse-splade_pp_en_v1"

def create_database_connection(url: Optional[str] = None, api_key: Optional[str] = None) -> QdrantClient:
    """Create a database connection.
    
    Args:
        url: Optional database URL. If not provided, uses configuration.
        
    Returns:
        QdrantClient: Connected database client
        
    Raises:
        DatabaseConnectionError: If connection fails after retries
    """
    url = url or config.database.url
    for attempt in range(config.database.max_retries + 1):
        try:
            client = QdrantClient(
                url=url,
                timeout=config.database.timeout,
                api_key=api_key
            )
            # Test connection
            client.get_collections()
            return client
        except Exception as e:
            if attempt == config.database.max_retries:
                raise DatabaseConnectionError(
                    f"Failed to connect to database at {url} after "
                    f"{config.database.max_retries} attempts: {str(e)}"
                ) from e
            logger.warning(
                f"Connection attempt {attempt + 1} failed, retrying in "
                f"{config.database.retry_delay} seconds..."
            )
            time.sleep(config.database.retry_delay)

def initialize_qdrant_database(
    url: str = config.database.url,
    api_key: Optional[str] = "",
    collection_name: str = config.embedding.collection_name,
    use_fastembed: bool = config.embedding.use_fastembed,
    dense_model: str = config.embedding.dense_model_name,
    sparse_model: str = config.embedding.sparse_model_name
) -> QdrantClient:
    """Initialize the Qdrant database for storing markdown file information.
    
    Args:
        collection_name: Name of the collection to initialize
        use_fastembed: Whether to use FastEmbed for embeddings
        embed_model: Name of the embedding model to use
        
    Returns:
        QdrantClient: Initialized database client
        
    Raises:
        DatabaseInitializationError: If initialization fails
        CollectionError: If collection creation fails
        ConfigurationError: If configuration is invalid
    """
    try:
        # Validate configuration
        config.validate()
        
        # Create database connection
        client = create_database_connection(url=url, api_key=api_key)

        client.set_model(dense_model)
        client.set_sparse_model(sparse_model)
        
        # Check if collection exists
        collections = client.get_collections()
        collection_names = [c.name for c in collections.collections]
        
        if collection_name not in collection_names:
            logger.info(f"Creating new collection: {collection_name}")
            
            # Create collection with appropriate vector configuration
            if use_fastembed:
                vectors_config = client.get_fastembed_vector_params()
                sparse_vectors_config = client.get_fastembed_sparse_vector_params()
            else:
                print("Creating database without Fastembed not implemented yet.")
                raise NotImplementedError()
            
            try:
                client.create_collection(
                    collection_name=collection_name,
                    vectors_config=vectors_config,
                    sparse_vectors_config=sparse_vectors_config,
                    on_disk_payload=True
                )
            except exceptions.UnexpectedResponse as e:
                raise CollectionError(
                    f"Failed to create collection {collection_name}: {str(e)}"
                ) from e
            
            # Create indexes for efficient searching
            for field_name, field_type in [
                ("filename", models.PayloadSchemaType.KEYWORD),
                ("date", models.PayloadSchemaType.DATETIME)
            ]:
                client.create_payload_index(
                    collection_name=collection_name,
                    field_name=field_name,
                    field_schema=field_type
                )
            logger.info(f"Created indexes for collection {collection_name}")
        
        # Log collection status
        collection_info = client.get_collection(collection_name)
        logger.info(
            f"Collection {collection_name} ready with "
            f"{collection_info.points_count} points"
        )
        
        return client
        
    except Exception as e:
        logger.error(f"Database initialization failed: {str(e)}", exc_info=True)
        if isinstance(e, (DatabaseConnectionError, CollectionError)):
            raise
        raise DatabaseInitializationError(str(e)) from e

def validate_database_payload(
    client: QdrantClient,
    collection_name: str,
    ) -> Dict:
    """Validate the payload of all points in the Qdrant database.
    
    Args:
        client: Initialized Qdrant client
        collection_name: Name of the collection to validate
    """

    # First validate existing points in database
    logger.info("Validating existing database points...")
    offset = None
    
    while True:
        scroll_result = client.scroll(
            collection_name=collection_name,
            limit=5, # Process in batches of 5
            offset=offset
        )
        
        points, offset = scroll_result
        
        for point in points:
            try:
                fixed_payload = validate_point_payload(point.payload, point.id)
                if fixed_payload != point.payload:
                    # Update point with fixed payload
                    point_struct = PointStruct(
                        id=point.id,
                        vector=point.vector,
                        payload=fixed_payload
                    )
                    client.upsert(collection_name=collection_name, points=[point_struct])
                    logger.info(f"Fixed and updated point {point.id} in database")
            except ConfigurationError as e:
                logger.error(str(e))
        
        if not offset:  # No more points to process
            break
    
    logger.info("Database validation completed")

def validate_point_payload(payload: dict, point_id: Optional[str] = None) -> dict:
    """Validate and fix point payload fields.
    Only use if somehow many points have become corrupted.
    
    Args:
        payload (dict): Point payload to validate
        point_id (str, optional): ID of the point for logging purposes
        
    Returns:
        dict: Validated and potentially fixed payload
        
    Raises:
        ConfigurationError: If required fields are missing and cannot be fixed
    """
    print(f"Validating point {point_id if point_id else ''}")
    from .exceptions import ConfigurationError
    
    # Check for critical fields
    if 'filename' not in payload or 'content' not in payload:
        error_msg = f"Point {point_id if point_id else ''} is missing critical fields: "
        error_msg += "'filename' and/or 'content' are required and cannot be defaulted"
        raise ConfigurationError(error_msg)
        
    # Copy payload to avoid modifying the original
    fixed_payload = payload.copy()
    
    # Apply defaults and fixes for non-critical fields
    if 'purpose' not in fixed_payload or not fixed_payload['purpose']:
        fixed_payload['purpose'] = fixed_payload['content']
        logger.warning(f"Point {point_id if point_id else ''}: 'purpose' was missing, set to content value")
        
    if 'filesize' not in fixed_payload:
        fixed_payload['filesize'] = self.required_fields_defaults['filesize']
        logger.warning(f"Point {point_id if point_id else ''}: 'filesize' was missing, set to {self.required_fields_defaults['filesize']}")
        
    if 'trigger' not in fixed_payload:
        fixed_payload['trigger'] = self.required_fields_defaults['trigger']
        logger.warning(f"Point {point_id if point_id else ''}: 'trigger' was missing, set to {self.required_fields_defaults['trigger']}")
        
    return fixed_payload