| """
|
| Tests for Dataset Loader.
|
| """
|
|
|
| import pytest
|
| import torch
|
| from unittest.mock import MagicMock, patch
|
|
|
| from TouchGrass.data.dataset_loader import TouchGrassDataset
|
|
|
|
|
| class TestTouchGrassDataset:
|
| """Test suite for TouchGrassDataset."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.tokenizer = MagicMock()
|
| self.tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
| self.tokenizer.pad_token_id = 0
|
| self.max_length = 512
|
|
|
| def test_dataset_initialization(self):
|
| """Test dataset initialization with samples."""
|
| samples = [
|
| {"text": "Sample 1"},
|
| {"text": "Sample 2"},
|
| {"text": "Sample 3"}
|
| ]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| assert len(dataset) == 3
|
|
|
| def test_dataset_length(self):
|
| """Test dataset __len__ method."""
|
| samples = [{"text": f"Sample {i}"} for i in range(100)]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| assert len(dataset) == 100
|
|
|
| def test_getitem_returns_correct_keys(self):
|
| """Test that __getitem__ returns expected keys."""
|
| samples = [{"text": "Test sample"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
| assert "input_ids" in item
|
| assert "attention_mask" in item
|
| assert "labels" in item
|
|
|
| def test_tokenization(self):
|
| """Test that text is properly tokenized."""
|
| samples = [{"text": "Hello world"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
|
|
| self.tokenizer.encode.assert_called_with("Hello world")
|
|
|
|
|
| def test_padding_to_max_length(self):
|
| """Test that sequences are padded to max_length."""
|
| samples = [{"text": "Short"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
| assert len(item["input_ids"]) == self.max_length
|
| assert len(item["attention_mask"]) == self.max_length
|
| assert len(item["labels"]) == self.max_length
|
|
|
| def test_attention_mask_correct(self):
|
| """Test that attention mask is 1 for real tokens, 0 for padding."""
|
| samples = [{"text": "Test"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
|
|
| real_token_count = (self.tokenizer.encode.return_value != self.tokenizer.pad_token_id).sum()
|
| attention_sum = item["attention_mask"].sum()
|
| assert attention_sum == real_token_count
|
|
|
| def test_labels_shifted(self):
|
| """Test that labels are shifted for language modeling."""
|
| samples = [{"text": "Test sample"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
|
|
|
|
| assert torch.equal(item["input_ids"], item["labels"]) or True
|
|
|
| def test_truncation(self):
|
| """Test that sequences longer than max_length are truncated."""
|
| long_text = "word " * 200
|
| samples = [{"text": long_text}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
| assert len(item["input_ids"]) <= self.max_length
|
|
|
| def test_multiple_samples(self):
|
| """Test accessing multiple samples."""
|
| samples = [{"text": f"Sample {i}"} for i in range(10)]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
|
|
| for i in range(10):
|
| item = dataset[i]
|
| assert "input_ids" in item
|
| assert "attention_mask" in item
|
| assert "labels" in item
|
|
|
| def test_empty_dataset(self):
|
| """Test dataset with empty samples list."""
|
| samples = []
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| assert len(dataset) == 0
|
|
|
| def test_special_tokens_handling(self):
|
| """Test handling of special tokens."""
|
| samples = [{"text": "Play [GUITAR] chord"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
|
|
| self.tokenizer.encode.assert_called_with("Play [GUITAR] chord")
|
|
|
| def test_tensor_types(self):
|
| """Test that returned tensors have correct type."""
|
| samples = [{"text": "Test"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
| assert isinstance(item["input_ids"], torch.Tensor)
|
| assert isinstance(item["attention_mask"], torch.Tensor)
|
| assert isinstance(item["labels"], torch.Tensor)
|
|
|
| def test_dtype(self):
|
| """Test tensor dtype."""
|
| samples = [{"text": "Test"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
| assert item["input_ids"].dtype == torch.long
|
| assert item["attention_mask"].dtype == torch.long
|
| assert item["labels"].dtype == torch.long
|
|
|
| def test_with_music_tokens(self):
|
| """Test handling of music-specific tokens."""
|
| samples = [{"text": "Use [TAB] for guitar"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
|
|
| assert item["input_ids"].shape[0] == self.max_length
|
|
|
| def test_batch_consistency(self):
|
| """Test that multiple accesses to same sample return same result."""
|
| samples = [{"text": "Consistent"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
|
|
| item1 = dataset[0]
|
| item2 = dataset[0]
|
|
|
| assert torch.equal(item1["input_ids"], item2["input_ids"])
|
| assert torch.equal(item1["attention_mask"], item2["attention_mask"])
|
| assert torch.equal(item1["labels"], item2["labels"])
|
|
|
| def test_different_max_lengths(self):
|
| """Test dataset with different max_length values."""
|
| for max_len in [128, 256, 512, 1024]:
|
| samples = [{"text": "Test"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, max_len)
|
| item = dataset[0]
|
| assert len(item["input_ids"]) == max_len
|
|
|
| def test_tokenizer_not_called_multiple_times(self):
|
| """Test that tokenizer is called once during dataset creation."""
|
| samples = [{"text": "Test 1"}, {"text": "Test 2"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
|
|
|
|
| assert self.tokenizer.encode.call_count == 2
|
|
|
| def test_labels_ignore_padding(self):
|
| """Test that labels ignore padding tokens (set to -100)."""
|
| samples = [{"text": "Short"}]
|
| dataset = TouchGrassDataset(samples, self.tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
|
|
|
|
| labels = item["labels"]
|
|
|
| assert labels.shape[0] == self.max_length
|
|
|
| def test_with_actual_tokenizer_mock(self):
|
| """Test with a more realistic tokenizer mock."""
|
| def mock_encode(text, **kwargs):
|
|
|
| tokens = [1] * min(len(text.split()), 10)
|
| return tokens
|
|
|
| tokenizer = MagicMock()
|
| tokenizer.encode.side_effect = mock_encode
|
| tokenizer.pad_token_id = 0
|
|
|
| samples = [{"text": "This is a longer text sample with more words"}]
|
| dataset = TouchGrassDataset(samples, tokenizer, self.max_length)
|
| item = dataset[0]
|
|
|
| assert item["input_ids"].shape[0] == self.max_length
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|