File size: 4,274 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9a1c18
e4f9cbe
 
 
 
 
 
 
e9a1c18
 
e4f9cbe
 
e9a1c18
e4f9cbe
e9a1c18
 
e4f9cbe
 
 
 
e9a1c18
e4f9cbe
 
 
 
 
 
 
e9a1c18
e4f9cbe
 
 
 
 
 
e9a1c18
e4f9cbe
e9a1c18
 
e4f9cbe
 
 
e9a1c18
 
 
 
e4f9cbe
 
e9a1c18
e4f9cbe
e9a1c18
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Embedding registry."""
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Generator, Iterable, Optional, Union, cast

import numpy as np
from pydantic import StrictStr
from sklearn.preprocessing import normalize

from ..data.dataset_utils import lilac_embedding
from ..schema import Item, RichData
from ..signals.signal import EMBEDDING_KEY, TextEmbeddingSignal, get_signal_by_type
from ..signals.splitters.chunk_splitter import TextChunk
from ..utils import chunks

EmbeddingId = Union[StrictStr, TextEmbeddingSignal]

EmbedFn = Callable[[Iterable[RichData]], np.ndarray]


def get_embed_fn(embedding_name: str) -> EmbedFn:
  """Return a function that returns the embedding matrix for the given embedding signal."""
  embedding_cls = get_signal_by_type(embedding_name, TextEmbeddingSignal)
  embedding = embedding_cls(split=False)
  embedding.setup()

  def _embed_fn(data: Iterable[RichData]) -> np.ndarray:
    items = embedding.compute(data)

    embedding_vectors: list[np.ndarray] = []
    for item in items:
      if not item:
        raise ValueError('Embedding signal returned None.')
      if len(item) != 1:
        raise ValueError(
          f'Embedding signal returned {len(item)} items, but expected 1 since split was False')
      embedding_vector = item[0][EMBEDDING_KEY]
      if not isinstance(embedding_vector, np.ndarray):
        raise ValueError(
          f'Embedding signal returned {type(embedding_vector)} which is not an ndarray.')
      # We use squeeze here because embedding functions can return outer dimensions of 1.
      embedding_vector = embedding_vector.reshape(-1)
      if embedding_vector.ndim != 1:
        raise ValueError(f'Expected embeddings to be 1-dimensional, got {embedding_vector.ndim} '
                         f'with shape {embedding_vector.shape}.')
      embedding_vectors.append(embedding_vector)
    return np.array(embedding_vectors)

  return _embed_fn


def compute_split_embeddings(docs: Iterable[str],
                             batch_size: int,
                             embed_fn: Callable[[list[str]], list[np.ndarray]],
                             split_fn: Optional[Callable[[str], list[TextChunk]]] = None,
                             num_parallel_requests: int = 1) -> Generator[Item, None, None]:
  """Compute text embeddings in batches of chunks, using the provided splitter and embedding fn."""
  pool = ThreadPoolExecutor()

  def _splitter(doc: str) -> list[TextChunk]:
    if not doc:
      return []
    if split_fn:
      return split_fn(doc)
    else:
      # Return a single chunk that spans the entire document.
      return [(doc, (0, len(doc)))]

  num_docs = 0

  def _flat_split_batch_docs(docs: Iterable[str]) -> Generator[tuple[int, TextChunk], None, None]:
    """Split a batch of documents into chunks and yield them."""
    nonlocal num_docs
    for i, doc in enumerate(docs):
      num_docs += 1
      chunks = _splitter(doc)
      for chunk in chunks:
        yield (i, chunk)

  doc_chunks = _flat_split_batch_docs(docs)
  items_to_yield: Optional[list[Item]] = None
  current_index = 0

  mega_batch_size = batch_size * num_parallel_requests

  for batch in chunks(doc_chunks, mega_batch_size):
    texts = [text for _, (text, _) in batch]
    embeddings: list[np.ndarray] = []

    for x in list(pool.map(lambda x: embed_fn(x), chunks(texts, batch_size))):
      embeddings.extend(x)
    matrix = normalize(np.array(embeddings)).astype(np.float16)
    # np.split returns a shallow copy of each embedding so we don't increase the mem footprint.
    embeddings_batch = cast(list[np.ndarray], np.split(matrix, matrix.shape[0]))
    for (index, (_, (start, end))), embedding in zip(batch, embeddings_batch):
      embedding = embedding.reshape(-1)
      if index == current_index:
        if items_to_yield is None:
          items_to_yield = []
        items_to_yield.append(lilac_embedding(start, end, embedding))
      else:
        yield items_to_yield
        current_index += 1
        while current_index < index:
          yield None
          current_index += 1
        items_to_yield = [lilac_embedding(start, end, embedding)]

  while current_index < num_docs:
    yield items_to_yield
    items_to_yield = None
    current_index += 1