File size: 2,886 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""NumpyVectorStore class for storing vectors in numpy arrays."""

from typing import Iterable, Optional, cast

import numpy as np
import pandas as pd
from typing_extensions import override

from ..schema import VectorKey
from .vector_store import VectorStore


class NumpyVectorStore(VectorStore):
  """Stores vectors as in-memory np arrays."""
  _embeddings: np.ndarray
  _keys: list[VectorKey]
  # Maps a `VectorKey` to a row index in `_embeddings`.
  _lookup: pd.Series

  @override
  def keys(self) -> list[VectorKey]:
    return self._keys

  @override
  def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
    if hasattr(self, '_embeddings') or hasattr(self, '_keys'):
      raise ValueError('Embeddings already exist in this store. Upsert is not yet supported.')

    if len(keys) != embeddings.shape[0]:
      raise ValueError(
        f'Length of keys ({len(keys)}) does not match number of embeddings {embeddings.shape[0]}.')

    self._keys = keys
    # Cast to float32 since dot product with float32 is 40-50x faster than float16 and 2.5x faster
    # than float64.
    self._embeddings = embeddings.astype(np.float32)

    index = pd.MultiIndex.from_tuples(keys)
    row_indices = np.arange(len(self._embeddings), dtype=np.uint32)
    self._lookup = pd.Series(row_indices, index=index)

  @override
  def get(self, keys: Iterable[VectorKey]) -> np.ndarray:
    """Return the embeddings for given keys.

    Args:
      keys: The keys to return the embeddings for.

    Returns
      The embeddings for the given keys.
    """
    return self._embeddings.take(self._lookup.loc[keys], axis=0)

  @override
  def topk(self,
           query: np.ndarray,
           k: int,
           key_prefixes: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]:
    if key_prefixes is not None:
      # Cast tuples of length 1 to the element itself to avoid a pandas bug.
      key_prefixes = cast(
        list[VectorKey],
        [k[0] if isinstance(k, tuple) and len(k) == 1 else k for k in key_prefixes])
      # This uses the hierarchical index (MutliIndex) to do a prefix lookup.
      row_indices = self._lookup.loc[key_prefixes]
      keys, embeddings = list(row_indices.index), self._embeddings.take(row_indices, axis=0)
    else:
      keys, embeddings = self._keys, self._embeddings

    query = query.astype(embeddings.dtype)
    similarities: np.ndarray = np.dot(embeddings, query).reshape(-1)
    k = min(k, len(similarities))

    # We do a partition + sort only top K to save time: O(n + klogk) instead of O(nlogn).
    indices = np.argpartition(similarities, -k)[-k:]
    # Indices sorted by value from largest to smallest.
    indices = indices[np.argsort(similarities[indices])][::-1]

    topk_similarities = similarities[indices]
    topk_keys = [keys[idx] for idx in indices]
    return list(zip(topk_keys, topk_similarities))