|
|
from __future__ import annotations |
|
|
|
|
|
import re |
|
|
from unittest.mock import patch |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from sentence_transformers.models import Pooling, Transformer |
|
|
from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseAutoEncoder, SpladePooling |
|
|
from sentence_transformers.sparse_encoder.SparseEncoder import SparseEncoder |
|
|
from tests.sparse_encoder.utils import sparse_allclose |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("texts", "top_k", "expected_shape"), |
|
|
[ |
|
|
|
|
|
(["The weather is nice!"], None, 1), |
|
|
|
|
|
(["The weather is nice!"], 3, 1), |
|
|
|
|
|
("The weather is nice!", 8, 8), |
|
|
|
|
|
(["The weather is nice!", "It's sunny outside"], None, 2), |
|
|
|
|
|
(["The weather is nice!", "It's sunny outside"], 3, 2), |
|
|
], |
|
|
) |
|
|
def test_decode_shapes( |
|
|
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int, expected_shape: int |
|
|
) -> None: |
|
|
model = splade_bert_tiny_model |
|
|
embeddings = model.encode(texts) |
|
|
decoded = model.decode(embeddings, top_k=top_k) |
|
|
|
|
|
assert len(decoded) == expected_shape |
|
|
|
|
|
if isinstance(texts, list): |
|
|
if len(texts) == 1: |
|
|
assert isinstance(decoded[0], tuple) or isinstance(decoded, list) |
|
|
if top_k is not None: |
|
|
assert len(decoded) <= top_k |
|
|
else: |
|
|
assert isinstance(decoded, list) |
|
|
for item in decoded: |
|
|
assert isinstance(item, list) |
|
|
if top_k is not None: |
|
|
assert len(item) <= top_k |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("text", "expected_token_types"), |
|
|
[ |
|
|
("The weather is nice!", str), |
|
|
("It's sunny outside", str), |
|
|
], |
|
|
) |
|
|
def test_decode_token_types(splade_bert_tiny_model: SparseEncoder, text: str, expected_token_types: type) -> None: |
|
|
model = splade_bert_tiny_model |
|
|
embeddings = model.encode(text) |
|
|
decoded = model.decode(embeddings) |
|
|
|
|
|
|
|
|
for token, weight in decoded: |
|
|
assert isinstance(token, expected_token_types) |
|
|
assert isinstance(weight, float) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("text", "top_k"), |
|
|
[ |
|
|
("The weather is nice!", 1), |
|
|
("It's sunny outside", 3), |
|
|
("Hello world", 5), |
|
|
], |
|
|
) |
|
|
def test_decode_top_k_respects_limit(splade_bert_tiny_model: SparseEncoder, text: str, top_k: int) -> None: |
|
|
model = splade_bert_tiny_model |
|
|
embeddings = model.encode([text]) |
|
|
decoded = model.decode(embeddings, top_k=top_k) |
|
|
|
|
|
assert len(decoded) <= top_k |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("texts", "format_type"), |
|
|
[ |
|
|
("The weather is nice!", "1d"), |
|
|
(["The weather is nice!"], "1d"), |
|
|
(["The weather is nice!", "It's sunny outside"], "2d"), |
|
|
], |
|
|
) |
|
|
def test_decode_handles_sparse_dense_inputs( |
|
|
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, format_type: str |
|
|
): |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
embeddings = model.encode(texts) |
|
|
|
|
|
|
|
|
if not embeddings.is_sparse: |
|
|
embeddings_sparse = embeddings.to_sparse() |
|
|
else: |
|
|
embeddings_sparse = embeddings |
|
|
|
|
|
decoded_sparse = model.decode(embeddings_sparse) |
|
|
|
|
|
|
|
|
if embeddings.is_sparse: |
|
|
embeddings_dense = embeddings.to_dense() |
|
|
else: |
|
|
embeddings_dense = embeddings |
|
|
|
|
|
decoded_dense = model.decode(embeddings_dense) |
|
|
|
|
|
|
|
|
if format_type == "1d": |
|
|
assert len(decoded_sparse) == len(decoded_dense) |
|
|
else: |
|
|
assert len(decoded_sparse) == len(decoded_dense) |
|
|
for i in range(len(decoded_sparse)): |
|
|
|
|
|
sorted_sparse = sorted(decoded_sparse[i], key=lambda x: (x[1], x[0]), reverse=True) |
|
|
sorted_dense = sorted(decoded_dense[i], key=lambda x: (x[1], x[0]), reverse=True) |
|
|
assert len(sorted_sparse) == len(sorted_dense) |
|
|
|
|
|
|
|
|
def test_decode_empty_tensor(splade_bert_tiny_model: SparseEncoder) -> None: |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
empty_sparse = torch.sparse_coo_tensor( |
|
|
indices=torch.zeros((2, 0), dtype=torch.long), |
|
|
values=torch.zeros((0,), dtype=torch.float), |
|
|
size=(1, model.get_sentence_embedding_dimension()), |
|
|
) |
|
|
|
|
|
decoded = model.decode(empty_sparse) |
|
|
assert len(decoded) == 0 or (isinstance(decoded, list) and all(not item for item in decoded)) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
"top_k", |
|
|
[None, 5, 1000], |
|
|
) |
|
|
@pytest.mark.parametrize( |
|
|
"texts", |
|
|
[ |
|
|
("The weather is nice!"), |
|
|
(["The weather is nice!"]), |
|
|
(["The weather is nice!", "It's sunny outside", "Hello world"]), |
|
|
(["Short text", "This is a longer text with more words to encode"]), |
|
|
], |
|
|
) |
|
|
def test_decode_returns_sorted_weights( |
|
|
splade_bert_tiny_model: SparseEncoder, texts: list[str] | str, top_k: int | None |
|
|
) -> None: |
|
|
model = splade_bert_tiny_model |
|
|
embeddings = model.encode(texts) |
|
|
decoded = model.decode(embeddings, top_k=top_k) |
|
|
|
|
|
if isinstance(texts, list): |
|
|
for item in decoded: |
|
|
weights = [weight for _, weight in item] |
|
|
assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1)) |
|
|
else: |
|
|
weights = [weight for _, weight in decoded] |
|
|
assert all(weights[i] >= weights[i + 1] for i in range(len(weights) - 1)) |
|
|
|
|
|
|
|
|
def test_inference_free_splade(inference_free_splade_bert_tiny_model: SparseEncoder): |
|
|
model = inference_free_splade_bert_tiny_model |
|
|
dimensionality = model.get_sentence_embedding_dimension() |
|
|
|
|
|
query = "What is the capital of France?" |
|
|
document = "The capital of France is Paris." |
|
|
query_embeddings = model.encode_query(query) |
|
|
document_embeddings = model.encode_document(document) |
|
|
|
|
|
assert query_embeddings.shape == (dimensionality,) |
|
|
assert document_embeddings.shape == (dimensionality,) |
|
|
|
|
|
decoded_query = model.decode(query_embeddings) |
|
|
decoded_document = model.decode(document_embeddings) |
|
|
assert len(decoded_query) == len(model.tokenize(query, task="query")["input_ids"][0]) |
|
|
assert len(decoded_document) >= 50 |
|
|
|
|
|
assert model.max_seq_length == 512 |
|
|
assert model[0].sub_modules["query"][0].max_seq_length == 512 |
|
|
assert model[0].sub_modules["document"][0].max_seq_length == 512 |
|
|
|
|
|
model.max_seq_length = 256 |
|
|
assert model.max_seq_length == 256 |
|
|
assert model[0].sub_modules["query"][0].max_seq_length == 256 |
|
|
assert model[0].sub_modules["document"][0].max_seq_length == 256 |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]]) |
|
|
@pytest.mark.parametrize("prompt_name", [None, "query", "custom"]) |
|
|
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "]) |
|
|
@pytest.mark.parametrize("convert_to_tensor", [True, False]) |
|
|
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) |
|
|
def test_encode_query( |
|
|
splade_bert_tiny_model: SparseEncoder, |
|
|
sentences: str | list[str], |
|
|
prompt_name: str | None, |
|
|
prompt: str | None, |
|
|
convert_to_tensor: bool, |
|
|
convert_to_sparse_tensor: bool, |
|
|
): |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
model.prompts = {"query": "query: ", "custom": "custom: "} |
|
|
|
|
|
|
|
|
with patch.object(model, "encode", autospec=True) as mock_encode: |
|
|
|
|
|
model.encode_query( |
|
|
sentences=sentences, |
|
|
prompt_name=prompt_name, |
|
|
prompt=prompt, |
|
|
batch_size=32, |
|
|
convert_to_tensor=convert_to_tensor, |
|
|
convert_to_sparse_tensor=convert_to_sparse_tensor, |
|
|
) |
|
|
|
|
|
|
|
|
expected_prompt_name = prompt_name if prompt_name else "query" |
|
|
|
|
|
mock_encode.assert_called_once() |
|
|
args, kwargs = mock_encode.call_args |
|
|
|
|
|
|
|
|
assert kwargs["sentences"] == sentences |
|
|
|
|
|
|
|
|
assert kwargs["prompt"] == prompt |
|
|
assert kwargs["prompt_name"] == expected_prompt_name |
|
|
|
|
|
|
|
|
assert kwargs["convert_to_tensor"] == convert_to_tensor |
|
|
assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor |
|
|
assert kwargs["task"] == "query" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("sentences", ["Hello world", ["Hello world", "This is a test"], [], [""]]) |
|
|
@pytest.mark.parametrize("prompt_name", [None, "document", "passage", "corpus", "custom"]) |
|
|
@pytest.mark.parametrize("prompt", [None, "Custom prompt: "]) |
|
|
@pytest.mark.parametrize("convert_to_tensor", [True, False]) |
|
|
@pytest.mark.parametrize("convert_to_sparse_tensor", [True, False]) |
|
|
def test_encode_document( |
|
|
splade_bert_tiny_model: SparseEncoder, |
|
|
sentences: str | list[str], |
|
|
prompt_name: str | None, |
|
|
prompt: str | None, |
|
|
convert_to_tensor: bool, |
|
|
convert_to_sparse_tensor: bool, |
|
|
): |
|
|
|
|
|
model = splade_bert_tiny_model |
|
|
model.prompts = {"document": "document: ", "passage": "passage: ", "corpus": "corpus: ", "custom": "custom: "} |
|
|
|
|
|
|
|
|
with patch.object(model, "encode", autospec=True) as mock_encode: |
|
|
|
|
|
model.encode_document( |
|
|
sentences=sentences, |
|
|
prompt_name=prompt_name, |
|
|
prompt=prompt, |
|
|
batch_size=32, |
|
|
convert_to_tensor=convert_to_tensor, |
|
|
convert_to_sparse_tensor=convert_to_sparse_tensor, |
|
|
) |
|
|
|
|
|
|
|
|
mock_encode.assert_called_once() |
|
|
args, kwargs = mock_encode.call_args |
|
|
|
|
|
expected_prompt_name = prompt_name if prompt_name else "document" |
|
|
|
|
|
|
|
|
assert kwargs["sentences"] == sentences |
|
|
|
|
|
|
|
|
assert kwargs["prompt"] == prompt |
|
|
assert kwargs["prompt_name"] == expected_prompt_name |
|
|
|
|
|
|
|
|
assert kwargs["convert_to_tensor"] == convert_to_tensor |
|
|
assert kwargs["convert_to_sparse_tensor"] == convert_to_sparse_tensor |
|
|
assert kwargs["task"] == "document" |
|
|
|
|
|
|
|
|
def test_encode_document_prompt_priority(splade_bert_tiny_model: SparseEncoder): |
|
|
"""Test that proper prompt priority is respected when multiple options are available""" |
|
|
model = splade_bert_tiny_model |
|
|
model.prompts = { |
|
|
"document": "document: ", |
|
|
"passage": "passage: ", |
|
|
"corpus": "corpus: ", |
|
|
} |
|
|
|
|
|
|
|
|
with patch.object(model, "encode", autospec=True) as mock_encode: |
|
|
|
|
|
model.encode_document("test") |
|
|
|
|
|
|
|
|
args, kwargs = mock_encode.call_args |
|
|
assert kwargs["prompt_name"] == "document" |
|
|
|
|
|
|
|
|
mock_encode.reset_mock() |
|
|
model.prompts = { |
|
|
"passage": "passage: ", |
|
|
"corpus": "corpus: ", |
|
|
} |
|
|
model.encode_document("test") |
|
|
args, kwargs = mock_encode.call_args |
|
|
assert kwargs["prompt_name"] == "passage" |
|
|
|
|
|
|
|
|
mock_encode.reset_mock() |
|
|
model.prompts = { |
|
|
"corpus": "corpus: ", |
|
|
} |
|
|
model.encode_document("test") |
|
|
args, kwargs = mock_encode.call_args |
|
|
assert kwargs["prompt_name"] == "corpus" |
|
|
|
|
|
|
|
|
mock_encode.reset_mock() |
|
|
model.prompts = { |
|
|
"query": "query: ", |
|
|
} |
|
|
model.encode_document("test") |
|
|
args, kwargs = mock_encode.call_args |
|
|
assert kwargs["prompt_name"] is None |
|
|
|
|
|
|
|
|
def test_encode_advanced_parameters(splade_bert_tiny_model: SparseEncoder): |
|
|
"""Test that additional parameters are correctly passed to encode""" |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
|
|
|
with patch.object(model, "encode", autospec=True) as mock_encode: |
|
|
|
|
|
model.encode_query( |
|
|
"test", |
|
|
normalize_embeddings=True, |
|
|
batch_size=64, |
|
|
show_progress_bar=True, |
|
|
max_active_dims=128, |
|
|
chunk_size=10, |
|
|
custom_param="value", |
|
|
) |
|
|
|
|
|
|
|
|
args, kwargs = mock_encode.call_args |
|
|
assert kwargs["normalize_embeddings"] is True |
|
|
assert kwargs["batch_size"] == 64 |
|
|
assert kwargs["show_progress_bar"] is True |
|
|
assert kwargs["max_active_dims"] == 128 |
|
|
assert kwargs["chunk_size"] == 10 |
|
|
assert kwargs["custom_param"] == "value" |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("inputs", ["test sentence", ["test sentence"]]) |
|
|
def test_encode_query_document_vs_encode(splade_bert_tiny_model: SparseEncoder, inputs: str | list[str]): |
|
|
"""Test the actual integration with encode vs encode_query/encode_document""" |
|
|
|
|
|
model = splade_bert_tiny_model |
|
|
model.prompts = {"query": "query: ", "document": "document: "} |
|
|
|
|
|
|
|
|
query_embeddings = model.encode_query(inputs) |
|
|
document_embeddings = model.encode_document(inputs) |
|
|
|
|
|
|
|
|
encode_query_embeddings = model.encode(inputs, prompt_name="query") |
|
|
encode_document_embeddings = model.encode(inputs, prompt_name="document") |
|
|
|
|
|
|
|
|
assert sparse_allclose(query_embeddings, encode_query_embeddings) |
|
|
assert sparse_allclose(document_embeddings, encode_document_embeddings) |
|
|
|
|
|
|
|
|
query_embeddings_without_prompt = model.encode(inputs) |
|
|
document_embeddings_without_prompt = model.encode(inputs) |
|
|
|
|
|
|
|
|
assert not sparse_allclose(query_embeddings_without_prompt, query_embeddings) |
|
|
assert not sparse_allclose(document_embeddings_without_prompt, document_embeddings) |
|
|
|
|
|
|
|
|
def test_default_prompt(splade_bert_tiny_model: SparseEncoder): |
|
|
"""Test that the default prompt is used when no prompt is specified""" |
|
|
model = splade_bert_tiny_model |
|
|
model.prompts = {"query": "query: ", "document": "document: "} |
|
|
model.default_prompt_name = "query" |
|
|
|
|
|
|
|
|
query_embeddings = model.encode_query("test") |
|
|
assert query_embeddings.shape == (model.get_sentence_embedding_dimension(),) |
|
|
|
|
|
|
|
|
document_embeddings = model.encode_document("test") |
|
|
assert document_embeddings.shape == (model.get_sentence_embedding_dimension(),) |
|
|
|
|
|
default_embeddings = model.encode("test") |
|
|
assert default_embeddings.shape == (model.get_sentence_embedding_dimension(),) |
|
|
|
|
|
|
|
|
assert sparse_allclose(query_embeddings, default_embeddings) |
|
|
assert not sparse_allclose(document_embeddings, default_embeddings) |
|
|
|
|
|
|
|
|
model.default_prompt_name = None |
|
|
default_embeddings_no_default = model.encode("test") |
|
|
assert not sparse_allclose(default_embeddings_no_default, default_embeddings) |
|
|
|
|
|
|
|
|
def test_wrong_prompt(splade_bert_tiny_model: SparseEncoder): |
|
|
"""Test that using a wrong prompt raises an error""" |
|
|
model = splade_bert_tiny_model |
|
|
model.prompts = {"query": "query: ", "document": "document: "} |
|
|
|
|
|
for encode_method in [model.encode_query, model.encode_document, model.encode]: |
|
|
with pytest.raises( |
|
|
ValueError, |
|
|
match=re.escape( |
|
|
"Prompt name 'invalid_prompt' not found in the configured prompts dictionary with keys ['query', 'document']." |
|
|
), |
|
|
): |
|
|
encode_method("test", prompt_name="invalid_prompt") |
|
|
|
|
|
|
|
|
def test_max_active_dims_set_init(splade_bert_tiny_model: SparseEncoder, csr_bert_tiny_model: SparseEncoder, tmp_path): |
|
|
splade_bert_tiny_model.save_pretrained(str(tmp_path / "splade_bert_tiny")) |
|
|
csr_bert_tiny_model.save_pretrained(str(tmp_path / "csr_bert_tiny")) |
|
|
|
|
|
|
|
|
loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny")) |
|
|
assert loaded_model.max_active_dims is None |
|
|
loaded_model = SparseEncoder(str(tmp_path / "splade_bert_tiny"), max_active_dims=13) |
|
|
assert loaded_model.max_active_dims == 13 |
|
|
|
|
|
loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny")) |
|
|
assert loaded_model.max_active_dims == 16 |
|
|
loaded_model = SparseEncoder(str(tmp_path / "csr_bert_tiny"), max_active_dims=13) |
|
|
assert loaded_model.max_active_dims == 13 |
|
|
|
|
|
|
|
|
def test_detect_mlm(): |
|
|
model = SparseEncoder("distilbert/distilbert-base-uncased") |
|
|
|
|
|
assert isinstance(model[0], MLMTransformer) |
|
|
assert isinstance(model[1], SpladePooling) |
|
|
|
|
|
|
|
|
def test_default_to_csr(): |
|
|
|
|
|
|
|
|
model = SparseEncoder("prajjwal1/bert-tiny") |
|
|
assert isinstance(model[0], Transformer) |
|
|
assert isinstance(model[1], Pooling) |
|
|
assert isinstance(model[2], SparseAutoEncoder) |
|
|
|
|
|
|
|
|
def test_sparsity(splade_bert_tiny_model: SparseEncoder): |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
|
|
|
embeddings = model.encode_query(["What is the capital of France?", "Who has won the World Cup in 2016?"]) |
|
|
sparsity = model.sparsity(embeddings) |
|
|
assert isinstance(sparsity, dict) |
|
|
assert "active_dims" in sparsity |
|
|
assert "sparsity_ratio" in sparsity |
|
|
assert sparsity["active_dims"] < 100 and sparsity["active_dims"] > 0 |
|
|
assert sparsity["sparsity_ratio"] < 1.0 and sparsity["sparsity_ratio"] >= 0.99 |
|
|
|
|
|
|
|
|
dense_sparsity = model.sparsity(embeddings.to_dense()) |
|
|
assert dense_sparsity == sparsity, "Sparsity should be the same for dense and sparse tensors" |
|
|
|
|
|
|
|
|
sparsity_one = model.sparsity(embeddings[0]) |
|
|
sparsity_two = model.sparsity(embeddings[1]) |
|
|
assert (sparsity_one["active_dims"] + sparsity_two["active_dims"]) / 2 == sparsity["active_dims"] |
|
|
|
|
|
|
|
|
def test_splade_pooling_chunk_size(splade_bert_tiny_model: SparseEncoder): |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
|
|
|
assert model.splade_pooling_chunk_size is None |
|
|
|
|
|
model.splade_pooling_chunk_size = 13 |
|
|
assert model.splade_pooling_chunk_size == 13 |
|
|
assert isinstance(model[1], SpladePooling) |
|
|
assert model[1].chunk_size == 13 |
|
|
|
|
|
|
|
|
def test_intersection(splade_bert_tiny_model: SparseEncoder): |
|
|
model = splade_bert_tiny_model |
|
|
|
|
|
|
|
|
query = "Where can I deposit my money?" |
|
|
document = "I'm sitting by the river." |
|
|
query_embeddings = model.encode_query(query) |
|
|
document_embeddings = model.encode_document(document) |
|
|
query_sparsity = model.sparsity(query_embeddings) |
|
|
document_sparsity = model.sparsity(document_embeddings) |
|
|
|
|
|
|
|
|
intersection = model.intersection(query_embeddings, document_embeddings) |
|
|
assert isinstance(intersection, torch.Tensor) |
|
|
assert intersection.shape == (model.get_sentence_embedding_dimension(),) |
|
|
|
|
|
|
|
|
intersection_sparsity = model.sparsity(intersection) |
|
|
assert ( |
|
|
intersection_sparsity["active_dims"] < query_sparsity["active_dims"] |
|
|
and intersection_sparsity["active_dims"] < document_sparsity["active_dims"] |
|
|
) |
|
|
|
|
|
|
|
|
query = "Who has won the World Cup in 2016?" |
|
|
documents = ["The capital of France is Paris.", "Germany won the World Cup in 2014."] |
|
|
query_embeddings = model.encode_query(query) |
|
|
document_embeddings = model.encode_document(documents) |
|
|
|
|
|
intersection_batch = model.intersection(query_embeddings, document_embeddings) |
|
|
assert isinstance(intersection_batch, torch.Tensor) |
|
|
assert intersection_batch.shape == (len(documents), model.get_sentence_embedding_dimension()) |
|
|
|
|
|
decoded_intersection_batch = model.decode(intersection_batch) |
|
|
assert len(decoded_intersection_batch) == len(documents) |
|
|
|