| import pytest |
| import torch |
| from transformers import AutoConfig |
| from mentioned.model import SentenceEncoder, Detector, MentionDetectorCore |
|
|
| |
|
|
|
|
| @pytest.fixture |
| def model_dims(): |
| return {"input_dim": 128, "hidden_dim": 64, "seq_len": 10, "batch_size": 2} |
|
|
|
|
| @pytest.fixture |
| def mock_embeddings(model_dims): |
| |
| return torch.randn( |
| model_dims["batch_size"], |
| model_dims["seq_len"], |
| model_dims["input_dim"], |
| requires_grad=True, |
| ) |
|
|
|
|
| @pytest.mark.parametrize("subwords_per_word", [1, 2, 4]) |
| def test_variable_subword_pooling(subwords_per_word): |
| encoder = SentenceEncoder(model_name="sshleifer/tiny-distilroberta-base") |
| B, Hidden = 1, encoder.dim |
| Total_Subwords = 8 |
| Num_Words = Total_Subwords // subwords_per_word |
|
|
| input_ids = torch.randint(0, 100, (B, Total_Subwords)) |
| attention_mask = torch.ones(B, Total_Subwords) |
|
|
| |
| word_ids = torch.arange(Num_Words).repeat_interleave(subwords_per_word).unsqueeze(0) |
|
|
| word_embs = encoder(input_ids, attention_mask, word_ids) |
|
|
| assert word_embs.shape == (B, Num_Words, Hidden) |
| assert not torch.isnan(word_embs).any() |
|
|
|
|
| def test_detector_projections(model_dims, mock_embeddings): |
| """Verify Detector handles both 3D (starts) and 4D (spans) tensors.""" |
| detector = Detector(model_dims["input_dim"], model_dims["hidden_dim"]) |
|
|
| |
| start_out = detector(mock_embeddings) |
| assert start_out.shape == (model_dims["batch_size"], model_dims["seq_len"], 1) |
|
|
| |
| pair_input = torch.randn(2, 10, 10, model_dims["input_dim"]) |
| pair_out = detector(pair_input) |
| assert pair_out.shape == (2, 10, 10, 1) |
|
|
|
|
| def test_mention_detector_core_logic(model_dims, mock_embeddings): |
| """Verify the N x N pair materialization and concatenation.""" |
| B, N, H = model_dims["batch_size"], model_dims["seq_len"], model_dims["input_dim"] |
|
|
| start_det = Detector(H, 32) |
| |
| end_det = Detector(H * 2, 32) |
|
|
| model = MentionDetectorCore(start_det, end_det) |
| start_logits, end_logits = model(mock_embeddings) |
|
|
| assert start_logits.shape == (B, N) |
| assert end_logits.shape == (B, N, N) |
|
|
| |
| loss = start_logits.sum() + end_logits.sum() |
| loss.backward() |
| assert mock_embeddings.grad is not None |
|
|