File size: 5,521 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from typing import Any, Callable, Dict, List, Optional, Sequence, Set
import numpy as np
import numpy.typing as npt
from chromadb.types import (
    EmbeddingRecord,
    VectorEmbeddingRecord,
    VectorQuery,
    VectorQueryResult,
)

from chromadb.utils import distance_functions
import logging

logger = logging.getLogger(__name__)


class BruteForceIndex:
    """A lightweight, numpy based brute force index that is used for batches that have not been indexed into hnsw yet. It is not
    thread safe and callers should ensure that only one thread is accessing it at a time.
    """

    id_to_index: Dict[str, int]
    index_to_id: Dict[int, str]
    id_to_seq_id: Dict[str, int]
    deleted_ids: Set[str]
    free_indices: List[int]
    size: int
    dimensionality: int
    distance_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any]], float]
    vectors: npt.NDArray[Any]

    def __init__(self, size: int, dimensionality: int, space: str = "l2"):
        if space == "l2":
            self.distance_fn = distance_functions.l2
        elif space == "ip":
            self.distance_fn = distance_functions.ip
        elif space == "cosine":
            self.distance_fn = distance_functions.cosine
        else:
            raise Exception(f"Unknown distance function: {space}")

        self.id_to_index = {}
        self.index_to_id = {}
        self.id_to_seq_id = {}
        self.deleted_ids = set()
        self.free_indices = list(range(size))
        self.size = size
        self.dimensionality = dimensionality
        self.vectors = np.zeros((size, dimensionality))

    def __len__(self) -> int:
        return len(self.id_to_index)

    def clear(self) -> None:
        self.id_to_index = {}
        self.index_to_id = {}
        self.id_to_seq_id = {}
        self.deleted_ids.clear()
        self.free_indices = list(range(self.size))
        self.vectors.fill(0)

    def upsert(self, records: List[EmbeddingRecord]) -> None:
        if len(records) + len(self) > self.size:
            raise Exception(
                "Index with capacity {} and {} current entries cannot add {} records".format(
                    self.size, len(self), len(records)
                )
            )

        for i, record in enumerate(records):
            id = record["id"]
            vector = record["embedding"]
            self.id_to_seq_id[id] = record["seq_id"]
            if id in self.deleted_ids:
                self.deleted_ids.remove(id)

            # TODO: It may be faster to use multi-index selection on the vectors array
            if id in self.id_to_index:
                # Update
                index = self.id_to_index[id]
                self.vectors[index] = vector
            else:
                # Add
                next_index = self.free_indices.pop()
                self.id_to_index[id] = next_index
                self.index_to_id[next_index] = id
                self.vectors[next_index] = vector

    def delete(self, records: List[EmbeddingRecord]) -> None:
        for record in records:
            id = record["id"]
            if id in self.id_to_index:
                index = self.id_to_index[id]
                self.deleted_ids.add(id)
                del self.id_to_index[id]
                del self.index_to_id[index]
                del self.id_to_seq_id[id]
                self.vectors[index].fill(np.NaN)
                self.free_indices.append(index)
            else:
                logger.warning(f"Delete of nonexisting embedding ID: {id}")

    def has_id(self, id: str) -> bool:
        """Returns whether the index contains the given ID"""
        return id in self.id_to_index and id not in self.deleted_ids

    def get_vectors(
        self, ids: Optional[Sequence[str]] = None
    ) -> Sequence[VectorEmbeddingRecord]:
        target_ids = ids or self.id_to_index.keys()

        return [
            VectorEmbeddingRecord(
                id=id,
                embedding=self.vectors[self.id_to_index[id]].tolist(),
                seq_id=self.id_to_seq_id[id],
            )
            for id in target_ids
        ]

    def query(self, query: VectorQuery) -> Sequence[Sequence[VectorQueryResult]]:
        np_query = np.array(query["vectors"])
        allowed_ids = (
            None if query["allowed_ids"] is None else set(query["allowed_ids"])
        )
        distances = np.apply_along_axis(
            lambda query: np.apply_along_axis(self.distance_fn, 1, self.vectors, query),
            1,
            np_query,
        )

        indices = np.argsort(distances).tolist()
        # Filter out deleted labels
        filtered_results = []
        for i, index_list in enumerate(indices):
            curr_results = []
            for j in index_list:
                # If the index is in the index_to_id map, then it has been added
                if j in self.index_to_id:
                    id = self.index_to_id[j]
                    if id not in self.deleted_ids and (
                        allowed_ids is None or id in allowed_ids
                    ):
                        curr_results.append(
                            VectorQueryResult(
                                id=id,
                                distance=distances[i][j].item(),
                                seq_id=self.id_to_seq_id[id],
                                embedding=self.vectors[j].tolist(),
                            )
                        )
            filtered_results.append(curr_results)
        return filtered_results