Spaces:
Running
Running
""" | |
Tokenizer Service - Handles tokenizer loading, caching, and management | |
""" | |
import time | |
from typing import Dict, Tuple, Optional, Any | |
from transformers import AutoTokenizer | |
from flask import current_app | |
class TokenizerService: | |
"""Service for managing tokenizer loading and caching.""" | |
# Predefined tokenizer models with aliases | |
TOKENIZER_MODELS = { | |
'qwen3': { | |
'name': 'Qwen/Qwen3-0.6B', | |
'alias': 'Qwen 3' | |
}, | |
'gemma3-27b': { | |
'name': 'google/gemma-3-27b-it', | |
'alias': 'Gemma 3 27B' | |
}, | |
'glm4': { | |
'name': 'THUDM/GLM-4-32B-0414', | |
'alias': 'GLM 4' | |
}, | |
'mistral-small': { | |
'name': 'mistralai/Mistral-Small-3.1-24B-Instruct-2503', | |
'alias': 'Mistral Small 3.1' | |
}, | |
'llama4': { | |
'name': 'meta-llama/Llama-4-Scout-17B-16E-Instruct', | |
'alias': 'Llama 4' | |
}, | |
'deepseek-r1': { | |
'name': 'deepseek-ai/DeepSeek-R1', | |
'alias': 'Deepseek R1' | |
}, | |
'qwen_25_72b': { | |
'name': 'Qwen/Qwen2.5-72B-Instruct', | |
'alias': 'QWQ 32B' | |
}, | |
'llama_33': { | |
'name': 'unsloth/Llama-3.3-70B-Instruct-bnb-4bit', | |
'alias': 'Llama 3.3 70B' | |
}, | |
'gemma2_2b': { | |
'name': 'google/gemma-2-2b-it', | |
'alias': 'Gemma 2 2B' | |
}, | |
'bert-large-uncased': { | |
'name': 'google-bert/bert-large-uncased', | |
'alias': 'Bert Large Uncased' | |
}, | |
'gpt2': { | |
'name': 'openai-community/gpt2', | |
'alias': 'GPT-2' | |
} | |
} | |
def __init__(self): | |
"""Initialize the tokenizer service with empty caches.""" | |
self.tokenizers: Dict[str, Any] = {} | |
self.custom_tokenizers: Dict[str, Tuple[Any, float]] = {} | |
self.tokenizer_info_cache: Dict[str, Dict] = {} | |
self.custom_model_errors: Dict[str, str] = {} | |
def get_tokenizer_info(self, tokenizer) -> Dict: | |
"""Extract useful information from a tokenizer.""" | |
info = {} | |
try: | |
# Get vocabulary size (dictionary size) | |
if hasattr(tokenizer, 'vocab_size'): | |
info['vocab_size'] = tokenizer.vocab_size | |
elif hasattr(tokenizer, 'get_vocab'): | |
info['vocab_size'] = len(tokenizer.get_vocab()) | |
# Get model max length if available | |
if hasattr(tokenizer, 'model_max_length') and tokenizer.model_max_length < 1000000: | |
info['model_max_length'] = tokenizer.model_max_length | |
# Check tokenizer type | |
info['tokenizer_type'] = tokenizer.__class__.__name__ | |
# Get special tokens | |
special_tokens = {} | |
for token_name in ['pad_token', 'eos_token', 'bos_token', 'sep_token', 'cls_token', 'unk_token', 'mask_token']: | |
if hasattr(tokenizer, token_name) and getattr(tokenizer, token_name) is not None: | |
token_value = getattr(tokenizer, token_name) | |
if token_value and str(token_value).strip(): | |
special_tokens[token_name] = str(token_value) | |
info['special_tokens'] = special_tokens | |
except Exception as e: | |
info['error'] = f"Error extracting tokenizer info: {str(e)}" | |
return info | |
def load_tokenizer(self, model_id_or_name: str) -> Tuple[Optional[Any], Dict, Optional[str]]: | |
""" | |
Load tokenizer if not already loaded. | |
Returns: | |
Tuple of (tokenizer, tokenizer_info, error_message) | |
""" | |
error_message = None | |
tokenizer_info = {} | |
# Check if we have cached tokenizer info | |
if model_id_or_name in self.tokenizer_info_cache: | |
tokenizer_info = self.tokenizer_info_cache[model_id_or_name] | |
try: | |
# Check if it's a predefined model ID | |
if model_id_or_name in self.TOKENIZER_MODELS: | |
model_name = self.TOKENIZER_MODELS[model_id_or_name]['name'] | |
if model_id_or_name not in self.tokenizers: | |
self.tokenizers[model_id_or_name] = AutoTokenizer.from_pretrained(model_name) | |
tokenizer = self.tokenizers[model_id_or_name] | |
# Get tokenizer info if not already cached | |
if model_id_or_name not in self.tokenizer_info_cache: | |
tokenizer_info = self.get_tokenizer_info(tokenizer) | |
self.tokenizer_info_cache[model_id_or_name] = tokenizer_info | |
return tokenizer, tokenizer_info, None | |
# It's a custom model path | |
# Check if we have it in the custom cache and it's not expired | |
current_time = time.time() | |
cache_expiration = current_app.config.get('CACHE_EXPIRATION', 3600) | |
if model_id_or_name in self.custom_tokenizers: | |
cached_tokenizer, timestamp = self.custom_tokenizers[model_id_or_name] | |
if current_time - timestamp < cache_expiration: | |
# Get tokenizer info if not already cached | |
if model_id_or_name not in self.tokenizer_info_cache: | |
tokenizer_info = self.get_tokenizer_info(cached_tokenizer) | |
self.tokenizer_info_cache[model_id_or_name] = tokenizer_info | |
return cached_tokenizer, tokenizer_info, None | |
# Not in cache or expired, load it | |
tokenizer = AutoTokenizer.from_pretrained(model_id_or_name) | |
# Store in cache with timestamp | |
self.custom_tokenizers[model_id_or_name] = (tokenizer, current_time) | |
# Clear any previous errors for this model | |
if model_id_or_name in self.custom_model_errors: | |
del self.custom_model_errors[model_id_or_name] | |
# Get tokenizer info | |
tokenizer_info = self.get_tokenizer_info(tokenizer) | |
self.tokenizer_info_cache[model_id_or_name] = tokenizer_info | |
return tokenizer, tokenizer_info, None | |
except Exception as e: | |
error_message = f"Failed to load tokenizer: {str(e)}" | |
# Store error for future reference | |
self.custom_model_errors[model_id_or_name] = error_message | |
return None, tokenizer_info, error_message | |
def get_model_alias(self, model_id: str) -> str: | |
"""Get the display alias for a model ID.""" | |
if model_id in self.TOKENIZER_MODELS: | |
return self.TOKENIZER_MODELS[model_id]['alias'] | |
return model_id | |
def is_predefined_model(self, model_id: str) -> bool: | |
"""Check if a model ID is a predefined model.""" | |
return model_id in self.TOKENIZER_MODELS | |
def clear_cache(self): | |
"""Clear all caches.""" | |
self.tokenizers.clear() | |
self.custom_tokenizers.clear() | |
self.tokenizer_info_cache.clear() | |
self.custom_model_errors.clear() | |
# Global instance | |
tokenizer_service = TokenizerService() |