|
import abc |
|
import json |
|
import os |
|
from typing import Any, Generic, List, Optional, Tuple, TypeVar |
|
|
|
import numpy as np |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.models import Distance, PointStruct, VectorParams |
|
|
|
T = TypeVar("T", bound=dict[str, Any]) |
|
E = TypeVar("E") |
|
|
|
|
|
class VectorStore(Generic[T, E], abc.ABC): |
|
@abc.abstractmethod |
|
def _add(self, embedding: E, payload: T, emb_id: str) -> None: |
|
"""Save an embedding with payload for a given ID.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def _get(self, emb_id: str) -> Optional[E]: |
|
"""Get the embedding for a given ID.""" |
|
pass |
|
|
|
def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str: |
|
if emb_id is None: |
|
if payload is None: |
|
raise ValueError("Either emb_id or payload must be provided.") |
|
emb_id = json.dumps(payload, sort_keys=True) |
|
return emb_id |
|
|
|
def add(self, embedding: E, payload: T, emb_id: Optional[str] = None) -> None: |
|
if emb_id is None: |
|
emb_id = json.dumps(payload, sort_keys=True) |
|
self._add(embedding=embedding, payload=payload, emb_id=emb_id) |
|
|
|
def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]: |
|
return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload)) |
|
|
|
@abc.abstractmethod |
|
def _retrieve_similar( |
|
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None |
|
) -> List[Tuple[T, float]]: |
|
"""Retrieve IDs, payloads and the respective similarity scores with respect to the |
|
reference entry. Note that this requires the reference entry to be present in the store. |
|
|
|
Args: |
|
ref_id: The ID of the reference entry. |
|
top_k: If provided, only the top-k most similar entries will be returned. |
|
min_similarity: If provided, only entries with a similarity score greater or equal to |
|
this value will be returned. |
|
|
|
Returns: |
|
A list of tuples consisting of the ID and the similarity score, sorted by similarity |
|
score in descending order. |
|
""" |
|
pass |
|
|
|
def retrieve_similar( |
|
self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs |
|
) -> List[Tuple[T, float]]: |
|
return self._retrieve_similar( |
|
ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs |
|
) |
|
|
|
@abc.abstractmethod |
|
def __len__(self): |
|
pass |
|
|
|
def save_to_directory(self, directory: str) -> None: |
|
"""Save the vector store to a directory.""" |
|
raise NotImplementedError |
|
|
|
def load_from_directory(self, directory: str, replace: bool = False) -> None: |
|
"""Load the vector store from a directory. |
|
|
|
If `replace` is True, the current content of the store will be replaced. |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
def vector_norm(vector: List[float]) -> float: |
|
return sum(x**2 for x in vector) ** 0.5 |
|
|
|
|
|
def cosine_similarity(a: List[float], b: List[float]) -> float: |
|
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b)) |
|
|
|
|
|
class SimpleVectorStore(VectorStore[T, List[float]]): |
|
|
|
INDEX_FILE = "vectors_index.json" |
|
EMBEDDINGS_FILE = "vectors_data.npy" |
|
PAYLOADS_FILE = "vectors_payloads.json" |
|
|
|
def __init__(self): |
|
self.vectors: dict[str, List[float]] = {} |
|
self.payloads: dict[str, T] = {} |
|
self._cache = {} |
|
self._sim = cosine_similarity |
|
|
|
def _add(self, embedding: E, payload: T, emb_id: str) -> None: |
|
self.vectors[emb_id] = embedding |
|
self.payloads[emb_id] = payload |
|
|
|
def _get(self, emb_id: str) -> Optional[E]: |
|
return self.vectors.get(emb_id) |
|
|
|
def delete(self, emb_id: str) -> None: |
|
if emb_id in self.vectors: |
|
del self.vectors[emb_id] |
|
del self.payloads[emb_id] |
|
|
|
self._cache = {k: v for k, v in self._cache.items() if emb_id not in k} |
|
|
|
def clear(self) -> None: |
|
self.vectors.clear() |
|
self._cache.clear() |
|
self.payloads.clear() |
|
|
|
def __len__(self): |
|
return len(self.vectors) |
|
|
|
def _retrieve_similar( |
|
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None |
|
) -> List[Tuple[str, T, float]]: |
|
ref_embedding = self.get(emb_id=ref_id) |
|
if ref_embedding is None: |
|
raise ValueError(f"Reference embedding '{ref_id}' not found.") |
|
|
|
|
|
similarities = {} |
|
for emb_id, embedding in self.vectors.items(): |
|
if (emb_id, ref_id) not in self._cache: |
|
|
|
self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding) |
|
similarities[emb_id] = self._cache[(emb_id, ref_id)] |
|
|
|
|
|
similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True) |
|
|
|
if min_similarity is not None: |
|
similar_entries = [ |
|
(emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity |
|
] |
|
if top_k is not None: |
|
similar_entries = similar_entries[:top_k] |
|
|
|
return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries] |
|
|
|
def save_to_directory(self, directory: str) -> None: |
|
os.makedirs(directory, exist_ok=True) |
|
indices = list(self.vectors.keys()) |
|
with open(os.path.join(directory, self.INDEX_FILE), "w") as f: |
|
json.dump(indices, f) |
|
embeddings_np = np.array(list(self.vectors.values())) |
|
np.save(os.path.join(directory, self.EMBEDDINGS_FILE), embeddings_np) |
|
payloads = [self.payloads[idx] for idx in indices] |
|
with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f: |
|
json.dump(payloads, f) |
|
|
|
def load_from_directory(self, directory: str, replace: bool = False) -> None: |
|
if replace: |
|
self.clear() |
|
with open(os.path.join(directory, self.INDEX_FILE), "r") as f: |
|
index = json.load(f) |
|
embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE)) |
|
with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f: |
|
payloads = json.load(f) |
|
for emb_id, emb, payload in zip(index, embeddings_np, payloads): |
|
self.vectors[emb_id] = emb.tolist() |
|
self.payloads[emb_id] = payload |
|
|
|
|
|
class QdrantVectorStore(VectorStore[T, List[float]]): |
|
|
|
COLLECTION_NAME = "ADUs" |
|
MAX_LIMIT = 100 |
|
|
|
def __init__( |
|
self, |
|
location: str = ":memory:", |
|
vector_size: int = 768, |
|
distance: Distance = Distance.COSINE, |
|
): |
|
self.client = QdrantClient(location=location) |
|
self.id2idx = {} |
|
self.idx2id = {} |
|
self.client.create_collection( |
|
collection_name=self.COLLECTION_NAME, |
|
vectors_config=VectorParams(size=vector_size, distance=distance), |
|
) |
|
|
|
def __len__(self): |
|
return self.client.get_collection(collection_name=self.COLLECTION_NAME).points_count |
|
|
|
def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None: |
|
|
|
|
|
|
|
|
|
_id = len(self.id2idx) |
|
self.client.upsert( |
|
collection_name=self.COLLECTION_NAME, |
|
points=[PointStruct(id=_id, vector=embedding, payload=payload)], |
|
) |
|
self.id2idx[emb_id] = _id |
|
self.idx2id[_id] = emb_id |
|
|
|
def _get(self, emb_id: str) -> Optional[List[float]]: |
|
points = self.client.retrieve( |
|
collection_name=self.COLLECTION_NAME, |
|
ids=[self.id2idx[emb_id]], |
|
with_vectors=True, |
|
) |
|
if len(points) == 0: |
|
return None |
|
elif len(points) == 1: |
|
return points[0].vector |
|
else: |
|
raise ValueError(f"Multiple points found for ID '{emb_id}'.") |
|
|
|
def _retrieve_similar( |
|
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None |
|
) -> List[Tuple[str, T, float]]: |
|
similar_entries = self.client.recommend( |
|
collection_name=self.COLLECTION_NAME, |
|
positive=[self.id2idx[ref_id]], |
|
limit=top_k or self.MAX_LIMIT, |
|
score_threshold=min_similarity, |
|
) |
|
return [(self.idx2id[entry.id], entry.payload, entry.score) for entry in similar_entries] |
|
|
|
def clear(self) -> None: |
|
vectors_config = self.client.get_collection( |
|
collection_name=self.COLLECTION_NAME |
|
).vectors_config |
|
self.client.delete_collection(collection_name=self.COLLECTION_NAME) |
|
self.client.create_collection( |
|
collection_name=self.COLLECTION_NAME, |
|
vectors_config=vectors_config, |
|
) |
|
self.id2idx.clear() |
|
self.idx2id.clear() |
|
|