nikhil_staging / lilac /embeddings /vector_store_numpy.py
nsthorat's picture
Push
56cce61
raw
history blame
No virus
2.89 kB
"""NumpyVectorStore class for storing vectors in numpy arrays."""
from typing import Iterable, Optional, cast
import numpy as np
import pandas as pd
from typing_extensions import override
from ..schema import VectorKey
from .vector_store import VectorStore
class NumpyVectorStore(VectorStore):
"""Stores vectors as in-memory np arrays."""
_embeddings: np.ndarray
_keys: list[VectorKey]
# Maps a `VectorKey` to a row index in `_embeddings`.
_lookup: pd.Series
@override
def keys(self) -> list[VectorKey]:
return self._keys
@override
def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
if hasattr(self, '_embeddings') or hasattr(self, '_keys'):
raise ValueError('Embeddings already exist in this store. Upsert is not yet supported.')
if len(keys) != embeddings.shape[0]:
raise ValueError(
f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.')
self._keys = keys
# Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster
# than float64.
self._embeddings = embeddings.astype(np.float32)
index = pd.MultiIndex.from_tuples(keys)
row_indices = np.arange(len(self._embeddings), dtype=np.uint32)
self._lookup = pd.Series(row_indices, index=index)
@override
def get(self, keys: Iterable[VectorKey]) -> np.ndarray:
"""Return the embeddings for given keys.
Args:
keys: The keys to return the embeddings for.
Returns
The embeddings for the given keys.
"""
return self._embeddings.take(self._lookup.loc[keys], axis=0)
@override
def topk(self,
query: np.ndarray,
k: int,
key_prefixes: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
if key_prefixes is not None:
# Cast tuples of length 1 to the element itself to avoid a pandas bug.
key_prefixes = cast(
list[VectorKey],
[k[0] if isinstance(k, tuple) and len(k) == 1 else k for k in key_prefixes])
# This uses the hierarchical index (MutliIndex) to do a prefix lookup.
row_indices = self._lookup.loc[key_prefixes]
keys, embeddings = list(row_indices.index), self._embeddings.take(row_indices, axis=0)
else:
keys, embeddings = self._keys, self._embeddings
query = query.astype(embeddings.dtype)
similarities: np.ndarray = np.dot(embeddings, query).reshape(-1)
k = min(k, len(similarities))
# We do a partition + sort only top K to save time: O(n + klogk) instead of O(nlogn).
indices = np.argpartition(similarities, -k)[-k:]
# Indices sorted by value from largest to smallest.
indices = indices[np.argsort(similarities[indices])][::-1]
topk_similarities = similarities[indices]
topk_keys = [keys[idx] for idx in indices]
return list(zip(topk_keys, topk_similarities))