File size: 7,516 Bytes
e8051be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""

Vector Storage Module



Handles storing chunks and embeddings in Qdrant vector database.

"""

import numpy as np
from typing import List
from pathlib import Path
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct


class VectorStorage:
    """Handles vector storage operations with Qdrant."""
    
    def __init__(self, base_db_path: Path):
        """

        Initialize the vector storage.

        

        Args:

            base_db_path: Base path for storing Qdrant databases

        """
        self.base_db_path = base_db_path
    
    async def store_in_qdrant(self, chunks: List[str], embeddings: np.ndarray, doc_id: str):
        """

        Store chunks and embeddings in Qdrant.

        

        Args:

            chunks: List of text chunks

            embeddings: Corresponding embeddings array

            doc_id: Document identifier

        """
        if len(chunks) != embeddings.shape[0]:
            raise ValueError(f"Chunk count ({len(chunks)}) doesn't match embedding count ({embeddings.shape[0]})")
        
        collection_name = f"{doc_id}_collection"
        db_path = self.base_db_path / f"{collection_name}.db"
        client = QdrantClient(path=str(db_path))
        
        print(f"πŸ’Ύ Storing {len(chunks)} vectors in collection: {collection_name}")
        
        try:
            # Create or recreate collection
            await self._setup_collection(client, collection_name, embeddings.shape[1])
            
            # Prepare and upload points
            await self._upload_points(client, collection_name, chunks, embeddings, doc_id)
            
            print(f"βœ… Successfully stored all vectors in Qdrant")
            
        finally:
            client.close()
    
    async def _setup_collection(self, client: QdrantClient, collection_name: str, embedding_dim: int):
        """

        Set up Qdrant collection, recreating if it exists.

        

        Args:

            client: Qdrant client

            collection_name: Name of the collection

            embedding_dim: Dimension of embeddings

        """
        # Delete existing collection if it exists
        try:
            client.delete_collection(collection_name)
            print(f"πŸ—‘οΈ Deleted existing collection: {collection_name}")
        except Exception:
            pass  # Collection might not exist
        
        # Create new collection
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(
                size=embedding_dim,
                distance=Distance.COSINE
            )
        )
        print(f"βœ… Created new collection: {collection_name}")
    
    async def _upload_points(self, client: QdrantClient, collection_name: str, 

                           chunks: List[str], embeddings: np.ndarray, doc_id: str):
        """

        Upload points to Qdrant collection in batches.

        

        Args:

            client: Qdrant client

            collection_name: Name of the collection

            chunks: Text chunks

            embeddings: Embedding vectors

            doc_id: Document identifier

        """
        # Prepare points
        points = []
        for i in range(len(chunks)):
            points.append(
                PointStruct(
                    id=i,
                    vector=embeddings[i].tolist(),
                    payload={
                        "text": chunks[i],
                        "chunk_id": i,
                        "doc_id": doc_id,
                        "char_count": len(chunks[i]),
                        "word_count": len(chunks[i].split())
                    }
                )
            )
        
        # Upload in batches to handle large documents
        batch_size = 100
        total_batches = (len(points) + batch_size - 1) // batch_size
        
        for i in range(0, len(points), batch_size):
            batch = points[i:i + batch_size]
            batch_num = (i // batch_size) + 1
            
            print(f"   Uploading batch {batch_num}/{total_batches} ({len(batch)} points)")
            client.upsert(collection_name=collection_name, points=batch)
        
        print(f"βœ… Uploaded {len(points)} points in {total_batches} batches")
    
    def collection_exists(self, doc_id: str) -> bool:
        """

        Check if a collection exists for the given document ID.

        

        Args:

            doc_id: Document identifier

            

        Returns:

            bool: True if collection exists, False otherwise

        """
        collection_name = f"{doc_id}_collection"
        db_path = self.base_db_path / f"{collection_name}.db"
        return db_path.exists()
    
    def get_collection_info(self, doc_id: str) -> dict:
        """

        Get information about a collection.

        

        Args:

            doc_id: Document identifier

            

        Returns:

            dict: Collection information

        """
        collection_name = f"{doc_id}_collection"
        db_path = self.base_db_path / f"{collection_name}.db"
        
        if not db_path.exists():
            return {
                "collection_name": collection_name,
                "exists": False,
                "path": str(db_path)
            }
        
        try:
            client = QdrantClient(path=str(db_path))
            try:
                collection_info = client.get_collection(collection_name)
                return {
                    "collection_name": collection_name,
                    "exists": True,
                    "path": str(db_path),
                    "vectors_count": collection_info.vectors_count,
                    "status": collection_info.status
                }
            finally:
                client.close()
        except Exception as e:
            return {
                "collection_name": collection_name,
                "exists": True,
                "path": str(db_path),
                "error": str(e)
            }
    
    def delete_collection(self, doc_id: str) -> bool:
        """

        Delete a collection and its database file.

        

        Args:

            doc_id: Document identifier

            

        Returns:

            bool: True if successfully deleted, False otherwise

        """
        collection_name = f"{doc_id}_collection"
        db_path = self.base_db_path / f"{collection_name}.db"
        
        try:
            if db_path.exists():
                # Try to delete collection properly first
                try:
                    client = QdrantClient(path=str(db_path))
                    client.delete_collection(collection_name)
                    client.close()
                except Exception:
                    pass  # Collection might not exist or be corrupted
                
                # Remove database directory
                import shutil
                shutil.rmtree(db_path, ignore_errors=True)
                print(f"πŸ—‘οΈ Deleted collection: {collection_name}")
                return True
            
        except Exception as e:
            print(f"❌ Error deleting collection {collection_name}: {e}")
            return False
        
        return True  # Nothing to delete