File size: 2,510 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Tests for embedding.py."""

import numpy as np

from ..data.dataset_utils import lilac_embedding
from ..signals.splitters.chunk_splitter import TextChunk
from .embedding import compute_split_embeddings


def char_splitter(text: str) -> list[TextChunk]:
  return [(letter, (i, i + 1)) for i, letter in enumerate(text)]


def test_split_and_combine_text_embeddings_batch_across_two_docs() -> None:
  docs = ['This is', '123']
  batch_size = 3

  embed_fn_inputs: list[list[str]] = []

  def embed_fn(batch: list[str]) -> list[np.ndarray]:
    embed_fn_inputs.append(batch)
    return [np.ones(1) for _ in batch]

  result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter))

  # Each input to embed_fn is a batch of at most 3 letters.
  assert embed_fn_inputs == [
    ['T', 'h', 'i'],
    ['s', ' ', 'i'],
    ['s', '1', '2'],
    ['3'],
  ]

  assert result == [
    [
      lilac_embedding(0, 1, np.array(1)),  # T
      lilac_embedding(1, 2, np.array(1)),  # h
      lilac_embedding(2, 3, np.array(1)),  # i
      lilac_embedding(3, 4, np.array(1)),  # s
      lilac_embedding(4, 5, np.array(1)),  # ' '
      lilac_embedding(5, 6, np.array(1)),  # i
      lilac_embedding(6, 7, np.array(1)),  # s
    ],
    [
      lilac_embedding(0, 1, np.array(1)),  # 1
      lilac_embedding(1, 2, np.array(1)),  # 2
      lilac_embedding(2, 3, np.array(1)),  # 3
    ],
  ]


def test_split_and_combine_text_embeddings_no_docs() -> None:
  docs: list[str] = []
  batch_size = 3

  embed_fn_inputs: list[list[str]] = []

  def embed_fn(batch: list[str]) -> list[np.ndarray]:
    embed_fn_inputs.append(batch)
    return [np.ones(1) for _ in batch]

  result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter))
  assert embed_fn_inputs == []
  assert result == []


def test_split_and_combine_text_embeddings_empty_docs() -> None:
  docs: list[str] = ['', '', '123']
  batch_size = 3

  embed_fn_inputs: list[list[str]] = []

  def embed_fn(batch: list[str]) -> list[np.ndarray]:
    embed_fn_inputs.append(batch)
    return [np.ones(1) for _ in batch]

  result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter))
  assert embed_fn_inputs == [['', '', '1'], ['2', '3']]
  assert result == [
    [lilac_embedding(0, 0, np.array(1))],  # ''
    [lilac_embedding(0, 0, np.array(1))],  # ''
    [
      lilac_embedding(0, 1, np.array(1)),  # 1
      lilac_embedding(1, 2, np.array(1)),  # 2
      lilac_embedding(2, 3, np.array(1)),  # 3
    ],
  ]