Spaces:
Sleeping
Sleeping
from typing import Any, Callable, Dict, List, Optional, Sequence, Set | |
import numpy as np | |
import numpy.typing as npt | |
from chromadb.types import ( | |
EmbeddingRecord, | |
VectorEmbeddingRecord, | |
VectorQuery, | |
VectorQueryResult, | |
) | |
from chromadb.utils import distance_functions | |
import logging | |
logger = logging.getLogger(__name__) | |
class BruteForceIndex: | |
"""A lightweight, numpy based brute force index that is used for batches that have not been indexed into hnsw yet. It is not | |
thread safe and callers should ensure that only one thread is accessing it at a time. | |
""" | |
id_to_index: Dict[str, int] | |
index_to_id: Dict[int, str] | |
id_to_seq_id: Dict[str, int] | |
deleted_ids: Set[str] | |
free_indices: List[int] | |
size: int | |
dimensionality: int | |
distance_fn: Callable[[npt.NDArray[Any], npt.NDArray[Any]], float] | |
vectors: npt.NDArray[Any] | |
def __init__(self, size: int, dimensionality: int, space: str = "l2"): | |
if space == "l2": | |
self.distance_fn = distance_functions.l2 | |
elif space == "ip": | |
self.distance_fn = distance_functions.ip | |
elif space == "cosine": | |
self.distance_fn = distance_functions.cosine | |
else: | |
raise Exception(f"Unknown distance function: {space}") | |
self.id_to_index = {} | |
self.index_to_id = {} | |
self.id_to_seq_id = {} | |
self.deleted_ids = set() | |
self.free_indices = list(range(size)) | |
self.size = size | |
self.dimensionality = dimensionality | |
self.vectors = np.zeros((size, dimensionality)) | |
def __len__(self) -> int: | |
return len(self.id_to_index) | |
def clear(self) -> None: | |
self.id_to_index = {} | |
self.index_to_id = {} | |
self.id_to_seq_id = {} | |
self.deleted_ids.clear() | |
self.free_indices = list(range(self.size)) | |
self.vectors.fill(0) | |
def upsert(self, records: List[EmbeddingRecord]) -> None: | |
if len(records) + len(self) > self.size: | |
raise Exception( | |
"Index with capacity {} and {} current entries cannot add {} records".format( | |
self.size, len(self), len(records) | |
) | |
) | |
for i, record in enumerate(records): | |
id = record["id"] | |
vector = record["embedding"] | |
self.id_to_seq_id[id] = record["seq_id"] | |
if id in self.deleted_ids: | |
self.deleted_ids.remove(id) | |
# TODO: It may be faster to use multi-index selection on the vectors array | |
if id in self.id_to_index: | |
# Update | |
index = self.id_to_index[id] | |
self.vectors[index] = vector | |
else: | |
# Add | |
next_index = self.free_indices.pop() | |
self.id_to_index[id] = next_index | |
self.index_to_id[next_index] = id | |
self.vectors[next_index] = vector | |
def delete(self, records: List[EmbeddingRecord]) -> None: | |
for record in records: | |
id = record["id"] | |
if id in self.id_to_index: | |
index = self.id_to_index[id] | |
self.deleted_ids.add(id) | |
del self.id_to_index[id] | |
del self.index_to_id[index] | |
del self.id_to_seq_id[id] | |
self.vectors[index].fill(np.NaN) | |
self.free_indices.append(index) | |
else: | |
logger.warning(f"Delete of nonexisting embedding ID: {id}") | |
def has_id(self, id: str) -> bool: | |
"""Returns whether the index contains the given ID""" | |
return id in self.id_to_index and id not in self.deleted_ids | |
def get_vectors( | |
self, ids: Optional[Sequence[str]] = None | |
) -> Sequence[VectorEmbeddingRecord]: | |
target_ids = ids or self.id_to_index.keys() | |
return [ | |
VectorEmbeddingRecord( | |
id=id, | |
embedding=self.vectors[self.id_to_index[id]].tolist(), | |
seq_id=self.id_to_seq_id[id], | |
) | |
for id in target_ids | |
] | |
def query(self, query: VectorQuery) -> Sequence[Sequence[VectorQueryResult]]: | |
np_query = np.array(query["vectors"]) | |
allowed_ids = ( | |
None if query["allowed_ids"] is None else set(query["allowed_ids"]) | |
) | |
distances = np.apply_along_axis( | |
lambda query: np.apply_along_axis(self.distance_fn, 1, self.vectors, query), | |
1, | |
np_query, | |
) | |
indices = np.argsort(distances).tolist() | |
# Filter out deleted labels | |
filtered_results = [] | |
for i, index_list in enumerate(indices): | |
curr_results = [] | |
for j in index_list: | |
# If the index is in the index_to_id map, then it has been added | |
if j in self.index_to_id: | |
id = self.index_to_id[j] | |
if id not in self.deleted_ids and ( | |
allowed_ids is None or id in allowed_ids | |
): | |
curr_results.append( | |
VectorQueryResult( | |
id=id, | |
distance=distances[i][j].item(), | |
seq_id=self.id_to_seq_id[id], | |
embedding=self.vectors[j].tolist(), | |
) | |
) | |
filtered_results.append(curr_results) | |
return filtered_results | |