nikhil_staging / src /embeddings /embedding_test.py
nsthorat's picture
Push
e4f9cbe
raw
history blame
No virus
2.51 kB
"""Tests for embedding.py."""
import numpy as np
from ..data.dataset_utils import lilac_embedding
from ..signals.splitters.chunk_splitter import TextChunk
from .embedding import compute_split_embeddings
def char_splitter(text: str) -> list[TextChunk]:
return [(letter, (i, i + 1)) for i, letter in enumerate(text)]
def test_split_and_combine_text_embeddings_batch_across_two_docs() -> None:
docs = ['This is', '123']
batch_size = 3
embed_fn_inputs: list[list[str]] = []
def embed_fn(batch: list[str]) -> list[np.ndarray]:
embed_fn_inputs.append(batch)
return [np.ones(1) for _ in batch]
result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter))
# Each input to embed_fn is a batch of at most 3 letters.
assert embed_fn_inputs == [
['T', 'h', 'i'],
['s', ' ', 'i'],
['s', '1', '2'],
['3'],
]
assert result == [
[
lilac_embedding(0, 1, np.array(1)), # T
lilac_embedding(1, 2, np.array(1)), # h
lilac_embedding(2, 3, np.array(1)), # i
lilac_embedding(3, 4, np.array(1)), # s
lilac_embedding(4, 5, np.array(1)), # ' '
lilac_embedding(5, 6, np.array(1)), # i
lilac_embedding(6, 7, np.array(1)), # s
],
[
lilac_embedding(0, 1, np.array(1)), # 1
lilac_embedding(1, 2, np.array(1)), # 2
lilac_embedding(2, 3, np.array(1)), # 3
],
]
def test_split_and_combine_text_embeddings_no_docs() -> None:
docs: list[str] = []
batch_size = 3
embed_fn_inputs: list[list[str]] = []
def embed_fn(batch: list[str]) -> list[np.ndarray]:
embed_fn_inputs.append(batch)
return [np.ones(1) for _ in batch]
result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter))
assert embed_fn_inputs == []
assert result == []
def test_split_and_combine_text_embeddings_empty_docs() -> None:
docs: list[str] = ['', '', '123']
batch_size = 3
embed_fn_inputs: list[list[str]] = []
def embed_fn(batch: list[str]) -> list[np.ndarray]:
embed_fn_inputs.append(batch)
return [np.ones(1) for _ in batch]
result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter))
assert embed_fn_inputs == [['', '', '1'], ['2', '3']]
assert result == [
[lilac_embedding(0, 0, np.array(1))], # ''
[lilac_embedding(0, 0, np.array(1))], # ''
[
lilac_embedding(0, 1, np.array(1)), # 1
lilac_embedding(1, 2, np.array(1)), # 2
lilac_embedding(2, 3, np.array(1)), # 3
],
]