sam-pointer-bart-base-v0.3 / vector_store.py
ArneBinder's picture
Upload 7 files
4467900 verified
raw
history blame
No virus
2.58 kB
import abc
from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
T = TypeVar("T", bound=Hashable)
class VectorStore(Generic[T], abc.ABC):
@abc.abstractmethod
def save(self, emb_id: T, embedding: List[float]) -> None:
pass
@abc.abstractmethod
def retrieve_similar(
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[T, float]]:
pass
@abc.abstractmethod
def __len__(self):
pass
def vector_norm(vector: List[float]) -> float:
return sum(x**2 for x in vector) ** 0.5
def cosine_similarity(a: List[float], b: List[float]) -> float:
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
class SimpleVectorStore(VectorStore[T]):
def __init__(self):
self.vectors: dict[T, List[float]] = {}
self._cache = {}
self._sim = cosine_similarity
def save(self, emb_id: T, embedding: List[float]) -> None:
self.vectors[emb_id] = embedding
def get(self, emb_id: T) -> Optional[List[float]]:
return self.vectors.get(emb_id)
def delete(self, emb_id: T) -> None:
if emb_id in self.vectors:
del self.vectors[emb_id]
# remove from cache
self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}
def clear(self) -> None:
self.vectors.clear()
self._cache.clear()
def __len__(self):
return len(self.vectors)
def retrieve_similar(
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[T, float]]:
ref_embedding = self.get(ref_id)
if ref_embedding is None:
raise ValueError(f"Reference embedding '{ref_id}' not found.")
# calculate similarity to all embeddings
similarities = {}
for emb_id, embedding in self.vectors.items():
if (emb_id, ref_id) not in self._cache:
# use cosine similarity
self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding)
similarities[emb_id] = self._cache[(emb_id, ref_id)]
# sort by similarity
similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
if min_similarity is not None:
similar_entries = [
(emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity
]
if top_k is not None:
similar_entries = similar_entries[:top_k]
return similar_entries