import os import json import pickle import argparse import logging import threading from collections import Counter, defaultdict, OrderedDict from typing import List, Dict, Set, Optional, Tuple, Union, Iterator, Any from dataclasses import dataclass, asdict from pathlib import Path import re import unicodedata import heapq from functools import lru_cache import time from contextlib import contextmanager # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) @dataclass class TokenizerConfig: """Configuration class with validation and serialization support""" vocab_size: int = 32000 min_freq: int = 2 max_token_length: int = 256 cache_size: int = 10000 chunk_size: int = 10000 # Special tokens pad_token: str = '' unk_token: str = '' bos_token: str = '' eos_token: str = '' # Technical domain specific enable_code_detection: bool = True enable_math_detection: bool = True enable_url_detection: bool = True def __post_init__(self): """Validate configuration parameters""" if self.vocab_size <= 0: raise ValueError(f"vocab_size must be positive, got {self.vocab_size}") if self.min_freq <= 0: raise ValueError(f"min_freq must be positive, got {self.min_freq}") if self.max_token_length <= 0: raise ValueError(f"max_token_length must be positive, got {self.max_token_length}") if self.cache_size <= 0: raise ValueError(f"cache_size must be positive, got {self.cache_size}") logger.info(f"TokenizerConfig validated: vocab_size={self.vocab_size}") def save(self, path: Union[str, Path]) -> None: """Save configuration to JSON file""" path = Path(path) with open(path, 'w', encoding='utf-8') as f: json.dump(asdict(self), f, indent=2, ensure_ascii=False) logger.info(f"Config saved to {path}") @classmethod def load(cls, path: Union[str, Path]) -> 'TokenizerConfig': """Load configuration from JSON file""" path = Path(path) if not path.exists(): raise FileNotFoundError(f"Config file not found: {path}") with open(path, 'r', encoding='utf-8') as f: config_dict = json.load(f) logger.info(f"Config loaded from {path}") return cls(**config_dict) class ThreadSafeLRUCache: """Thread-safe LRU cache with size limits""" def __init__(self, max_size: int = 10000): self.max_size = max_size self.cache = OrderedDict() self.lock = threading.RLock() def get(self, key: str) -> Optional[List[str]]: """Get value from cache""" with self.lock: if key in self.cache: # Move to end (most recently used) value = self.cache.pop(key) self.cache[key] = value return value return None def put(self, key: str, value: List[str]) -> None: """Add value to cache""" with self.lock: if key in self.cache: self.cache.pop(key) elif len(self.cache) >= self.max_size: # Remove least recently used item self.cache.popitem(last=False) self.cache[key] = value def clear(self) -> None: """Clear all cache entries""" with self.lock: self.cache.clear() def size(self) -> int: """Get current cache size""" with self.lock: return len(self.cache) class EfficientBPE: """Efficient BPE implementation using priority queues""" def __init__(self): self.merges: List[Tuple[str, str]] = [] self.merge_ranks: Dict[Tuple[str, str], int] = {} def train(self, word_counts: Dict[str, int], num_merges: int) -> None: """Train BPE using efficient algorithm with priority queue""" logger.info(f"Training BPE with {num_merges} merges") # Convert words to character sequences vocab = defaultdict(int) for word, count in word_counts.items(): vocab[tuple(word)] += count # Get all possible pairs and their frequencies def get_pairs(vocab_dict): pairs = defaultdict(int) for word, freq in vocab_dict.items(): if len(word) < 2: continue for i in range(len(word) - 1): pair = (word[i], word[i + 1]) pairs[pair] += freq return pairs for i in range(num_merges): if i % 1000 == 0: logger.info(f"BPE merge progress: {i}/{num_merges}") pairs = get_pairs(vocab) if not pairs: logger.warning(f"No more pairs available at merge {i}") break # Get most frequent pair best_pair = max(pairs.items(), key=lambda x: x[1])[0] # Merge the best pair new_vocab = {} bigram = best_pair for word, freq in vocab.items(): new_word = [] i = 0 while i < len(word): if i < len(word) - 1 and (word[i], word[i + 1]) == bigram: new_word.append(word[i] + word[i + 1]) i += 2 else: new_word.append(word[i]) i += 1 new_vocab[tuple(new_word)] = freq vocab = new_vocab self.merges.append(best_pair) self.merge_ranks[best_pair] = len(self.merges) - 1 logger.info(f"BPE training completed with {len(self.merges)} merges") def apply(self, word: str) -> List[str]: """Apply BPE merges to a word efficiently""" if len(word) <= 1: return list(word) # Start with character-level tokens word_tokens = list(word) # Apply merges in order for merge_pair in self.merges: if len(word_tokens) == 1: break new_tokens = [] i = 0 while i < len(word_tokens): if (i < len(word_tokens) - 1 and word_tokens[i] == merge_pair[0] and word_tokens[i + 1] == merge_pair[1]): new_tokens.append(merge_pair[0] + merge_pair[1]) i += 2 else: new_tokens.append(word_tokens[i]) i += 1 word_tokens = new_tokens return word_tokens class TechnicalTokenizer: """ Production-quality tokenizer for technical content with: - Efficient BPE implementation - Thread-safe caching - Memory-efficient streaming - Comprehensive error handling - Proper logging and monitoring """ def __init__(self, config: Optional[TokenizerConfig] = None): self.config = config or TokenizerConfig() # Core components self.vocab: Dict[str, int] = {} self.id_to_token: Dict[int, str] = {} self.token_frequencies: Counter = Counter() self.bpe = EfficientBPE() # Thread-safe cache self.cache = ThreadSafeLRUCache(self.config.cache_size) # Special tokens mapping self.special_tokens = { self.config.pad_token: 0, self.config.unk_token: 1, self.config.bos_token: 2, self.config.eos_token: 3, '': 4, '': 5, '': 6, '<|endoftext|>': 7, '<|newline|>': 8, '<|tab|>': 9, '<|code|>': 10, '<|/code|>': 11, '<|math|>': 12, '<|/math|>': 13, '': 14, '': 15, '': 16 } # Initialize vocabulary with special tokens self._initialize_vocab() # Compile regex patterns for efficiency self._compile_patterns() # Technical terms for priority processing self.technical_terms = self._load_technical_terms() logger.info(f"TechnicalTokenizer initialized with vocab_size={self.config.vocab_size}") def _initialize_vocab(self) -> None: """Initialize vocabulary with special tokens""" self.vocab = self.special_tokens.copy() self.id_to_token = {v: k for k, v in self.special_tokens.items()} def _compile_patterns(self) -> None: """Compile regex patterns for efficient text processing""" patterns = [] if self.config.enable_code_detection: patterns.extend([ r'```[\s\S]*?```', # Code blocks r'`[^`\n]+`', # Inline code ]) if self.config.enable_url_detection: patterns.append(r'https?://[^\s<>"{}|\\^`[\]]+') patterns.extend([ r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email r'<[^>]+>', # Special tokens r'\b\d+\.?\d*\b', # Numbers r'\b\w+(?:\'\w+)?\b', # Words with contractions r'[^\w\s]', # Punctuation ]) self.tokenizer_pattern = re.compile('|'.join(f'({pattern})' for pattern in patterns)) # Additional patterns for normalization self.newline_pattern = re.compile(r'\r\n|\r') self.tab_pattern = re.compile(r'\t') self.multiple_space_pattern = re.compile(r'\s+') def _load_technical_terms(self) -> Set[str]: """Load technical terms for priority processing""" return { # Programming 'function', 'variable', 'array', 'object', 'class', 'method', 'parameter', 'return', 'import', 'export', 'async', 'await', 'promise', 'callback', 'algorithm', 'datatype', 'boolean', # Languages 'python', 'javascript', 'java', 'cpp', 'rust', 'go', 'html', 'css', 'sql', 'typescript', 'kotlin', 'swift', # Web/API 'api', 'rest', 'graphql', 'json', 'xml', 'http', 'https', 'endpoint', 'request', 'response', 'authentication', # Math/ML 'neural', 'network', 'model', 'training', 'validation', 'accuracy', 'precision', 'recall', 'loss', 'gradient', 'derivative', 'integral', 'matrix', 'vector', 'tensor', 'transformer', 'attention', 'embedding', 'tokenization', # Infrastructure 'docker', 'kubernetes', 'microservice', 'database', 'server', 'client', 'deployment', 'scalability' } @contextmanager def _error_context(self, operation: str): """Context manager for consistent error handling""" try: yield except Exception as e: logger.error(f"Error in {operation}: {str(e)}") raise def normalize_text(self, text: str) -> str: """Normalize text with proper error handling""" if not isinstance(text, str): raise TypeError(f"Expected str, got {type(text)}") with self._error_context("text normalization"): # Basic normalization text = self.newline_pattern.sub('\n', text) text = self.tab_pattern.sub('<|tab|>', text) text = unicodedata.normalize('NFKC', text) # Handle special token markers text = re.sub(r'<\|system\|>', ' ', text) text = re.sub(r'<\|user\|>', ' ', text) text = re.sub(r'<\|assistant\|>', ' ', text) text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text) return text.strip() def pre_tokenize(self, text: str) -> List[str]: """Pre-tokenize text into words and special tokens""" if not text: return [] with self._error_context("pre-tokenization"): normalized_text = self.normalize_text(text) # Find all tokens using compiled pattern matches = self.tokenizer_pattern.findall(normalized_text) # Flatten the match groups and filter empty strings tokens = [] for match_groups in matches: for group in match_groups: if group: tokens.append(group) break return [token.strip() for token in tokens if token.strip()] def train_from_iterator(self, text_iterator: Iterator[str], total_texts: Optional[int] = None) -> None: """ Train tokenizer from text iterator for memory efficiency Args: text_iterator: Iterator yielding text strings total_texts: Optional total count for progress tracking """ logger.info("Starting BPE training from iterator") start_time = time.time() word_counts = Counter() processed_texts = 0 # Process texts in chunks to manage memory current_chunk = [] for text in text_iterator: current_chunk.append(text) processed_texts += 1 if len(current_chunk) >= self.config.chunk_size: self._process_text_chunk(current_chunk, word_counts) current_chunk.clear() if processed_texts % 10000 == 0: elapsed = time.time() - start_time logger.info(f"Processed {processed_texts} texts in {elapsed:.1f}s") # Process remaining texts if current_chunk: self._process_text_chunk(current_chunk, word_counts) logger.info(f"Pre-processing completed: {len(word_counts)} unique words") # Filter by frequency and boost technical terms filtered_words = {} for word, count in word_counts.items(): if count >= self.config.min_freq: # Boost technical terms if word.lower() in self.technical_terms: count *= 5 filtered_words[word] = count logger.info(f"After filtering: {len(filtered_words)} words") # Build character vocabulary all_chars = set() for word in filtered_words: all_chars.update(word) # Add characters to vocabulary for char in sorted(all_chars): if char not in self.vocab: token_id = len(self.vocab) self.vocab[char] = token_id self.id_to_token[token_id] = char # Calculate number of merges needed current_vocab_size = len(self.vocab) target_vocab_size = self.config.vocab_size num_merges = target_vocab_size - current_vocab_size if num_merges > 0: # Train BPE self.bpe.train(filtered_words, num_merges) # Add merged tokens to vocabulary for merge_pair in self.bpe.merges: merged_token = merge_pair[0] + merge_pair[1] if merged_token not in self.vocab: token_id = len(self.vocab) self.vocab[merged_token] = token_id self.id_to_token[token_id] = merged_token # Update token frequencies for word, count in filtered_words.items(): tokens = self.apply_bpe(word) for token in tokens: self.token_frequencies[token] += count training_time = time.time() - start_time logger.info(f"Training completed in {training_time:.1f}s") logger.info(f"Final vocabulary size: {len(self.vocab)}") def _process_text_chunk(self, texts: List[str], word_counts: Counter) -> None: """Process a chunk of texts and update word counts""" for text in texts: try: tokens = self.pre_tokenize(text) for token in tokens: if len(token) <= self.config.max_token_length: word_counts[token] += 1 except Exception as e: logger.warning(f"Error processing text chunk: {e}") continue def apply_bpe(self, word: str) -> List[str]: """Apply BPE to a word with caching""" if not word: return [] # Check cache first cached_result = self.cache.get(word) if cached_result is not None: return cached_result # Apply BPE tokens = self.bpe.apply(word) # Cache the result self.cache.put(word, tokens) return tokens def tokenize(self, text: str) -> List[str]: """Tokenize text into subword tokens""" if not text: return [] with self._error_context("tokenization"): pre_tokens = self.pre_tokenize(text) final_tokens = [] for token in pre_tokens: if token in self.special_tokens or token in self.vocab: final_tokens.append(token) else: bpe_tokens = self.apply_bpe(token) final_tokens.extend(bpe_tokens) return final_tokens def encode(self, text: str, add_special_tokens: bool = False) -> List[int]: """Encode text to token IDs""" if not isinstance(text, str): raise TypeError(f"Expected str, got {type(text)}") tokens = self.tokenize(text) if add_special_tokens: tokens = [self.config.bos_token] + tokens + [self.config.eos_token] ids = [] unk_id = self.vocab[self.config.unk_token] for token in tokens: token_id = self.vocab.get(token, unk_id) ids.append(token_id) return ids def decode(self, ids: List[int], skip_special_tokens: bool = False) -> str: """Decode token IDs to text""" if not isinstance(ids, (list, tuple)): raise TypeError(f"Expected list or tuple, got {type(ids)}") tokens = [] for token_id in ids: if not isinstance(token_id, int): raise TypeError(f"Expected int token ID, got {type(token_id)}") if token_id not in self.id_to_token: logger.warning(f"Unknown token ID: {token_id}") continue token = self.id_to_token[token_id] if skip_special_tokens and token in self.special_tokens: continue tokens.append(token) # Join tokens and clean up text = ''.join(tokens) text = text.replace('<|tab|>', '\t') text = text.replace('<|newline|>', '\n') return text def get_vocab_size(self) -> int: """Get vocabulary size""" return len(self.vocab) def get_vocab(self) -> Dict[str, int]: """Get vocabulary dictionary (copy for safety)""" return self.vocab.copy() def get_cache_info(self) -> Dict[str, int]: """Get cache statistics""" return { 'size': self.cache.size(), 'max_size': self.config.cache_size, 'hit_rate': getattr(self.cache, 'hit_rate', 0) } def save(self, save_dir: Union[str, Path]) -> None: """Save tokenizer with validation""" save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving tokenizer to {save_dir}") try: # Save configuration self.config.save(save_dir / 'config.json') # Save vocabulary with open(save_dir / 'vocab.json', 'w', encoding='utf-8') as f: json.dump(self.vocab, f, indent=2, ensure_ascii=False) # Save BPE merges with open(save_dir / 'merges.txt', 'w', encoding='utf-8') as f: for merge in self.bpe.merges: f.write(f"{merge[0]} {merge[1]}\n") # Save token frequencies with open(save_dir / 'frequencies.pkl', 'wb') as f: pickle.dump(dict(self.token_frequencies), f) # Save metadata metadata = { 'version': '2.0', 'vocab_size': len(self.vocab), 'num_merges': len(self.bpe.merges), 'special_tokens': self.special_tokens } with open(save_dir / 'metadata.json', 'w', encoding='utf-8') as f: json.dump(metadata, f, indent=2) logger.info("Tokenizer saved successfully") except Exception as e: logger.error(f"Error saving tokenizer: {e}") raise @classmethod def load(cls, save_dir: Union[str, Path]) -> 'TechnicalTokenizer': """Load tokenizer from directory""" save_dir = Path(save_dir) if not save_dir.exists(): raise FileNotFoundError(f"Tokenizer directory not found: {save_dir}") logger.info(f"Loading tokenizer from {save_dir}") try: # Load configuration config = TokenizerConfig.load(save_dir / 'config.json') # Create tokenizer instance tokenizer = cls(config) # Load vocabulary with open(save_dir / 'vocab.json', 'r', encoding='utf-8') as f: tokenizer.vocab = json.load(f) tokenizer.id_to_token = {v: k for k, v in tokenizer.vocab.items()} # Load BPE merges merges_file = save_dir / 'merges.txt' if merges_file.exists(): with open(merges_file, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: parts = line.split() if len(parts) == 2: tokenizer.bpe.merges.append(tuple(parts)) # Rebuild merge ranks tokenizer.bpe.merge_ranks = { merge: i for i, merge in enumerate(tokenizer.bpe.merges) } # Load token frequencies freq_file = save_dir / 'frequencies.pkl' if freq_file.exists(): with open(freq_file, 'rb') as f: freq_dict = pickle.load(f) tokenizer.token_frequencies = Counter(freq_dict) logger.info(f"Tokenizer loaded successfully") logger.info(f"Vocabulary size: {len(tokenizer.vocab)}") logger.info(f"Number of BPE merges: {len(tokenizer.bpe.merges)}") return tokenizer except Exception as e: logger.error(f"Error loading tokenizer: {e}") raise def create_text_iterator(file_paths: List[Union[str, Path]], max_texts: Optional[int] = None) -> Iterator[str]: """Create memory-efficient text iterator from multiple files""" processed_count = 0 for file_path in file_paths: file_path = Path(file_path) if not file_path.exists(): logger.warning(f"File not found: {file_path}") continue logger.info(f"Processing file: {file_path}") try: if file_path.suffix == '.jsonl': with open(file_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): try: data = json.loads(line.strip()) if 'messages' in data: # Conversation format texts = [] for msg in data['messages']: content = msg.get('content', '').strip() if content: texts.append(content) if texts: yield ' '.join(texts) processed_count += 1 elif 'text' in data: # Simple text format text = data['text'].strip() if text: yield text processed_count += 1 if max_texts and processed_count >= max_texts: return except json.JSONDecodeError as e: logger.warning(f"JSON decode error at line {line_num} in {file_path}: {e}") continue else: # Plain text file with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # Split by double newlines or other separators chunks = re.split(r'\n\s*\n', content) for chunk in chunks: chunk = chunk.strip() if chunk and len(chunk) > 50: # Skip very short chunks yield chunk processed_count += 1 if max_texts and processed_count >= max_texts: return except Exception as e: logger.error(f"Error processing file {file_path}: {e}") continue logger.info(f"Total texts processed: {processed_count}") def train_tokenizer(input_files: List[Union[str, Path]], output_dir: Union[str, Path], config: Optional[TokenizerConfig] = None, max_texts: Optional[int] = None) -> TechnicalTokenizer: """Train a new tokenizer from input files""" config = config or TokenizerConfig() tokenizer = TechnicalTokenizer(config) # Create text iterator text_iter = create_text_iterator(input_files, max_texts) # Train tokenizer tokenizer.train_from_iterator(text_iter) # Save tokenizer tokenizer.save(output_dir) return tokenizer def main(): """Main CLI interface""" parser = argparse.ArgumentParser( description="Production-Quality Technical Tokenizer", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Input/Output parser.add_argument('--input_files', nargs='+', help='Input files for training') parser.add_argument('--output_dir', default='tokenizer_output', help='Output directory for tokenizer') parser.add_argument('--load_from', help='Load existing tokenizer from directory') # Training parameters parser.add_argument('--vocab_size', type=int, default=32000, help='Target vocabulary size') parser.add_argument('--min_freq', type=int, default=2, help='Minimum token frequency') parser.add_argument('--max_texts', type=int, help='Maximum number of texts to process') parser.add_argument('--cache_size', type=int, default=10000, help='BPE cache size') # Testing parser.add_argument('--test_text', help='Test text for tokenization analysis') parser.add_argument('--benchmark', action='store_true', help='Run performance benchmarks') # Logging parser.add_argument('--verbose', action='store_true', help='Enable verbose logging') args = parser.parse_args() if args.verbose: logging.getLogger().setLevel(logging.DEBUG) try: if args.load_from: # Load existing tokenizer tokenizer = TechnicalTokenizer.load(args.load_from) if args.test_text: print(f"\nTokenization Analysis:") print(f"Text: {args.test_text}") tokens = tokenizer.tokenize(args.test_text) ids = tokenizer.encode(args.test_text) decoded = tokenizer.decode(ids) print(f"Tokens: {tokens}") print(f"Token IDs: {ids}") print(f"Decoded: {decoded}") print(f"Token count: {len(tokens)}") print(f"Compression ratio: {len(args.test_text.split()) / len(tokens):.2f}") if args.benchmark: run_benchmark(tokenizer) else: # Train new tokenizer if not args.input_files: parser.error("--input_files required when not loading existing tokenizer") # Create configuration config = TokenizerConfig( vocab_size=args.vocab_size, min_freq=args.min_freq, cache_size=args.cache_size ) # Train tokenizer tokenizer = train_tokenizer( input_files=args.input_files, output_dir=args.output_dir, config=config, max_texts=args.max_texts ) # Test on sample texts test_texts = [ "Hello, how can I help you with your Python programming question?", "The neural network architecture uses attention mechanisms for better performance.", "```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```", "The derivative of x² is 2x, and the integral is (x³)/3 + C." ] print("\nTokenization Analysis on Sample Texts:") print("=" * 50) for i, text in enumerate(test_texts, 1): print(f"\nTest {i}:") print(f"Text: {text}") tokens = tokenizer.tokenize(text) ids = tokenizer.encode(text) print(f"Tokens ({len(tokens)}): {tokens}") print(f"Token IDs: {ids}") word_count = len(text.split()) compression_ratio = word_count / len(tokens) if tokens else 0 print(f"Compression ratio: {compression_ratio:.2f}") print(f"\nTokenizer training completed!") print(f"Vocabulary size: {tokenizer.get_vocab_size()}") print(f"Cache info: {tokenizer.get_cache_info()}") except Exception as e: logger.error(f"Error in main: {e}") if args.verbose: import traceback traceback.print_exc() return 1 return 0 def run_benchmark(tokenizer: TechnicalTokenizer) -> None: """Run performance benchmarks on the tokenizer""" import time import random import string print("\nRunning Performance Benchmarks...") print("=" * 50) # Generate test data test_texts = [] # Short texts for _ in range(1000): length = random.randint(10, 50) text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10))) for _ in range(length)) test_texts.append(text) # Medium texts for _ in range(100): length = random.randint(100, 500) text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10))) for _ in range(length)) test_texts.append(text) # Long texts for _ in range(10): length = random.randint(1000, 5000) text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10))) for _ in range(length)) test_texts.append(text) # Benchmark tokenization print("Benchmarking tokenization...") start_time = time.time() total_tokens = 0 for text in test_texts: tokens = tokenizer.tokenize(text) total_tokens += len(tokens) tokenization_time = time.time() - start_time # Benchmark encoding print("Benchmarking encoding...") start_time = time.time() all_ids = [] for text in test_texts: ids = tokenizer.encode(text) all_ids.append(ids) encoding_time = time.time() - start_time # Benchmark decoding print("Benchmarking decoding...") start_time = time.time() for ids in all_ids: decoded = tokenizer.decode(ids) decoding_time = time.time() - start_time # Print results print(f"\nBenchmark Results:") print(f"Texts processed: {len(test_texts)}") print(f"Total tokens: {total_tokens:,}") print(f"Tokenization time: {tokenization_time:.3f}s") print(f"Encoding time: {encoding_time:.3f}s") print(f"Decoding time: {decoding_time:.3f}s") print(f"Tokenization speed: {total_tokens/tokenization_time:.0f} tokens/sec") print(f"Cache info: {tokenizer.get_cache_info()}") class TokenizerTester: """Comprehensive testing utilities for the tokenizer""" def __init__(self, tokenizer: TechnicalTokenizer): self.tokenizer = tokenizer def test_roundtrip_consistency(self, texts: List[str]) -> Dict[str, Any]: """Test encode/decode roundtrip consistency""" results = { 'total_tests': len(texts), 'passed': 0, 'failed': 0, 'failures': [] } for i, text in enumerate(texts): try: # Encode then decode ids = self.tokenizer.encode(text, add_special_tokens=False) decoded = self.tokenizer.decode(ids, skip_special_tokens=True) # Check if roundtrip preserves meaning (not exact match due to BPE) original_tokens = self.tokenizer.tokenize(text) decoded_tokens = self.tokenizer.tokenize(decoded) if len(original_tokens) == len(decoded_tokens): results['passed'] += 1 else: results['failed'] += 1 results['failures'].append({ 'index': i, 'original': text, 'decoded': decoded, 'original_tokens': len(original_tokens), 'decoded_tokens': len(decoded_tokens) }) except Exception as e: results['failed'] += 1 results['failures'].append({ 'index': i, 'error': str(e), 'text': text }) return results def test_special_tokens(self) -> Dict[str, bool]: """Test special token handling""" results = {} for token_name, token_id in self.tokenizer.special_tokens.items(): try: # Test encoding ids = self.tokenizer.encode(token_name, add_special_tokens=False) expected_id = self.tokenizer.vocab.get(token_name) # Test decoding decoded = self.tokenizer.decode([token_id]) results[token_name] = ( expected_id in ids and token_name in decoded ) except Exception: results[token_name] = False return results def test_edge_cases(self) -> Dict[str, bool]: """Test edge cases and error conditions""" tests = { 'empty_string': True, 'whitespace_only': True, 'very_long_text': True, 'unicode_text': True, 'special_chars': True } try: # Empty string result = self.tokenizer.encode("") tests['empty_string'] = isinstance(result, list) except Exception: tests['empty_string'] = False try: # Whitespace only result = self.tokenizer.encode(" \n\t ") tests['whitespace_only'] = isinstance(result, list) except Exception: tests['whitespace_only'] = False try: # Very long text long_text = "test " * 10000 result = self.tokenizer.encode(long_text) tests['very_long_text'] = isinstance(result, list) except Exception: tests['very_long_text'] = False try: # Unicode text unicode_text = "Hello 世界 🌍 café naïve" result = self.tokenizer.encode(unicode_text) tests['unicode_text'] = isinstance(result, list) except Exception: tests['unicode_text'] = False try: # Special characters special_text = "!@#$%^&*()_+-=[]{}|;:'\",.<>?/~`" result = self.tokenizer.encode(special_text) tests['special_chars'] = isinstance(result, list) except Exception: tests['special_chars'] = False return tests if __name__ == "__main__": exit(main())