Spaces:
Runtime error
Runtime error
"""NumpyVectorStore class for storing vectors in numpy arrays.""" | |
from typing import Iterable, Optional, cast | |
import numpy as np | |
import pandas as pd | |
from typing_extensions import override | |
from ..schema import VectorKey | |
from .vector_store import VectorStore | |
class NumpyVectorStore(VectorStore): | |
"""Stores vectors as in-memory np arrays.""" | |
_embeddings: np.ndarray | |
_keys: list[VectorKey] | |
# Maps a `VectorKey` to a row index in `_embeddings`. | |
_lookup: pd.Series | |
def keys(self) -> list[VectorKey]: | |
return self._keys | |
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: | |
if hasattr(self, '_embeddings') or hasattr(self, '_keys'): | |
raise ValueError('Embeddings already exist in this store. Upsert is not yet supported.') | |
if len(keys) != embeddings.shape[0]: | |
raise ValueError( | |
f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.') | |
self._keys = keys | |
# Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster | |
# than float64. | |
self._embeddings = embeddings.astype(np.float32) | |
index = pd.MultiIndex.from_tuples(keys) | |
row_indices = np.arange(len(self._embeddings), dtype=np.uint32) | |
self._lookup = pd.Series(row_indices, index=index) | |
def get(self, keys: Iterable[VectorKey]) -> np.ndarray: | |
"""Return the embeddings for given keys. | |
Args: | |
keys: The keys to return the embeddings for. | |
Returns | |
The embeddings for the given keys. | |
""" | |
return self._embeddings.take(self._lookup.loc[keys], axis=0) | |
def topk(self, | |
query: np.ndarray, | |
k: int, | |
key_prefixes: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]: | |
if key_prefixes is not None: | |
# Cast tuples of length 1 to the element itself to avoid a pandas bug. | |
key_prefixes = cast( | |
list[VectorKey], | |
[k[0] if isinstance(k, tuple) and len(k) == 1 else k for k in key_prefixes]) | |
# This uses the hierarchical index (MutliIndex) to do a prefix lookup. | |
row_indices = self._lookup.loc[key_prefixes] | |
keys, embeddings = list(row_indices.index), self._embeddings.take(row_indices, axis=0) | |
else: | |
keys, embeddings = self._keys, self._embeddings | |
query = query.astype(embeddings.dtype) | |
similarities: np.ndarray = np.dot(embeddings, query).reshape(-1) | |
k = min(k, len(similarities)) | |
# We do a partition + sort only top K to save time: O(n + klogk) instead of O(nlogn). | |
indices = np.argpartition(similarities, -k)[-k:] | |
# Indices sorted by value from largest to smallest. | |
indices = indices[np.argsort(similarities[indices])][::-1] | |
topk_similarities = similarities[indices] | |
topk_keys = [keys[idx] for idx in indices] | |
return list(zip(topk_keys, topk_similarities)) | |