sam-pointer-bart-base-v0.3 / vector_store.py
ArneBinder's picture
from https://github.com/ArneBinder/pie-document-level/pull/225
d7a2972 verified
raw
history blame
9.07 kB
import abc
import json
import os
from typing import Any, Generic, List, Optional, Tuple, TypeVar
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
T = TypeVar("T", bound=dict[str, Any])
E = TypeVar("E")
class VectorStore(Generic[T, E], abc.ABC):
@abc.abstractmethod
def _add(self, embedding: E, payload: T, emb_id: str) -> None:
"""Save an embedding with payload for a given ID."""
pass
@abc.abstractmethod
def _get(self, emb_id: str) -> Optional[E]:
"""Get the embedding for a given ID."""
pass
def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str:
if emb_id is None:
if payload is None:
raise ValueError("Either emb_id or payload must be provided.")
emb_id = json.dumps(payload, sort_keys=True)
return emb_id
def add(self, embedding: E, payload: T, emb_id: Optional[str] = None) -> None:
if emb_id is None:
emb_id = json.dumps(payload, sort_keys=True)
self._add(embedding=embedding, payload=payload, emb_id=emb_id)
def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]:
return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload))
@abc.abstractmethod
def _retrieve_similar(
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[T, float]]:
"""Retrieve IDs, payloads and the respective similarity scores with respect to the
reference entry. Note that this requires the reference entry to be present in the store.
Args:
ref_id: The ID of the reference entry.
top_k: If provided, only the top-k most similar entries will be returned.
min_similarity: If provided, only entries with a similarity score greater or equal to
this value will be returned.
Returns:
A list of tuples consisting of the ID and the similarity score, sorted by similarity
score in descending order.
"""
pass
def retrieve_similar(
self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs
) -> List[Tuple[T, float]]:
return self._retrieve_similar(
ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs
)
@abc.abstractmethod
def __len__(self):
pass
def save_to_directory(self, directory: str) -> None:
"""Save the vector store to a directory."""
raise NotImplementedError
def load_from_directory(self, directory: str, replace: bool = False) -> None:
"""Load the vector store from a directory.
If `replace` is True, the current content of the store will be replaced.
"""
raise NotImplementedError
def vector_norm(vector: List[float]) -> float:
return sum(x**2 for x in vector) ** 0.5
def cosine_similarity(a: List[float], b: List[float]) -> float:
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
class SimpleVectorStore(VectorStore[T, List[float]]):
INDEX_FILE = "vectors_index.json"
EMBEDDINGS_FILE = "vectors_data.npy"
PAYLOADS_FILE = "vectors_payloads.json"
def __init__(self):
self.vectors: dict[str, List[float]] = {}
self.payloads: dict[str, T] = {}
self._cache = {}
self._sim = cosine_similarity
def _add(self, embedding: E, payload: T, emb_id: str) -> None:
self.vectors[emb_id] = embedding
self.payloads[emb_id] = payload
def _get(self, emb_id: str) -> Optional[E]:
return self.vectors.get(emb_id)
def delete(self, emb_id: str) -> None:
if emb_id in self.vectors:
del self.vectors[emb_id]
del self.payloads[emb_id]
# remove from cache
self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}
def clear(self) -> None:
self.vectors.clear()
self._cache.clear()
self.payloads.clear()
def __len__(self):
return len(self.vectors)
def _retrieve_similar(
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[str, T, float]]:
ref_embedding = self.get(emb_id=ref_id)
if ref_embedding is None:
raise ValueError(f"Reference embedding '{ref_id}' not found.")
# calculate similarity to all embeddings
similarities = {}
for emb_id, embedding in self.vectors.items():
if (emb_id, ref_id) not in self._cache:
# use cosine similarity
self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding)
similarities[emb_id] = self._cache[(emb_id, ref_id)]
# sort by similarity
similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
if min_similarity is not None:
similar_entries = [
(emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity
]
if top_k is not None:
similar_entries = similar_entries[:top_k]
return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries]
def save_to_directory(self, directory: str) -> None:
os.makedirs(directory, exist_ok=True)
indices = list(self.vectors.keys())
with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
json.dump(indices, f)
embeddings_np = np.array(list(self.vectors.values()))
np.save(os.path.join(directory, self.EMBEDDINGS_FILE), embeddings_np)
payloads = [self.payloads[idx] for idx in indices]
with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
json.dump(payloads, f)
def load_from_directory(self, directory: str, replace: bool = False) -> None:
if replace:
self.clear()
with open(os.path.join(directory, self.INDEX_FILE), "r") as f:
index = json.load(f)
embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE))
with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f:
payloads = json.load(f)
for emb_id, emb, payload in zip(index, embeddings_np, payloads):
self.vectors[emb_id] = emb.tolist()
self.payloads[emb_id] = payload
class QdrantVectorStore(VectorStore[T, List[float]]):
COLLECTION_NAME = "ADUs"
MAX_LIMIT = 100
def __init__(
self,
location: str = ":memory:",
vector_size: int = 768,
distance: Distance = Distance.COSINE,
):
self.client = QdrantClient(location=location)
self.id2idx = {}
self.idx2id = {}
self.client.create_collection(
collection_name=self.COLLECTION_NAME,
vectors_config=VectorParams(size=vector_size, distance=distance),
)
def __len__(self):
return self.client.get_collection(collection_name=self.COLLECTION_NAME).points_count
def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None:
# we use the length of the id2idx dict as the index,
# because we assume that, even when we delete an entry from
# the store, we do not delete it from the index
_id = len(self.id2idx)
self.client.upsert(
collection_name=self.COLLECTION_NAME,
points=[PointStruct(id=_id, vector=embedding, payload=payload)],
)
self.id2idx[emb_id] = _id
self.idx2id[_id] = emb_id
def _get(self, emb_id: str) -> Optional[List[float]]:
points = self.client.retrieve(
collection_name=self.COLLECTION_NAME,
ids=[self.id2idx[emb_id]],
with_vectors=True,
)
if len(points) == 0:
return None
elif len(points) == 1:
return points[0].vector
else:
raise ValueError(f"Multiple points found for ID '{emb_id}'.")
def _retrieve_similar(
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[str, T, float]]:
similar_entries = self.client.recommend(
collection_name=self.COLLECTION_NAME,
positive=[self.id2idx[ref_id]],
limit=top_k or self.MAX_LIMIT,
score_threshold=min_similarity,
)
return [(self.idx2id[entry.id], entry.payload, entry.score) for entry in similar_entries]
def clear(self) -> None:
vectors_config = self.client.get_collection(
collection_name=self.COLLECTION_NAME
).vectors_config
self.client.delete_collection(collection_name=self.COLLECTION_NAME)
self.client.create_collection(
collection_name=self.COLLECTION_NAME,
vectors_config=vectors_config,
)
self.id2idx.clear()
self.idx2id.clear()