| """Advanced Tokenization with Multi-Tokenizer Support and Optimization"""
|
|
|
| import json
|
| import logging
|
| from dataclasses import dataclass
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
| import numpy as np
|
| from transformers import AutoTokenizer, PreTrainedTokenizer
|
| from tokenizers import Tokenizer as HFTokenizer
|
| from tokenizers.models import WordLevel
|
| from tokenizers.pre_tokenizers import Whitespace
|
| from tokenizers.processors import TemplateProcessing
|
| from tokenizers.trainers import WordLevelTrainer
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class TokenizerConfig:
|
| """Configuration for advanced tokenizer."""
|
| tokenizer_name: str = "meta-llama/Llama-2-7b-hf"
|
| use_custom_tokenizer: bool = False
|
| custom_vocab_size: int = 32000
|
| min_frequency: int = 2
|
| special_tokens: Dict[str, str] = field(default_factory=lambda: {
|
| "bos_token": "<s>",
|
| "eos_token": "</s>",
|
| "pad_token": "<pad>",
|
| "unk_token": "<unk>",
|
| "mask_token": "<mask>",
|
| "system_token": "<system>",
|
| "user_token": "<user>",
|
| "assistant_token": "<assistant>",
|
| "thought_token": "<thought>",
|
| "/thought_token": "</thought>",
|
| })
|
|
|
|
|
| use_fast: bool = True
|
| padding_side: str = "right"
|
| truncation_side: str = "right"
|
| model_max_length: int = 32768
|
|
|
|
|
| enable_image_tokenization: bool = False
|
| enable_audio_tokenization: bool = False
|
|
|
|
|
| class AdvancedTokenizer:
|
| """Advanced tokenizer with custom training, optimization, and multi-modal support."""
|
|
|
| def __init__(self, config: TokenizerConfig):
|
| self.config = config
|
| self.tokenizer: Optional[PreTrainedTokenizer] = None
|
| self._special_tokens = list(config.special_tokens.values())
|
|
|
| def load_or_train(self, dataset: Optional[Any] = None) -> PreTrainedTokenizer:
|
| """Load existing tokenizer or train new one from dataset."""
|
| if not self.config.use_custom_tokenizer:
|
| logger.info(f"Loading pretrained tokenizer: {self.config.tokenizer_name}")
|
| self.tokenizer = AutoTokenizer.from_pretrained(
|
| self.config.tokenizer_name,
|
| use_fast=self.config.use_fast,
|
| padding_side=self.config.padding_side,
|
| truncation_side=self.config.truncation_side,
|
| model_max_length=self.config.model_max_length,
|
| )
|
| else:
|
| if dataset is None:
|
| raise ValueError("Dataset required for custom tokenizer training")
|
| logger.info("Training custom tokenizer from dataset")
|
| self.tokenizer = self._train_tokenizer(dataset)
|
|
|
|
|
| self._setup_special_tokens()
|
|
|
| return self.tokenizer
|
|
|
| def _train_tokenizer(self, dataset: Any) -> PreTrainedTokenizer:
|
| """Train tokenizer from scratch on dataset."""
|
|
|
| import tempfile
|
| temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False)
|
| temp_file.close()
|
|
|
|
|
| logger.info("Preparing training data...")
|
| with open(temp_file.name, 'w', encoding='utf-8') as f:
|
| for sample in dataset:
|
| text = self._extract_text_for_tokenizer(sample)
|
| if text:
|
| f.write(text + '\n')
|
|
|
|
|
| tokenizer = HFTokenizer(WordLevel(unk_token="<unk>"))
|
| tokenizer.pre_tokenizer = Whitespace()
|
|
|
| trainer = WordLevelTrainer(
|
| vocab_size=self.config.custom_vocab_size,
|
| min_frequency=self.config.min_frequency,
|
| special_tokens=self._special_tokens,
|
| )
|
|
|
| logger.info("Training tokenizer...")
|
| tokenizer.train([temp_file.name], trainer=trainer)
|
|
|
|
|
| from transformers import PreTrainedTokenizerFast
|
| fast_tokenizer = PreTrainedTokenizerFast(
|
| tokenizer_object=tokenizer,
|
| bos_token=self.config.special_tokens["bos_token"],
|
| eos_token=self.config.special_tokens["eos_token"],
|
| pad_token=self.config.special_tokens["pad_token"],
|
| unk_token=self.config.special_tokens["unk_token"],
|
| mask_token=self.config.special_tokens["mask_token"],
|
| padding_side=self.config.padding_side,
|
| truncation_side=self.config.truncation_side,
|
| model_max_length=self.config.model_max_length,
|
| )
|
|
|
|
|
| Path(temp_file.name).unlink(missing_ok=True)
|
|
|
| logger.info(f"Trained tokenizer with vocab size: {fast_tokenizer.vocab_size}")
|
| return fast_tokenizer
|
|
|
| def _extract_text_for_tokenizer(self, sample: Dict[str, Any]) -> str:
|
| """Extract text from sample for tokenizer training."""
|
| if "conversations" in sample:
|
| conv = sample["conversations"]
|
| if isinstance(conv, str):
|
| try:
|
| conv = json.loads(conv)
|
| except:
|
| return conv
|
| texts = []
|
| for msg in conv:
|
| if isinstance(msg, dict):
|
| role = msg.get("role", "")
|
| content = msg.get("content", "")
|
| if content:
|
|
|
| if role == "user":
|
| texts.append(f"{self.config.special_tokens['user_token']} {content}")
|
| elif role == "assistant":
|
| texts.append(f"{self.config.special_tokens['assistant_token']} {content}")
|
| elif role == "system":
|
| texts.append(f"{self.config.special_tokens['system_token']} {content}")
|
| else:
|
| texts.append(content)
|
| return "\n".join(texts)
|
| elif "text" in sample:
|
| return sample["text"]
|
| elif "content" in sample:
|
| return sample["content"]
|
| return ""
|
|
|
| def _setup_special_tokens(self):
|
| """Configure special tokens and post-processing."""
|
| if self.tokenizer is None:
|
| raise ValueError("Tokenizer not initialized")
|
|
|
|
|
| special_tokens_dict = {}
|
| for key, token in self.config.special_tokens.items():
|
| if token not in self.tokenizer.get_vocab():
|
| special_tokens_dict[key] = token
|
|
|
| if special_tokens_dict:
|
| self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
|
|
|
| if self.config.use_fast:
|
| self.tokenizer.chat_template = self._create_chat_template()
|
|
|
| def _create_chat_template(self) -> str:
|
| """Create Jinja2 chat template."""
|
| template = """{% for message in messages %}
|
| {% if message['role'] == 'system' %}{{ '{{' }} system {{ '}}' }}{{ message['content'] }}{{ '{{' }} /system {{ '}}' }}
|
| {% elif message['role'] == 'user' %}{{ '{{' }} user {{ '}}' }}{{ message['content'] }}{{ '{{' }} /user {{ '}}' }}
|
| {% elif message['role'] == 'assistant' %}{{ '{{' }} assistant {{ '}}' }}{{ message['content'] }}{{ '{{' }} /assistant {{ '}}' }}
|
| {% endif %}
|
| {% endfor %}"""
|
| return template
|
|
|
| def tokenize(
|
| self,
|
| text: Union[str, List[str]],
|
| **kwargs
|
| ) -> Dict[str, Any]:
|
| """Tokenize text with advanced options."""
|
| if self.tokenizer is None:
|
| raise ValueError("Tokenizer not initialized")
|
|
|
|
|
| tokenize_kwargs = {
|
| "truncation": True,
|
| "max_length": self.config.model_max_length,
|
| "padding": "max_length",
|
| "return_tensors": "pt",
|
| }
|
| tokenize_kwargs.update(kwargs)
|
|
|
| return self.tokenizer(text, **tokenize_kwargs)
|
|
|
| def decode(self, token_ids: Union[List[int], Any], **kwargs) -> str:
|
| """Decode token IDs to text."""
|
| if self.tokenizer is None:
|
| raise ValueError("Tokenizer not initialized")
|
| return self.tokenizer.decode(token_ids, **kwargs)
|
|
|
| def save(self, path: str):
|
| """Save tokenizer to disk."""
|
| if self.tokenizer is None:
|
| raise ValueError("Tokenizer not initialized")
|
| self.tokenizer.save_pretrained(path)
|
| logger.info(f"Tokenizer saved to {path}")
|
|
|
| @property
|
| def vocab_size(self) -> int:
|
| """Get vocabulary size."""
|
| if self.tokenizer is None:
|
| return 0
|
| return self.tokenizer.vocab_size
|
|
|
|
|
| class TokenizerManager:
|
| """Manages multiple tokenizers for different model sizes."""
|
|
|
| def __init__(self):
|
| self.tokenizers: Dict[str, AdvancedTokenizer] = {}
|
|
|
| def register_tokenizer(self, name: str, tokenizer: AdvancedTokenizer):
|
| """Register a tokenizer."""
|
| self.tokenizers[name] = tokenizer
|
|
|
| def get_tokenizer(self, name: str) -> PreTrainedTokenizer:
|
| """Get tokenizer by name."""
|
| if name not in self.tokenizers:
|
| raise KeyError(f"Tokenizer '{name}' not found")
|
| return self.tokenizers[name].tokenizer
|
|
|
| def load_all(self, dataset: Optional[Any] = None):
|
| """Load all registered tokenizers."""
|
| for name, tokenizer in self.tokenizers.items():
|
| logger.info(f"Loading tokenizer: {name}")
|
| tokenizer.load_or_train(dataset)
|
|
|
| def save_all(self, output_dir: str):
|
| """Save all tokenizers."""
|
| base_path = Path(output_dir)
|
| for name, tokenizer in self.tokenizers.items():
|
| save_path = base_path / name / "tokenizer"
|
| tokenizer.save(str(save_path))
|
|
|
|
|
| def create_tokenizer_for_model_size(
|
| model_size: str,
|
| config: TokenizerConfig,
|
| ) -> AdvancedTokenizer:
|
| """Create tokenizer configured for specific model size."""
|
| if model_size == "7b":
|
| config.model_max_length = 8192
|
| config.tokenizer_name = "meta-llama/Llama-2-7b-hf"
|
| elif model_size == "32b":
|
| config.model_max_length = 8192
|
| config.tokenizer_name = "Qwen/Qwen1.5-32B"
|
| elif model_size == "70b":
|
| config.model_max_length = 32768
|
| config.tokenizer_name = "meta-llama/Llama-2-70b-hf"
|
| else:
|
| raise ValueError(f"Unknown model size: {model_size}")
|
|
|
| return AdvancedTokenizer(config)
|
|
|