| """
|
| Tests for Music Tokenizer Extension.
|
| """
|
|
|
| import pytest
|
| from unittest.mock import MagicMock, patch
|
|
|
| from TouchGrass.tokenizer.music_token_extension import MusicTokenizerExtension
|
|
|
|
|
| class TestMusicTokenizerExtension:
|
| """Test suite for MusicTokenizerExtension."""
|
|
|
| def setup_method(self):
|
| """Set up test fixtures."""
|
| self.special_tokens = {
|
| "[GUITAR]": 32000,
|
| "[PIANO]": 32001,
|
| "[DRUMS]": 32002,
|
| "[VOCALS]": 32003,
|
| "[THEORY]": 32004,
|
| "[PRODUCTION]": 32005,
|
| "[FRUSTRATED]": 32006,
|
| "[CONFUSED]": 32007,
|
| "[EXCITED]": 32008,
|
| "[CONFIDENT]": 32009,
|
| "[EASY]": 32010,
|
| "[MEDIUM]": 32011,
|
| "[HARD]": 32012,
|
| "[TAB]": 32013,
|
| "[CHORD]": 32014,
|
| "[SCALE]": 32015,
|
| "[INTERVAL]": 32016,
|
| "[PROGRESSION]": 32017,
|
| "[SIMPLIFY]": 32018,
|
| "[ENCOURAGE]": 32019,
|
| }
|
| self.music_vocab_extensions = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
|
|
| def test_tokenizer_initialization(self):
|
| """Test that tokenizer initializes correctly with special tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32000
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=self.music_vocab_extensions
|
| )
|
|
|
| assert ext.base_tokenizer == mock_tokenizer
|
| mock_tokenizer_class.from_pretrained.assert_called_once_with("Qwen/Qwen3.5-3B-Instruct")
|
|
|
| def test_special_tokens_added(self):
|
| """Test that special tokens are added to tokenizer."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32000
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| expected_tokens = list(self.special_tokens.keys())
|
| mock_tokenizer.add_special_tokens.assert_called_once_with(
|
| {"additional_special_tokens": expected_tokens}
|
| )
|
|
|
| def test_music_vocab_extensions_added(self):
|
| """Test that music vocabulary extensions are added."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32000
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens={},
|
| music_vocab_extensions=self.music_vocab_extensions
|
| )
|
|
|
|
|
| assert mock_tokenizer.add_tokens.called
|
| added_tokens = mock_tokenizer.add_tokens.call_args[0][0]
|
| assert set(added_tokens) == set(self.music_vocab_extensions)
|
|
|
| def test_tokenizer_vocab_size_increased(self):
|
| """Test that vocab size is correctly increased after adding tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32000
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| num_special = len(self.special_tokens)
|
| num_music = len(self.music_vocab_extensions)
|
| expected_new_vocab_size = 32000 + num_special + num_music
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=self.music_vocab_extensions
|
| )
|
|
|
| assert ext.base_tokenizer.vocab_size == expected_new_vocab_size
|
|
|
| def test_encode_with_music_tokens(self):
|
| """Test encoding text with music tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer.encode.return_value = [1, 2, 32000, 3, 4]
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| result = ext.encode("Play a [GUITAR] chord")
|
| assert result == [1, 2, 32000, 3, 4]
|
| mock_tokenizer.encode.assert_called_once_with("Play a [GUITAR] chord")
|
|
|
| def test_decode_with_music_tokens(self):
|
| """Test decoding token IDs with music tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer.decode.return_value = "Play a [GUITAR] chord"
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| result = ext.decode([1, 2, 32000, 3, 4])
|
| assert result == "Play a [GUITAR] chord"
|
| mock_tokenizer.decode.assert_called_once_with([1, 2, 32000, 3, 4])
|
|
|
| def test_get_music_token_id(self):
|
| """Test retrieving token ID for a music token."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer.convert_tokens_to_ids.return_value = 32000
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| token_id = ext.get_music_token_id("[GUITAR]")
|
| assert token_id == 32000
|
| mock_tokenizer.convert_tokens_to_ids.assert_called_with("[GUITAR]")
|
|
|
| def test_has_music_token(self):
|
| """Test checking if a token is a music token."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| assert ext.has_music_token("[GUITAR]") is True
|
| assert ext.has_music_token("[UNKNOWN]") is False
|
|
|
| def test_get_music_domain_tokens(self):
|
| """Test retrieving all domain tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| domain_tokens = ext.get_music_domain_tokens()
|
| expected = ["[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]"]
|
| assert domain_tokens == expected
|
|
|
| def test_get_emotion_tokens(self):
|
| """Test retrieving emotion tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| emotion_tokens = ext.get_emotion_tokens()
|
| expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]"]
|
| assert emotion_tokens == expected
|
|
|
| def test_get_difficulty_tokens(self):
|
| """Test retrieving difficulty tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| difficulty_tokens = ext.get_difficulty_tokens()
|
| expected = ["[EASY]", "[MEDIUM]", "[HARD]"]
|
| assert difficulty_tokens == expected
|
|
|
| def test_get_music_function_tokens(self):
|
| """Test retrieving music function tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| function_tokens = ext.get_music_function_tokens()
|
| expected = ["[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]"]
|
| assert function_tokens == expected
|
|
|
| def test_get_eq_tokens(self):
|
| """Test retrieving EQ (emotional intelligence) tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32021
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=[]
|
| )
|
|
|
| eq_tokens = ext.get_eq_tokens()
|
| expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]", "[SIMPLIFY]", "[ENCOURAGE]"]
|
| assert eq_tokens == expected
|
|
|
| def test_token_count_with_music_tokens(self):
|
| """Test that token count increases after adding music tokens."""
|
| with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| mock_tokenizer = MagicMock()
|
| mock_tokenizer.vocab_size = 32000
|
| mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
|
|
| num_special = len(self.special_tokens)
|
| num_music = len(self.music_vocab_extensions)
|
|
|
| ext = MusicTokenizerExtension(
|
| "Qwen/Qwen3.5-3B-Instruct",
|
| special_tokens=self.special_tokens,
|
| music_vocab_extensions=self.music_vocab_extensions
|
| )
|
|
|
| expected_vocab_size = 32000 + num_special + num_music
|
| assert ext.base_tokenizer.vocab_size == expected_vocab_size
|
| assert ext.base_tokenizer.vocab_size > 32000
|
|
|
|
|
| if __name__ == "__main__":
|
| pytest.main([__file__, "-v"])
|
|
|