| | """
|
| | Tests for Music Tokenizer Extension.
|
| | """
|
| |
|
| | import pytest
|
| | from unittest.mock import MagicMock, patch
|
| |
|
| | from TouchGrass.tokenizer.music_token_extension import MusicTokenizerExtension
|
| |
|
| |
|
| | class TestMusicTokenizerExtension:
|
| | """Test suite for MusicTokenizerExtension."""
|
| |
|
| | def setup_method(self):
|
| | """Set up test fixtures."""
|
| | self.special_tokens = {
|
| | "[GUITAR]": 32000,
|
| | "[PIANO]": 32001,
|
| | "[DRUMS]": 32002,
|
| | "[VOCALS]": 32003,
|
| | "[THEORY]": 32004,
|
| | "[PRODUCTION]": 32005,
|
| | "[FRUSTRATED]": 32006,
|
| | "[CONFUSED]": 32007,
|
| | "[EXCITED]": 32008,
|
| | "[CONFIDENT]": 32009,
|
| | "[EASY]": 32010,
|
| | "[MEDIUM]": 32011,
|
| | "[HARD]": 32012,
|
| | "[TAB]": 32013,
|
| | "[CHORD]": 32014,
|
| | "[SCALE]": 32015,
|
| | "[INTERVAL]": 32016,
|
| | "[PROGRESSION]": 32017,
|
| | "[SIMPLIFY]": 32018,
|
| | "[ENCOURAGE]": 32019,
|
| | }
|
| | self.music_vocab_extensions = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
|
| |
|
| | def test_tokenizer_initialization(self):
|
| | """Test that tokenizer initializes correctly with special tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32000
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=self.music_vocab_extensions
|
| | )
|
| |
|
| | assert ext.base_tokenizer == mock_tokenizer
|
| | mock_tokenizer_class.from_pretrained.assert_called_once_with("Qwen/Qwen3.5-3B-Instruct")
|
| |
|
| | def test_special_tokens_added(self):
|
| | """Test that special tokens are added to tokenizer."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32000
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | expected_tokens = list(self.special_tokens.keys())
|
| | mock_tokenizer.add_special_tokens.assert_called_once_with(
|
| | {"additional_special_tokens": expected_tokens}
|
| | )
|
| |
|
| | def test_music_vocab_extensions_added(self):
|
| | """Test that music vocabulary extensions are added."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32000
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens={},
|
| | music_vocab_extensions=self.music_vocab_extensions
|
| | )
|
| |
|
| |
|
| | assert mock_tokenizer.add_tokens.called
|
| | added_tokens = mock_tokenizer.add_tokens.call_args[0][0]
|
| | assert set(added_tokens) == set(self.music_vocab_extensions)
|
| |
|
| | def test_tokenizer_vocab_size_increased(self):
|
| | """Test that vocab size is correctly increased after adding tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32000
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | num_special = len(self.special_tokens)
|
| | num_music = len(self.music_vocab_extensions)
|
| | expected_new_vocab_size = 32000 + num_special + num_music
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=self.music_vocab_extensions
|
| | )
|
| |
|
| | assert ext.base_tokenizer.vocab_size == expected_new_vocab_size
|
| |
|
| | def test_encode_with_music_tokens(self):
|
| | """Test encoding text with music tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer.encode.return_value = [1, 2, 32000, 3, 4]
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | result = ext.encode("Play a [GUITAR] chord")
|
| | assert result == [1, 2, 32000, 3, 4]
|
| | mock_tokenizer.encode.assert_called_once_with("Play a [GUITAR] chord")
|
| |
|
| | def test_decode_with_music_tokens(self):
|
| | """Test decoding token IDs with music tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer.decode.return_value = "Play a [GUITAR] chord"
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | result = ext.decode([1, 2, 32000, 3, 4])
|
| | assert result == "Play a [GUITAR] chord"
|
| | mock_tokenizer.decode.assert_called_once_with([1, 2, 32000, 3, 4])
|
| |
|
| | def test_get_music_token_id(self):
|
| | """Test retrieving token ID for a music token."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer.convert_tokens_to_ids.return_value = 32000
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | token_id = ext.get_music_token_id("[GUITAR]")
|
| | assert token_id == 32000
|
| | mock_tokenizer.convert_tokens_to_ids.assert_called_with("[GUITAR]")
|
| |
|
| | def test_has_music_token(self):
|
| | """Test checking if a token is a music token."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | assert ext.has_music_token("[GUITAR]") is True
|
| | assert ext.has_music_token("[UNKNOWN]") is False
|
| |
|
| | def test_get_music_domain_tokens(self):
|
| | """Test retrieving all domain tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | domain_tokens = ext.get_music_domain_tokens()
|
| | expected = ["[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[PRODUCTION]"]
|
| | assert domain_tokens == expected
|
| |
|
| | def test_get_emotion_tokens(self):
|
| | """Test retrieving emotion tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | emotion_tokens = ext.get_emotion_tokens()
|
| | expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]"]
|
| | assert emotion_tokens == expected
|
| |
|
| | def test_get_difficulty_tokens(self):
|
| | """Test retrieving difficulty tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | difficulty_tokens = ext.get_difficulty_tokens()
|
| | expected = ["[EASY]", "[MEDIUM]", "[HARD]"]
|
| | assert difficulty_tokens == expected
|
| |
|
| | def test_get_music_function_tokens(self):
|
| | """Test retrieving music function tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | function_tokens = ext.get_music_function_tokens()
|
| | expected = ["[TAB]", "[CHORD]", "[SCALE]", "[INTERVAL]", "[PROGRESSION]"]
|
| | assert function_tokens == expected
|
| |
|
| | def test_get_eq_tokens(self):
|
| | """Test retrieving EQ (emotional intelligence) tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32021
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=[]
|
| | )
|
| |
|
| | eq_tokens = ext.get_eq_tokens()
|
| | expected = ["[FRUSTRATED]", "[CONFUSED]", "[EXCITED]", "[CONFIDENT]", "[SIMPLIFY]", "[ENCOURAGE]"]
|
| | assert eq_tokens == expected
|
| |
|
| | def test_token_count_with_music_tokens(self):
|
| | """Test that token count increases after adding music tokens."""
|
| | with patch('TouchGrass.tokenizer.music_token_extension.AutoTokenizer') as mock_tokenizer_class:
|
| | mock_tokenizer = MagicMock()
|
| | mock_tokenizer.vocab_size = 32000
|
| | mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
|
| |
|
| | num_special = len(self.special_tokens)
|
| | num_music = len(self.music_vocab_extensions)
|
| |
|
| | ext = MusicTokenizerExtension(
|
| | "Qwen/Qwen3.5-3B-Instruct",
|
| | special_tokens=self.special_tokens,
|
| | music_vocab_extensions=self.music_vocab_extensions
|
| | )
|
| |
|
| | expected_vocab_size = 32000 + num_special + num_music
|
| | assert ext.base_tokenizer.vocab_size == expected_vocab_size
|
| | assert ext.base_tokenizer.vocab_size > 32000
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | pytest.main([__file__, "-v"])
|
| |
|