File size: 4,017 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Embedding registry."""
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Generator, Iterable, Optional, Union, cast

import numpy as np
from pydantic import StrictStr
from sklearn.preprocessing import normalize

from ..data.dataset_utils import lilac_embedding
from ..schema import Item, RichData
from ..signals.signal import EMBEDDING_KEY, TextEmbeddingSignal, get_signal_by_type
from ..signals.splitters.chunk_splitter import TextChunk
from ..utils import chunks

EmbeddingId = Union[StrictStr, TextEmbeddingSignal]

EmbedFn = Callable[[Iterable[RichData]], np.ndarray]


def get_embed_fn(embedding_name: str) -> EmbedFn:
  """Return a function that returns the embedding matrix for the given embedding signal."""
  embedding_cls = get_signal_by_type(embedding_name, TextEmbeddingSignal)
  embedding = embedding_cls(split=False)
  embedding.setup()

  def _embed_fn(data: Iterable[RichData]) -> np.ndarray:
    items = embedding.compute(data)

    embedding_vectors: list[np.ndarray] = []
    for item in items:
      if not item:
        raise ValueError('Embedding signal returned None.')
      if len(item) != 1:
        raise ValueError(
          f'Embedding signal returned {len(item)} items, but expected 1 since split was False')
      embedding_vector = item[0][EMBEDDING_KEY]
      if not isinstance(embedding_vector, np.ndarray):
        raise ValueError(
          f'Embedding signal returned {type(embedding_vector)} which is not an ndarray.')
      # We use squeeze here because embedding functions can return outer dimensions of 1.
      embedding_vector = embedding_vector.reshape(-1)
      if embedding_vector.ndim != 1:
        raise ValueError(f'Expected embeddings to be 1-dimensional, got {embedding_vector.ndim} '
                         f'with shape {embedding_vector.shape}.')
      embedding_vectors.append(embedding_vector)
    return np.array(embedding_vectors)

  return _embed_fn


def compute_split_embeddings(docs: Iterable[str],
                             batch_size: int,
                             embed_fn: Callable[[list[str]], list[np.ndarray]],
                             split_fn: Optional[Callable[[str], list[TextChunk]]] = None,
                             num_parallel_requests: int = 1) -> Generator[Item, None, None]:
  """Compute text embeddings in batches of chunks, using the provided splitter and embedding fn."""
  pool = ThreadPoolExecutor()

  def _splitter(doc: str) -> list[TextChunk]:
    if doc is None:
      return []
    if split_fn:
      return split_fn(doc)
    else:
      # Return a single chunk that spans the entire document.
      return [(doc, (0, len(doc)))]

  def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
    """Split a batch of documents into chunks and yield them."""
    for i, doc in enumerate(docs):
      chunks = _splitter(doc) or [cast(TextChunk, ('', (0, 0)))]
      for chunk in chunks:
        yield (i, chunk)

  doc_chunks = _flat_split_batch_docs(docs)
  items_to_yield: list[Item] = []
  current_index = 0

  mega_batch_size = batch_size * num_parallel_requests

  for batch in chunks(doc_chunks, mega_batch_size):
    texts = [text for _, (text, _) in batch]
    embeddings: list[np.ndarray] = []
    for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
      embeddings.extend(x)
    matrix = normalize(np.array(embeddings)).astype(np.float16)
    # np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
    embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
    for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
      if index == current_index:
        items_to_yield.append(lilac_embedding(start, end, embedding))
      else:
        yield items_to_yield
        items_to_yield = [lilac_embedding(start, end, embedding)]
        current_index = index

  # Yield the last batch.
  if items_to_yield:
    yield items_to_yield