File size: 2,583 Bytes
4467900
25fcabc
 
4467900
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25fcabc
 
 
 
 
 
 
 
 
4467900
25fcabc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import abc
from typing import Generic, Hashable, List, Optional, Tuple, TypeVar

T = TypeVar("T", bound=Hashable)


class VectorStore(Generic[T], abc.ABC):
    @abc.abstractmethod
    def save(self, emb_id: T, embedding: List[float]) -> None:
        pass

    @abc.abstractmethod
    def retrieve_similar(
        self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
    ) -> List[Tuple[T, float]]:
        pass

    @abc.abstractmethod
    def __len__(self):
        pass


def vector_norm(vector: List[float]) -> float:
    return sum(x**2 for x in vector) ** 0.5


def cosine_similarity(a: List[float], b: List[float]) -> float:
    return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))


class SimpleVectorStore(VectorStore[T]):
    def __init__(self):
        self.vectors: dict[T, List[float]] = {}
        self._cache = {}
        self._sim = cosine_similarity

    def save(self, emb_id: T, embedding: List[float]) -> None:
        self.vectors[emb_id] = embedding

    def get(self, emb_id: T) -> Optional[List[float]]:
        return self.vectors.get(emb_id)

    def delete(self, emb_id: T) -> None:
        if emb_id in self.vectors:
            del self.vectors[emb_id]
            # remove from cache
            self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}

    def clear(self) -> None:
        self.vectors.clear()
        self._cache.clear()

    def __len__(self):
        return len(self.vectors)

    def retrieve_similar(
        self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
    ) -> List[Tuple[T, float]]:
        ref_embedding = self.get(ref_id)
        if ref_embedding is None:
            raise ValueError(f"Reference embedding '{ref_id}' not found.")

        # calculate similarity to all embeddings
        similarities = {}
        for emb_id, embedding in self.vectors.items():
            if (emb_id, ref_id) not in self._cache:
                # use cosine similarity
                self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding)
            similarities[emb_id] = self._cache[(emb_id, ref_id)]

        # sort by similarity
        similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True)

        if min_similarity is not None:
            similar_entries = [
                (emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity
            ]
        if top_k is not None:
            similar_entries = similar_entries[:top_k]

        return similar_entries