Spaces:
Running
Running
File size: 7,761 Bytes
d66ab65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""
Unit tests for TokenizerService
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from app.services.tokenizer_service import TokenizerService
import time
class TestTokenizerService:
"""Test cases for TokenizerService."""
def setup_method(self):
"""Set up test fixtures."""
self.service = TokenizerService()
def test_is_predefined_model(self):
"""Test predefined model checking."""
# Test with existing model
assert self.service.is_predefined_model('gpt2') is True
# Test with non-existing model
assert self.service.is_predefined_model('nonexistent-model') is False
# Test with empty string
assert self.service.is_predefined_model('') is False
def test_get_tokenizer_info_basic(self, mock_tokenizer):
"""Test basic tokenizer info extraction."""
info = self.service.get_tokenizer_info(mock_tokenizer)
assert 'vocab_size' in info
assert 'tokenizer_type' in info
assert 'special_tokens' in info
assert info['vocab_size'] == 50257
assert info['tokenizer_type'] == 'MockTokenizer'
# Check special tokens
special_tokens = info['special_tokens']
assert 'pad_token' in special_tokens
assert 'eos_token' in special_tokens
assert special_tokens['pad_token'] == '<pad>'
assert special_tokens['eos_token'] == '</s>'
def test_get_tokenizer_info_with_max_length(self, mock_tokenizer):
"""Test tokenizer info with model_max_length."""
mock_tokenizer.model_max_length = 2048
info = self.service.get_tokenizer_info(mock_tokenizer)
assert 'model_max_length' in info
assert info['model_max_length'] == 2048
def test_get_tokenizer_info_error_handling(self):
"""Test error handling in tokenizer info extraction."""
# Create a mock that raises an exception
broken_tokenizer = Mock()
broken_tokenizer.__class__.__name__ = 'BrokenTokenizer'
broken_tokenizer.vocab_size = property(Mock(side_effect=Exception("Test error")))
info = self.service.get_tokenizer_info(broken_tokenizer)
assert 'error' in info
assert 'Test error' in info['error']
@patch('app.services.tokenizer_service.AutoTokenizer')
def test_load_predefined_tokenizer_success(self, mock_auto_tokenizer, mock_tokenizer):
"""Test successful loading of predefined tokenizer."""
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer, info, error = self.service.load_tokenizer('gpt2')
assert tokenizer is not None
assert error is None
assert isinstance(info, dict)
mock_auto_tokenizer.from_pretrained.assert_called_once()
@patch('app.services.tokenizer_service.AutoTokenizer')
def test_load_tokenizer_failure(self, mock_auto_tokenizer):
"""Test tokenizer loading failure."""
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Failed to load")
tokenizer, info, error = self.service.load_tokenizer('gpt2')
assert tokenizer is None
assert error is not None
assert "Failed to load" in error
def test_load_nonexistent_predefined_model(self):
"""Test loading non-existent predefined model."""
tokenizer, info, error = self.service.load_tokenizer('nonexistent-model')
assert tokenizer is None
assert error is not None
assert "not found" in error.lower()
@patch('app.services.tokenizer_service.AutoTokenizer')
@patch('time.time')
def test_custom_tokenizer_caching(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app):
"""Test custom tokenizer caching behavior."""
with app.app_context():
mock_time.return_value = 1000.0
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# First load
tokenizer1, info1, error1 = self.service.load_tokenizer('custom/model')
# Second load (should use cache)
mock_time.return_value = 1500.0 # Still within cache time
tokenizer2, info2, error2 = self.service.load_tokenizer('custom/model')
# Should only call from_pretrained once
assert mock_auto_tokenizer.from_pretrained.call_count == 1
assert tokenizer1 is tokenizer2
@patch('app.services.tokenizer_service.AutoTokenizer')
@patch('time.time')
def test_custom_tokenizer_cache_expiration(self, mock_time, mock_auto_tokenizer, mock_tokenizer, app):
"""Test custom tokenizer cache expiration."""
with app.app_context():
mock_time.return_value = 1000.0
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# First load
self.service.load_tokenizer('custom/model')
# Second load after cache expiration
mock_time.return_value = 5000.0 # Beyond cache expiration
self.service.load_tokenizer('custom/model')
# Should call from_pretrained twice
assert mock_auto_tokenizer.from_pretrained.call_count == 2
def test_tokenizer_models_constant(self):
"""Test that TOKENIZER_MODELS contains expected models."""
models = self.service.TOKENIZER_MODELS
assert isinstance(models, dict)
assert len(models) > 0
# Check that each model has required fields
for model_id, model_info in models.items():
assert isinstance(model_id, str)
assert isinstance(model_info, dict)
assert 'name' in model_info
assert 'alias' in model_info
assert isinstance(model_info['name'], str)
assert isinstance(model_info['alias'], str)
def test_cache_initialization(self):
"""Test that caches are properly initialized."""
service = TokenizerService()
assert hasattr(service, 'tokenizers')
assert hasattr(service, 'custom_tokenizers')
assert hasattr(service, 'tokenizer_info_cache')
assert isinstance(service.tokenizers, dict)
assert isinstance(service.custom_tokenizers, dict)
assert isinstance(service.tokenizer_info_cache, dict)
def test_special_tokens_filtering(self, mock_tokenizer):
"""Test that only valid special tokens are included."""
# Add some None and empty special tokens
mock_tokenizer.pad_token = '<pad>'
mock_tokenizer.eos_token = '</s>'
mock_tokenizer.bos_token = None
mock_tokenizer.sep_token = ''
mock_tokenizer.cls_token = ' ' # Whitespace only
mock_tokenizer.unk_token = '<unk>'
mock_tokenizer.mask_token = '<mask>'
info = self.service.get_tokenizer_info(mock_tokenizer)
special_tokens = info['special_tokens']
# Should only include non-None, non-empty tokens
assert 'pad_token' in special_tokens
assert 'eos_token' in special_tokens
assert 'unk_token' in special_tokens
assert 'mask_token' in special_tokens
# Should not include None or empty tokens
assert 'bos_token' not in special_tokens
assert 'sep_token' not in special_tokens
assert 'cls_token' not in special_tokens |