nsthorat's picture
Push
56cce61
raw
history blame
No virus
4.27 kB
"""Embedding registry."""
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Generator, Iterable, Optional, Union, cast
import numpy as np
from pydantic import StrictStr
from sklearn.preprocessing import normalize
from ..data.dataset_utils import lilac_embedding
from ..schema import Item, RichData
from ..signals.signal import EMBEDDING_KEY, TextEmbeddingSignal, get_signal_by_type
from ..signals.splitters.chunk_splitter import TextChunk
from ..utils import chunks
EmbeddingId = Union[StrictStr, TextEmbeddingSignal]
EmbedFn = Callable[[Iterable[RichData]], np.ndarray]
def get_embed_fn(embedding_name: str) -> EmbedFn:
"""Return a function that returns the embedding matrix for the given embedding signal."""
embedding_cls = get_signal_by_type(embedding_name, TextEmbeddingSignal)
embedding = embedding_cls(split=False)
embedding.setup()
def _embed_fn(data: Iterable[RichData]) -> np.ndarray:
items = embedding.compute(data)
embedding_vectors: list[np.ndarray] = []
for item in items:
if not item:
raise ValueError('Embedding signal returned None.')
if len(item) != 1:
raise ValueError(
f'Embedding signal returned {len(item)} items, but expected 1 since split was False')
embedding_vector = item[0][EMBEDDING_KEY]
if not isinstance(embedding_vector, np.ndarray):
raise ValueError(
f'Embedding signal returned {type(embedding_vector)} which is not an ndarray.')
# We use squeeze here because embedding functions can return outer dimensions of 1.
embedding_vector = embedding_vector.reshape(-1)
if embedding_vector.ndim != 1:
raise ValueError(f'Expected embeddings to be 1-dimensional, got {embedding_vector.ndim} '
f'with shape {embedding_vector.shape}.')
embedding_vectors.append(embedding_vector)
return np.array(embedding_vectors)
return _embed_fn
def compute_split_embeddings(docs: Iterable[str],
batch_size: int,
embed_fn: Callable[[list[str]], list[np.ndarray]],
split_fn: Optional[Callable[[str], list[TextChunk]]] = None,
num_parallel_requests: int = 1) -> Generator[Item, None, None]:
"""Compute text embeddings in batches of chunks, using the provided splitter and embedding fn."""
pool = ThreadPoolExecutor()
def _splitter(doc: str) -> list[TextChunk]:
if not doc:
return []
if split_fn:
return split_fn(doc)
else:
# Return a single chunk that spans the entire document.
return [(doc, (0, len(doc)))]
num_docs = 0
def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
"""Split a batch of documents into chunks and yield them."""
nonlocal num_docs
for i, doc in enumerate(docs):
num_docs += 1
chunks = _splitter(doc)
for chunk in chunks:
yield (i, chunk)
doc_chunks = _flat_split_batch_docs(docs)
items_to_yield: Optional[list[Item]] = None
current_index = 0
mega_batch_size = batch_size * num_parallel_requests
for batch in chunks(doc_chunks, mega_batch_size):
texts = [text for _, (text, _) in batch]
embeddings: list[np.ndarray] = []
for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
embeddings.extend(x)
matrix = normalize(np.array(embeddings)).astype(np.float16)
# np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
embedding = embedding.reshape(-1)
if index == current_index:
if items_to_yield is None:
items_to_yield = []
items_to_yield.append(lilac_embedding(start, end, embedding))
else:
yield items_to_yield
current_index += 1
while current_index < index:
yield None
current_index += 1
items_to_yield = [lilac_embedding(start, end, embedding)]
while current_index < num_docs:
yield items_to_yield
items_to_yield = None
current_index += 1