Spaces:
Runtime error
Runtime error
"""Tests for embedding.py.""" | |
import numpy as np | |
from ..data.dataset_utils import lilac_embedding | |
from ..signals.splitters.chunk_splitter import TextChunk | |
from .embedding import compute_split_embeddings | |
def char_splitter(text: str) -> list[TextChunk]: | |
return [(letter, (i, i + 1)) for i, letter in enumerate(text)] | |
def test_split_and_combine_text_embeddings_batch_across_two_docs() -> None: | |
docs = ['This is', '123'] | |
batch_size = 3 | |
embed_fn_inputs: list[list[str]] = [] | |
def embed_fn(batch: list[str]) -> list[np.ndarray]: | |
embed_fn_inputs.append(batch) | |
return [np.ones(1) for _ in batch] | |
result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) | |
# Each input to embed_fn is a batch of at most 3 letters. | |
assert embed_fn_inputs == [ | |
['T', 'h', 'i'], | |
['s', ' ', 'i'], | |
['s', '1', '2'], | |
['3'], | |
] | |
assert result == [ | |
[ | |
lilac_embedding(0, 1, np.array(1)), # T | |
lilac_embedding(1, 2, np.array(1)), # h | |
lilac_embedding(2, 3, np.array(1)), # i | |
lilac_embedding(3, 4, np.array(1)), # s | |
lilac_embedding(4, 5, np.array(1)), # ' ' | |
lilac_embedding(5, 6, np.array(1)), # i | |
lilac_embedding(6, 7, np.array(1)), # s | |
], | |
[ | |
lilac_embedding(0, 1, np.array(1)), # 1 | |
lilac_embedding(1, 2, np.array(1)), # 2 | |
lilac_embedding(2, 3, np.array(1)), # 3 | |
], | |
] | |
def test_split_and_combine_text_embeddings_no_docs() -> None: | |
docs: list[str] = [] | |
batch_size = 3 | |
embed_fn_inputs: list[list[str]] = [] | |
def embed_fn(batch: list[str]) -> list[np.ndarray]: | |
embed_fn_inputs.append(batch) | |
return [np.ones(1) for _ in batch] | |
result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) | |
assert embed_fn_inputs == [] | |
assert result == [] | |
def test_split_and_combine_text_embeddings_empty_docs() -> None: | |
docs: list[str] = ['', '', '123'] | |
batch_size = 3 | |
embed_fn_inputs: list[list[str]] = [] | |
def embed_fn(batch: list[str]) -> list[np.ndarray]: | |
embed_fn_inputs.append(batch) | |
return [np.ones(1) for _ in batch] | |
result = list(compute_split_embeddings(docs, batch_size, embed_fn, char_splitter)) | |
assert embed_fn_inputs == [['', '', '1'], ['2', '3']] | |
assert result == [ | |
[lilac_embedding(0, 0, np.array(1))], # '' | |
[lilac_embedding(0, 0, np.array(1))], # '' | |
[ | |
lilac_embedding(0, 1, np.array(1)), # 1 | |
lilac_embedding(1, 2, np.array(1)), # 2 | |
lilac_embedding(2, 3, np.array(1)), # 3 | |
], | |
] | |