| """
|
| Tests for Tab & Chord Generation Module.
|
| """
|
|
|
| import pytest
|
| import torch
|
|
|
| from TouchGrass.models.tab_chord_module import TabChordModule
|
|
|
|
|
| class TestTabChordModule:
|
| """Test suite for TabChordModule."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.d_model = 768
|
| self.batch_size = 4
|
| self.num_strings = 6
|
| self.num_frets = 24
|
| self.module = TabChordModule(d_model=self.d_model, num_strings=self.num_strings, num_frets=self.num_frets)
|
|
|
| def test_module_initialization(self):
|
| """Test that module initializes correctly."""
|
| assert self.module.string_embed.num_embeddings == self.num_strings
|
| assert self.module.fret_embed.num_embeddings == self.num_frets + 2
|
| assert isinstance(self.module.tab_validator, torch.nn.Sequential)
|
| assert isinstance(self.module.difficulty_head, torch.nn.Linear)
|
| assert self.module.difficulty_head.out_features == 3
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass with dummy inputs."""
|
| seq_len = 10
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
|
|
| assert "tab_validator" in output
|
| assert "difficulty" in output
|
| assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
|
| assert output["difficulty"].shape == (self.batch_size, seq_len, 3)
|
|
|
| def test_tab_validator_output_range(self):
|
| """Test that tab validator outputs are in [0, 1] range."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| validator_output = output["tab_validator"]
|
|
|
| assert torch.all(validator_output >= 0)
|
| assert torch.all(validator_output <= 1)
|
|
|
| def test_difficulty_head_output(self):
|
| """Test difficulty head produces logits for 3 classes."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| difficulty_logits = output["difficulty"]
|
|
|
|
|
| assert difficulty_logits.shape == (self.batch_size, seq_len, 3)
|
|
|
| def test_embedding_dimensions(self):
|
| """Test embedding layer dimensions."""
|
|
|
| assert self.module.string_embed.embedding_dim == 64
|
|
|
| assert self.module.fret_embed.embedding_dim == 64
|
|
|
| def test_forward_with_different_seq_lengths(self):
|
| """Test forward pass with varying sequence lengths."""
|
| for seq_len in [1, 5, 20, 50]:
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| assert output["tab_validator"].shape[1] == seq_len
|
| assert output["difficulty"].shape[1] == seq_len
|
|
|
| def test_gradient_flow(self):
|
| """Test that gradients flow through the module."""
|
| seq_len = 5
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model, requires_grad=True)
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| loss = output["tab_validator"].sum() + output["difficulty"].sum()
|
| loss.backward()
|
|
|
| assert hidden_states.grad is not None
|
| assert self.module.string_embed.weight.grad is not None
|
| assert self.module.fret_embed.weight.grad is not None
|
|
|
| def test_different_batch_sizes(self):
|
| """Test forward pass with different batch sizes."""
|
| for batch_size in [1, 2, 8, 16]:
|
| seq_len = 10
|
| hidden_states = torch.randn(batch_size, seq_len, self.d_model)
|
| string_indices = torch.randint(0, self.num_strings, (batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| assert output["tab_validator"].shape[0] == batch_size
|
| assert output["difficulty"].shape[0] == batch_size
|
|
|
| def test_special_fret_tokens(self):
|
| """Test handling of special fret tokens (e.g., mute, open)."""
|
| seq_len = 3
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
|
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.tensor([[0, 1, 5], [2, 0, 10], [3, 1, 15], [4, 0, 20]])
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| assert output["tab_validator"].shape == (self.batch_size, seq_len, 1)
|
|
|
| def test_tab_validator_confidence_scores(self):
|
| """Test that validator produces meaningful confidence scores."""
|
| seq_len = 1
|
| hidden_states = torch.randn(self.batch_size, seq_len, self.d_model)
|
| string_indices = torch.randint(0, self.num_strings, (self.batch_size, seq_len))
|
| fret_indices = torch.randint(0, self.num_frets + 2, (self.batch_size, seq_len))
|
|
|
| output = self.module(hidden_states, string_indices, fret_indices)
|
| confidence = output["tab_validator"]
|
|
|
|
|
| assert torch.all((confidence >= 0) & (confidence <= 1))
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|