"""Interface for storing vectors.""" import abc import os import pickle from typing import Iterable, Optional, Type import numpy as np from ..schema import SpanVector, VectorKey from ..utils import open_file class VectorStore(abc.ABC): """Interface for storing and retrieving vectors.""" # The global name of the vector store. name: str @abc.abstractmethod def save(self, base_path: str) -> None: """Save the store to disk.""" pass @abc.abstractmethod def load(self, base_path: str) -> None: """Load the store from disk.""" pass @abc.abstractmethod def size(self) -> int: """Return the number of vectors in the store.""" pass @abc.abstractmethod def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: """Add or edit the given keyed embeddings to the store. If the keys already exist they will be overwritten, acting as an "upsert". Args: keys: The keys to add the embeddings for. embeddings: The embeddings to add. This should be a 2D matrix with the same length as keys. """ pass @abc.abstractmethod def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: """Return the embeddings for given keys. Args: keys: The keys to return the embeddings for. If None, return all embeddings. Returns The embeddings for the given keys. """ pass def topk(self, query: np.ndarray, k: int, keys: Optional[Iterable[VectorKey]] = None) -> list[tuple[VectorKey, float]]: """Return the top k most similar vectors. Args: query: The query vector. k: The number of results to return. keys: Optional keys to restrict the search to. Returns A list of (key, score) tuples. """ raise NotImplementedError PathKey = VectorKey _SPANS_PICKLE_NAME = 'spans.pkl' class VectorDBIndex: """Stores and retrives span vectors. This wraps a regular vector store by adding a mapping from path keys, such as (rowid1, 0), to span keys, such as (rowid1, 0, 0), which denotes the first span in the (rowid1, 0) document. """ def __init__(self, vector_store: str) -> None: self._vector_store: VectorStore = get_vector_store_cls(vector_store)() # Map a path key to spans for that path. self._id_to_spans: dict[PathKey, list[tuple[int, int]]] = {} def load(self, base_path: str) -> None: """Load the vector index from disk.""" assert not self._id_to_spans, 'Cannot load into a non-empty index.' with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'rb') as f: self._id_to_spans.update(pickle.load(f)) self._vector_store.load(os.path.join(base_path, self._vector_store.name)) def save(self, base_path: str) -> None: """Save the vector index to disk.""" assert self._id_to_spans, 'Cannot save an empty index.' with open_file(os.path.join(base_path, _SPANS_PICKLE_NAME), 'wb') as f: pickle.dump(list(self._id_to_spans.items()), f) self._vector_store.save(os.path.join(base_path, self._vector_store.name)) def add(self, all_spans: list[tuple[PathKey, list[tuple[int, int]]]], embeddings: np.ndarray) -> None: """Add the given spans and embeddings. Args: all_spans: The spans to initialize the index with. embeddings: The embeddings to initialize the index with. """ assert not self._id_to_spans, 'Cannot add to a non-empty index.' self._id_to_spans.update(all_spans) vector_keys = [(*path_key, i) for path_key, spans in all_spans for i in range(len(spans))] assert len(vector_keys) == len(embeddings), ( f'Number of spans ({len(vector_keys)}) and embeddings ({len(embeddings)}) must match.') self._vector_store.add(vector_keys, embeddings) def get_vector_store(self) -> VectorStore: """Return the underlying vector store.""" return self._vector_store def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]: """Return the spans with vectors for each key in `keys`. Args: keys: The keys to return the vectors for. Returns The span vectors for the given keys. """ all_spans: list[list[tuple[int, int]]] = [] vector_keys: list[VectorKey] = [] for path_key in keys: spans = self._id_to_spans[path_key] all_spans.append(spans) vector_keys.extend([(*path_key, i) for i in range(len(spans))]) all_vectors = self._vector_store.get(vector_keys) offset = 0 for spans in all_spans: vectors = all_vectors[offset:offset + len(spans)] yield [{'span': span, 'vector': vector} for span, vector in zip(spans, vectors)] offset += len(spans) def topk(self, query: np.ndarray, k: int, path_keys: Optional[Iterable[PathKey]] = None) -> list[tuple[PathKey, float]]: """Return the top k most similar vectors. Args: query: The query vector. k: The number of results to return. path_keys: Optional key prefixes to restrict the search to. Returns A list of (key, score) tuples. """ span_keys: Optional[list[VectorKey]] = None if path_keys is not None: span_keys = [ (*path_key, i) for path_key in path_keys for i in range(len(self._id_to_spans[path_key])) ] span_k = k path_key_scores: dict[PathKey, float] = {} total_num_span_keys = self._vector_store.size() while (len(path_key_scores) < k and span_k < total_num_span_keys and (not span_keys or span_k < len(span_keys))): span_k += k vector_key_scores = self._vector_store.topk(query, span_k, span_keys) for (*path_key_list, _), score in vector_key_scores: path_key = tuple(path_key_list) if path_key not in path_key_scores: path_key_scores[path_key] = score return list(path_key_scores.items())[:k] VECTOR_STORE_REGISTRY: dict[str, Type[VectorStore]] = {} def register_vector_store(vector_store_cls: Type[VectorStore]) -> None: """Register a vector store in the global registry.""" if vector_store_cls.name in VECTOR_STORE_REGISTRY: raise ValueError(f'Vector store "{vector_store_cls.name}" has already been registered!') VECTOR_STORE_REGISTRY[vector_store_cls.name] = vector_store_cls def get_vector_store_cls(vector_store_name: str) -> Type[VectorStore]: """Return a registered vector store given the name in the registry.""" return VECTOR_STORE_REGISTRY[vector_store_name] def clear_vector_store_registry() -> None: """Clear the vector store registry.""" VECTOR_STORE_REGISTRY.clear()