File size: 7,761 Bytes
d66ab65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""

Unit tests for TokenizerService

"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from app.services.tokenizer_service import TokenizerService
import time


class TestTokenizerService:
    """Test cases for TokenizerService."""
    
    def setup_method(self):
        """Set up test fixtures."""
        self.service = TokenizerService()
    
    def test_is_predefined_model(self):
        """Test predefined model checking."""
        # Test with existing model
        assert self.service.is_predefined_model('gpt2') is True
        
        # Test with non-existing model
        assert self.service.is_predefined_model('nonexistent-model') is False
        
        # Test with empty string
        assert self.service.is_predefined_model('') is False
    
    def test_get_tokenizer_info_basic(self, mock_tokenizer):
        """Test basic tokenizer info extraction."""
        info = self.service.get_tokenizer_info(mock_tokenizer)
        
        assert 'vocab_size' in info
        assert 'tokenizer_type' in info
        assert 'special_tokens' in info
        assert info['vocab_size'] == 50257
        assert info['tokenizer_type'] == 'MockTokenizer'
        
        # Check special tokens
        special_tokens = info['special_tokens']
        assert 'pad_token' in special_tokens
        assert 'eos_token' in special_tokens
        assert special_tokens['pad_token'] == '<pad>'
        assert special_tokens['eos_token'] == '</s>'
    
    def test_get_tokenizer_info_with_max_length(self, mock_tokenizer):
        """Test tokenizer info with model_max_length."""
        mock_tokenizer.model_max_length = 2048
        
        info = self.service.get_tokenizer_info(mock_tokenizer)
        
        assert 'model_max_length' in info
        assert info['model_max_length'] == 2048
    
    def test_get_tokenizer_info_error_handling(self):
        """Test error handling in tokenizer info extraction."""
        # Create a mock that raises an exception
        broken_tokenizer = Mock()
        broken_tokenizer.__class__.__name__ = 'BrokenTokenizer'
        broken_tokenizer.vocab_size = property(Mock(side_effect=Exception("Test error")))
        
        info = self.service.get_tokenizer_info(broken_tokenizer)
        
        assert 'error' in info
        assert 'Test error' in info['error']
    
    @patch('app.services.tokenizer_service.AutoTokenizer')
    def test_load_predefined_tokenizer_success(self, mock_auto_tokenizer, mock_tokenizer):
        """Test successful loading of predefined tokenizer."""
        mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
        
        tokenizer, info, error = self.service.load_tokenizer('gpt2')
        
        assert tokenizer is not None
        assert error is None
        assert isinstance(info, dict)
        mock_auto_tokenizer.from_pretrained.assert_called_once()
    
    @patch('app.services.tokenizer_service.AutoTokenizer')
    def test_load_tokenizer_failure(self, mock_auto_tokenizer):
        """Test tokenizer loading failure."""
        mock_auto_tokenizer.from_pretrained.side_effect = Exception("Failed to load")
        
        tokenizer, info, error = self.service.load_tokenizer('gpt2')
        
        assert tokenizer is None
        assert error is not None
        assert "Failed to load" in error
    
    def test_load_nonexistent_predefined_model(self):
        """Test loading non-existent predefined model."""
        tokenizer, info, error = self.service.load_tokenizer('nonexistent-model')
        
        assert tokenizer is None
        assert error is not None
        assert "not found" in error.lower()
    
    @patch('app.services.tokenizer_service.AutoTokenizer')
    @patch('time.time')
    def test_custom_tokenizer_caching(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app):
        """Test custom tokenizer caching behavior."""
        with app.app_context():
            mock_time.return_value = 1000.0
            mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
            
            # First load
            tokenizer1, info1, error1 = self.service.load_tokenizer('custom/model')
            
            # Second load (should use cache)
            mock_time.return_value = 1500.0  # Still within cache time
            tokenizer2, info2, error2 = self.service.load_tokenizer('custom/model')
            
            # Should only call from_pretrained once
            assert mock_auto_tokenizer.from_pretrained.call_count == 1
            assert tokenizer1 is tokenizer2
    
    @patch('app.services.tokenizer_service.AutoTokenizer')
    @patch('time.time')
    def test_custom_tokenizer_cache_expiration(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app):
        """Test custom tokenizer cache expiration."""
        with app.app_context():
            mock_time.return_value = 1000.0
            mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
            
            # First load
            self.service.load_tokenizer('custom/model')
            
            # Second load after cache expiration
            mock_time.return_value = 5000.0  # Beyond cache expiration
            self.service.load_tokenizer('custom/model')
            
            # Should call from_pretrained twice
            assert mock_auto_tokenizer.from_pretrained.call_count == 2
    
    def test_tokenizer_models_constant(self):
        """Test that TOKENIZER_MODELS contains expected models."""
        models = self.service.TOKENIZER_MODELS
        
        assert isinstance(models, dict)
        assert len(models) > 0
        
        # Check that each model has required fields
        for model_id, model_info in models.items():
            assert isinstance(model_id, str)
            assert isinstance(model_info, dict)
            assert 'name' in model_info
            assert 'alias' in model_info
            assert isinstance(model_info['name'], str)
            assert isinstance(model_info['alias'], str)
    
    def test_cache_initialization(self):
        """Test that caches are properly initialized."""
        service = TokenizerService()
        
        assert hasattr(service, 'tokenizers')
        assert hasattr(service, 'custom_tokenizers')
        assert hasattr(service, 'tokenizer_info_cache')
        
        assert isinstance(service.tokenizers, dict)
        assert isinstance(service.custom_tokenizers, dict)
        assert isinstance(service.tokenizer_info_cache, dict)
    
    def test_special_tokens_filtering(self, mock_tokenizer):
        """Test that only valid special tokens are included."""
        # Add some None and empty special tokens
        mock_tokenizer.pad_token = '<pad>'
        mock_tokenizer.eos_token = '</s>'
        mock_tokenizer.bos_token = None
        mock_tokenizer.sep_token = ''
        mock_tokenizer.cls_token = '   '  # Whitespace only
        mock_tokenizer.unk_token = '<unk>'
        mock_tokenizer.mask_token = '<mask>'
        
        info = self.service.get_tokenizer_info(mock_tokenizer)
        special_tokens = info['special_tokens']
        
        # Should only include non-None, non-empty tokens
        assert 'pad_token' in special_tokens
        assert 'eos_token' in special_tokens
        assert 'unk_token' in special_tokens
        assert 'mask_token' in special_tokens
        
        # Should not include None or empty tokens
        assert 'bos_token' not in special_tokens
        assert 'sep_token' not in special_tokens
        assert 'cls_token' not in special_tokens