ArneBinder commited on
Commit
b0174fa
1 Parent(s): fed112f

from https://github.com/ArneBinder/pie-document-level/pull/233

Browse files

use QdrantVectorStore again (since it works with download/upload now)

Files changed (1) hide show
  1. vector_store.py +82 -33
vector_store.py CHANGED
@@ -12,6 +12,16 @@ E = TypeVar("E")
12
 
13
 
14
  class VectorStore(Generic[T, E], abc.ABC):
 
 
 
 
 
 
 
 
 
 
15
  @abc.abstractmethod
16
  def _add(self, embedding: E, payload: T, emb_id: str) -> None:
17
  """Save an embedding with payload for a given ID."""
@@ -22,6 +32,11 @@ class VectorStore(Generic[T, E], abc.ABC):
22
  """Get the embedding for a given ID."""
23
  pass
24
 
 
 
 
 
 
25
  def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str:
26
  if emb_id is None:
27
  if payload is None:
@@ -67,16 +82,41 @@ class VectorStore(Generic[T, E], abc.ABC):
67
  def __len__(self):
68
  pass
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def save_to_directory(self, directory: str) -> None:
71
  """Save the vector store to a directory."""
72
- raise NotImplementedError
 
73
 
74
  def load_from_directory(self, directory: str, replace: bool = False) -> None:
75
  """Load the vector store from a directory.
76
 
77
  If `replace` is True, the current content of the store will be replaced.
78
  """
79
- raise NotImplementedError
 
 
80
 
81
 
82
  def vector_norm(vector: List[float]) -> float:
@@ -88,10 +128,7 @@ def cosine_similarity(a: List[float], b: List[float]) -> float:
88
 
89
 
90
  class SimpleVectorStore(VectorStore[T, List[float]]):
91
-
92
- INDEX_FILE = "vectors_index.json"
93
- EMBEDDINGS_FILE = "vectors_data.npy"
94
- PAYLOADS_FILE = "vectors_payloads.json"
95
 
96
  def __init__(self):
97
  self.vectors: dict[str, List[float]] = {}
@@ -148,8 +185,7 @@ class SimpleVectorStore(VectorStore[T, List[float]]):
148
 
149
  return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries]
150
 
151
- def save_to_directory(self, directory: str) -> None:
152
- os.makedirs(directory, exist_ok=True)
153
  indices = list(self.vectors.keys())
154
  with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
155
  json.dump(indices, f)
@@ -159,20 +195,15 @@ class SimpleVectorStore(VectorStore[T, List[float]]):
159
  with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
160
  json.dump(payloads, f)
161
 
162
- def load_from_directory(self, directory: str, replace: bool = False) -> None:
163
- if replace:
164
- self.clear()
165
- with open(os.path.join(directory, self.INDEX_FILE), "r") as f:
166
- index = json.load(f)
167
- embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE))
168
- with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f:
169
- payloads = json.load(f)
170
- for emb_id, emb, payload in zip(index, embeddings_np, payloads):
171
- self.vectors[emb_id] = emb.tolist()
172
- self.payloads[emb_id] = payload
173
 
174
 
175
  class QdrantVectorStore(VectorStore[T, List[float]]):
 
176
 
177
  COLLECTION_NAME = "ADUs"
178
  MAX_LIMIT = 100
@@ -184,8 +215,8 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
184
  distance: Distance = Distance.COSINE,
185
  ):
186
  self.client = QdrantClient(location=location)
187
- self.id2idx = {}
188
- self.idx2id = {}
189
  self.client.create_collection(
190
  collection_name=self.COLLECTION_NAME,
191
  vectors_config=VectorParams(size=vector_size, distance=distance),
@@ -196,21 +227,26 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
196
 
197
  def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None:
198
 
199
- # we use the length of the id2idx dict as the index,
200
- # because we assume that, even when we delete an entry from
201
- # the store, we do not delete it from the index
202
- _id = len(self.id2idx)
 
 
 
 
 
 
 
203
  self.client.upsert(
204
  collection_name=self.COLLECTION_NAME,
205
- points=[PointStruct(id=_id, vector=embedding, payload=payload)],
206
  )
207
- self.id2idx[emb_id] = _id
208
- self.idx2id[_id] = emb_id
209
 
210
  def _get(self, emb_id: str) -> Optional[List[float]]:
211
  points = self.client.retrieve(
212
  collection_name=self.COLLECTION_NAME,
213
- ids=[self.id2idx[emb_id]],
214
  with_vectors=True,
215
  )
216
  if len(points) == 0:
@@ -225,11 +261,14 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
225
  ) -> List[Tuple[str, T, float]]:
226
  similar_entries = self.client.recommend(
227
  collection_name=self.COLLECTION_NAME,
228
- positive=[self.id2idx[ref_id]],
229
  limit=top_k or self.MAX_LIMIT,
230
  score_threshold=min_similarity,
231
  )
232
- return [(self.idx2id[entry.id], entry.payload, entry.score) for entry in similar_entries]
 
 
 
233
 
234
  def clear(self) -> None:
235
  vectors_config = self.client.get_collection(
@@ -240,5 +279,15 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
240
  collection_name=self.COLLECTION_NAME,
241
  vectors_config=vectors_config,
242
  )
243
- self.id2idx.clear()
244
- self.idx2id.clear()
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class VectorStore(Generic[T, E], abc.ABC):
15
+ """Abstract base class for a vector store.
16
+
17
+ A vector store is a key-value store that maps an ID to a vector embedding and a payload. The
18
+ payload can be any JSON-serializable object, e.g. a dictionary.
19
+ """
20
+
21
+ INDEX_FILE = "vectors_index.json"
22
+ EMBEDDINGS_FILE = "vectors_data.npy"
23
+ PAYLOADS_FILE = "vectors_payloads.json"
24
+
25
  @abc.abstractmethod
26
  def _add(self, embedding: E, payload: T, emb_id: str) -> None:
27
  """Save an embedding with payload for a given ID."""
 
32
  """Get the embedding for a given ID."""
33
  pass
34
 
35
+ @abc.abstractmethod
36
+ def clear(self) -> None:
37
+ """Clear the store."""
38
+ pass
39
+
40
  def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str:
41
  if emb_id is None:
42
  if payload is None:
 
82
  def __len__(self):
83
  pass
84
 
85
+ def _add_from_directory(self, directory: str) -> None:
86
+ with open(os.path.join(directory, self.INDEX_FILE), "r") as f:
87
+ index = json.load(f)
88
+ embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE))
89
+ with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f:
90
+ payloads = json.load(f)
91
+ for emb_id, emb, payload in zip(index, embeddings_np, payloads):
92
+ self._add(emb_id=emb_id, payload=payload, embedding=emb.tolist())
93
+
94
+ @abc.abstractmethod
95
+ def as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[T]]:
96
+ """Return a tuple of indices, vectors and payloads."""
97
+ pass
98
+
99
+ def _save_to_directory(self, directory: str) -> None:
100
+ indices, vectors, payloads = self.as_indices_vectors_payloads()
101
+ np.save(os.path.join(directory, self.EMBEDDINGS_FILE), vectors)
102
+ with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
103
+ json.dump(payloads, f)
104
+ with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
105
+ json.dump(indices, f)
106
+
107
  def save_to_directory(self, directory: str) -> None:
108
  """Save the vector store to a directory."""
109
+ os.makedirs(directory, exist_ok=True)
110
+ self._save_to_directory(directory)
111
 
112
  def load_from_directory(self, directory: str, replace: bool = False) -> None:
113
  """Load the vector store from a directory.
114
 
115
  If `replace` is True, the current content of the store will be replaced.
116
  """
117
+ if replace:
118
+ self.clear()
119
+ self._add_from_directory(directory)
120
 
121
 
122
  def vector_norm(vector: List[float]) -> float:
 
128
 
129
 
130
  class SimpleVectorStore(VectorStore[T, List[float]]):
131
+ """Simple in-memory vector store using a dictionary."""
 
 
 
132
 
133
  def __init__(self):
134
  self.vectors: dict[str, List[float]] = {}
 
185
 
186
  return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries]
187
 
188
+ def _save_to_directory(self, directory: str) -> None:
 
189
  indices = list(self.vectors.keys())
190
  with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
191
  json.dump(indices, f)
 
195
  with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
196
  json.dump(payloads, f)
197
 
198
+ def as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[T]]:
199
+ indices = list(self.vectors.keys())
200
+ embeddings_np = np.array(list(self.vectors.values()))
201
+ payloads = [self.payloads[idx] for idx in indices]
202
+ return indices, embeddings_np, payloads
 
 
 
 
 
 
203
 
204
 
205
  class QdrantVectorStore(VectorStore[T, List[float]]):
206
+ """Vector store using Qdrant as a backend."""
207
 
208
  COLLECTION_NAME = "ADUs"
209
  MAX_LIMIT = 100
 
215
  distance: Distance = Distance.COSINE,
216
  ):
217
  self.client = QdrantClient(location=location)
218
+ self.emb_id2point_id = {}
219
+ self.point_id2emb_id = {}
220
  self.client.create_collection(
221
  collection_name=self.COLLECTION_NAME,
222
  vectors_config=VectorParams(size=vector_size, distance=distance),
 
227
 
228
  def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None:
229
 
230
+ if emb_id in self.emb_id2point_id:
231
+ # update existing entry
232
+ point_id = self.emb_id2point_id[emb_id]
233
+ else:
234
+ # we use the length of the emb_id2point_id dict as the index,
235
+ # because we assume that, even when we delete an entry from
236
+ # the store, we do not delete it from the index
237
+ point_id = len(self.emb_id2point_id)
238
+ self.emb_id2point_id[emb_id] = point_id
239
+ self.point_id2emb_id[point_id] = emb_id
240
+
241
  self.client.upsert(
242
  collection_name=self.COLLECTION_NAME,
243
+ points=[PointStruct(id=point_id, vector=embedding, payload=payload)],
244
  )
 
 
245
 
246
  def _get(self, emb_id: str) -> Optional[List[float]]:
247
  points = self.client.retrieve(
248
  collection_name=self.COLLECTION_NAME,
249
+ ids=[self.emb_id2point_id[emb_id]],
250
  with_vectors=True,
251
  )
252
  if len(points) == 0:
 
261
  ) -> List[Tuple[str, T, float]]:
262
  similar_entries = self.client.recommend(
263
  collection_name=self.COLLECTION_NAME,
264
+ positive=[self.emb_id2point_id[ref_id]],
265
  limit=top_k or self.MAX_LIMIT,
266
  score_threshold=min_similarity,
267
  )
268
+ return [
269
+ (self.point_id2emb_id[entry.id], entry.payload, entry.score)
270
+ for entry in similar_entries
271
+ ]
272
 
273
  def clear(self) -> None:
274
  vectors_config = self.client.get_collection(
 
279
  collection_name=self.COLLECTION_NAME,
280
  vectors_config=vectors_config,
281
  )
282
+ self.emb_id2point_id.clear()
283
+ self.point_id2emb_id.clear()
284
+
285
+ def as_indices_vectors_payloads(self) -> Tuple[List[str], np.ndarray, List[T]]:
286
+ num_entries = self.client.get_collection(collection_name=self.COLLECTION_NAME).points_count
287
+ data, point_ids = self.client.scroll(
288
+ collection_name=self.COLLECTION_NAME, with_vectors=True, limit=num_entries
289
+ )
290
+ vectors_np = np.array([point.vector for point in data])
291
+ payloads = [point.payload for point in data]
292
+ emb_ids = [self.point_id2emb_id[point.id] for point in data]
293
+ return emb_ids, vectors_np, payloads