|
import os |
|
import json |
|
import pickle |
|
import argparse |
|
import logging |
|
import threading |
|
from collections import Counter, defaultdict, OrderedDict |
|
from typing import List, Dict, Set, Optional, Tuple, Union, Iterator, Any |
|
from dataclasses import dataclass, asdict |
|
from pathlib import Path |
|
import re |
|
import unicodedata |
|
import heapq |
|
from functools import lru_cache |
|
import time |
|
from contextlib import contextmanager |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class TokenizerConfig: |
|
"""Configuration class with validation and serialization support""" |
|
|
|
vocab_size: int = 32000 |
|
min_freq: int = 2 |
|
max_token_length: int = 256 |
|
cache_size: int = 10000 |
|
chunk_size: int = 10000 |
|
|
|
|
|
pad_token: str = '<pad>' |
|
unk_token: str = '<unk>' |
|
bos_token: str = '<bos>' |
|
eos_token: str = '<eos>' |
|
|
|
|
|
enable_code_detection: bool = True |
|
enable_math_detection: bool = True |
|
enable_url_detection: bool = True |
|
|
|
def __post_init__(self): |
|
"""Validate configuration parameters""" |
|
if self.vocab_size <= 0: |
|
raise ValueError(f"vocab_size must be positive, got {self.vocab_size}") |
|
if self.min_freq <= 0: |
|
raise ValueError(f"min_freq must be positive, got {self.min_freq}") |
|
if self.max_token_length <= 0: |
|
raise ValueError(f"max_token_length must be positive, got {self.max_token_length}") |
|
if self.cache_size <= 0: |
|
raise ValueError(f"cache_size must be positive, got {self.cache_size}") |
|
|
|
logger.info(f"TokenizerConfig validated: vocab_size={self.vocab_size}") |
|
|
|
def save(self, path: Union[str, Path]) -> None: |
|
"""Save configuration to JSON file""" |
|
path = Path(path) |
|
with open(path, 'w', encoding='utf-8') as f: |
|
json.dump(asdict(self), f, indent=2, ensure_ascii=False) |
|
logger.info(f"Config saved to {path}") |
|
|
|
@classmethod |
|
def load(cls, path: Union[str, Path]) -> 'TokenizerConfig': |
|
"""Load configuration from JSON file""" |
|
path = Path(path) |
|
if not path.exists(): |
|
raise FileNotFoundError(f"Config file not found: {path}") |
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
config_dict = json.load(f) |
|
|
|
logger.info(f"Config loaded from {path}") |
|
return cls(**config_dict) |
|
|
|
|
|
class ThreadSafeLRUCache: |
|
"""Thread-safe LRU cache with size limits""" |
|
|
|
def __init__(self, max_size: int = 10000): |
|
self.max_size = max_size |
|
self.cache = OrderedDict() |
|
self.lock = threading.RLock() |
|
|
|
def get(self, key: str) -> Optional[List[str]]: |
|
"""Get value from cache""" |
|
with self.lock: |
|
if key in self.cache: |
|
|
|
value = self.cache.pop(key) |
|
self.cache[key] = value |
|
return value |
|
return None |
|
|
|
def put(self, key: str, value: List[str]) -> None: |
|
"""Add value to cache""" |
|
with self.lock: |
|
if key in self.cache: |
|
self.cache.pop(key) |
|
elif len(self.cache) >= self.max_size: |
|
|
|
self.cache.popitem(last=False) |
|
|
|
self.cache[key] = value |
|
|
|
def clear(self) -> None: |
|
"""Clear all cache entries""" |
|
with self.lock: |
|
self.cache.clear() |
|
|
|
def size(self) -> int: |
|
"""Get current cache size""" |
|
with self.lock: |
|
return len(self.cache) |
|
|
|
|
|
class EfficientBPE: |
|
"""Efficient BPE implementation using priority queues""" |
|
|
|
def __init__(self): |
|
self.merges: List[Tuple[str, str]] = [] |
|
self.merge_ranks: Dict[Tuple[str, str], int] = {} |
|
|
|
def train(self, word_counts: Dict[str, int], num_merges: int) -> None: |
|
"""Train BPE using efficient algorithm with priority queue""" |
|
logger.info(f"Training BPE with {num_merges} merges") |
|
|
|
|
|
vocab = defaultdict(int) |
|
for word, count in word_counts.items(): |
|
vocab[tuple(word)] += count |
|
|
|
|
|
def get_pairs(vocab_dict): |
|
pairs = defaultdict(int) |
|
for word, freq in vocab_dict.items(): |
|
if len(word) < 2: |
|
continue |
|
for i in range(len(word) - 1): |
|
pair = (word[i], word[i + 1]) |
|
pairs[pair] += freq |
|
return pairs |
|
|
|
for i in range(num_merges): |
|
if i % 1000 == 0: |
|
logger.info(f"BPE merge progress: {i}/{num_merges}") |
|
|
|
pairs = get_pairs(vocab) |
|
if not pairs: |
|
logger.warning(f"No more pairs available at merge {i}") |
|
break |
|
|
|
|
|
best_pair = max(pairs.items(), key=lambda x: x[1])[0] |
|
|
|
|
|
new_vocab = {} |
|
bigram = best_pair |
|
|
|
for word, freq in vocab.items(): |
|
new_word = [] |
|
i = 0 |
|
while i < len(word): |
|
if i < len(word) - 1 and (word[i], word[i + 1]) == bigram: |
|
new_word.append(word[i] + word[i + 1]) |
|
i += 2 |
|
else: |
|
new_word.append(word[i]) |
|
i += 1 |
|
new_vocab[tuple(new_word)] = freq |
|
|
|
vocab = new_vocab |
|
self.merges.append(best_pair) |
|
self.merge_ranks[best_pair] = len(self.merges) - 1 |
|
|
|
logger.info(f"BPE training completed with {len(self.merges)} merges") |
|
|
|
def apply(self, word: str) -> List[str]: |
|
"""Apply BPE merges to a word efficiently""" |
|
if len(word) <= 1: |
|
return list(word) |
|
|
|
|
|
word_tokens = list(word) |
|
|
|
|
|
for merge_pair in self.merges: |
|
if len(word_tokens) == 1: |
|
break |
|
|
|
new_tokens = [] |
|
i = 0 |
|
while i < len(word_tokens): |
|
if (i < len(word_tokens) - 1 and |
|
word_tokens[i] == merge_pair[0] and |
|
word_tokens[i + 1] == merge_pair[1]): |
|
new_tokens.append(merge_pair[0] + merge_pair[1]) |
|
i += 2 |
|
else: |
|
new_tokens.append(word_tokens[i]) |
|
i += 1 |
|
|
|
word_tokens = new_tokens |
|
|
|
return word_tokens |
|
|
|
|
|
class TechnicalTokenizer: |
|
""" |
|
Production-quality tokenizer for technical content with: |
|
- Efficient BPE implementation |
|
- Thread-safe caching |
|
- Memory-efficient streaming |
|
- Comprehensive error handling |
|
- Proper logging and monitoring |
|
""" |
|
|
|
def __init__(self, config: Optional[TokenizerConfig] = None): |
|
self.config = config or TokenizerConfig() |
|
|
|
|
|
self.vocab: Dict[str, int] = {} |
|
self.id_to_token: Dict[int, str] = {} |
|
self.token_frequencies: Counter = Counter() |
|
self.bpe = EfficientBPE() |
|
|
|
|
|
self.cache = ThreadSafeLRUCache(self.config.cache_size) |
|
|
|
|
|
self.special_tokens = { |
|
self.config.pad_token: 0, |
|
self.config.unk_token: 1, |
|
self.config.bos_token: 2, |
|
self.config.eos_token: 3, |
|
'<system>': 4, |
|
'<user>': 5, |
|
'<assistant>': 6, |
|
'<|endoftext|>': 7, |
|
'<|newline|>': 8, |
|
'<|tab|>': 9, |
|
'<|code|>': 10, |
|
'<|/code|>': 11, |
|
'<|math|>': 12, |
|
'<|/math|>': 13, |
|
'<URL>': 14, |
|
'<EMAIL>': 15, |
|
'<NUMBER>': 16 |
|
} |
|
|
|
|
|
self._initialize_vocab() |
|
|
|
|
|
self._compile_patterns() |
|
|
|
|
|
self.technical_terms = self._load_technical_terms() |
|
|
|
logger.info(f"TechnicalTokenizer initialized with vocab_size={self.config.vocab_size}") |
|
|
|
def _initialize_vocab(self) -> None: |
|
"""Initialize vocabulary with special tokens""" |
|
self.vocab = self.special_tokens.copy() |
|
self.id_to_token = {v: k for k, v in self.special_tokens.items()} |
|
|
|
def _compile_patterns(self) -> None: |
|
"""Compile regex patterns for efficient text processing""" |
|
patterns = [] |
|
|
|
if self.config.enable_code_detection: |
|
patterns.extend([ |
|
r'```[\s\S]*?```', |
|
r'`[^`\n]+`', |
|
]) |
|
|
|
if self.config.enable_url_detection: |
|
patterns.append(r'https?://[^\s<>"{}|\\^`[\]]+') |
|
|
|
patterns.extend([ |
|
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', |
|
r'<[^>]+>', |
|
r'\b\d+\.?\d*\b', |
|
r'\b\w+(?:\'\w+)?\b', |
|
r'[^\w\s]', |
|
]) |
|
|
|
self.tokenizer_pattern = re.compile('|'.join(f'({pattern})' for pattern in patterns)) |
|
|
|
|
|
self.newline_pattern = re.compile(r'\r\n|\r') |
|
self.tab_pattern = re.compile(r'\t') |
|
self.multiple_space_pattern = re.compile(r'\s+') |
|
|
|
def _load_technical_terms(self) -> Set[str]: |
|
"""Load technical terms for priority processing""" |
|
return { |
|
|
|
'function', 'variable', 'array', 'object', 'class', 'method', |
|
'parameter', 'return', 'import', 'export', 'async', 'await', |
|
'promise', 'callback', 'algorithm', 'datatype', 'boolean', |
|
|
|
|
|
'python', 'javascript', 'java', 'cpp', 'rust', 'go', |
|
'html', 'css', 'sql', 'typescript', 'kotlin', 'swift', |
|
|
|
|
|
'api', 'rest', 'graphql', 'json', 'xml', 'http', 'https', |
|
'endpoint', 'request', 'response', 'authentication', |
|
|
|
|
|
'neural', 'network', 'model', 'training', 'validation', |
|
'accuracy', 'precision', 'recall', 'loss', 'gradient', |
|
'derivative', 'integral', 'matrix', 'vector', 'tensor', |
|
'transformer', 'attention', 'embedding', 'tokenization', |
|
|
|
|
|
'docker', 'kubernetes', 'microservice', 'database', |
|
'server', 'client', 'deployment', 'scalability' |
|
} |
|
|
|
@contextmanager |
|
def _error_context(self, operation: str): |
|
"""Context manager for consistent error handling""" |
|
try: |
|
yield |
|
except Exception as e: |
|
logger.error(f"Error in {operation}: {str(e)}") |
|
raise |
|
|
|
def normalize_text(self, text: str) -> str: |
|
"""Normalize text with proper error handling""" |
|
if not isinstance(text, str): |
|
raise TypeError(f"Expected str, got {type(text)}") |
|
|
|
with self._error_context("text normalization"): |
|
|
|
text = self.newline_pattern.sub('\n', text) |
|
text = self.tab_pattern.sub('<|tab|>', text) |
|
text = unicodedata.normalize('NFKC', text) |
|
|
|
|
|
text = re.sub(r'<\|system\|>', ' <system> ', text) |
|
text = re.sub(r'<\|user\|>', ' <user> ', text) |
|
text = re.sub(r'<\|assistant\|>', ' <assistant> ', text) |
|
text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text) |
|
|
|
return text.strip() |
|
|
|
def pre_tokenize(self, text: str) -> List[str]: |
|
"""Pre-tokenize text into words and special tokens""" |
|
if not text: |
|
return [] |
|
|
|
with self._error_context("pre-tokenization"): |
|
normalized_text = self.normalize_text(text) |
|
|
|
|
|
matches = self.tokenizer_pattern.findall(normalized_text) |
|
|
|
|
|
tokens = [] |
|
for match_groups in matches: |
|
for group in match_groups: |
|
if group: |
|
tokens.append(group) |
|
break |
|
|
|
return [token.strip() for token in tokens if token.strip()] |
|
|
|
def train_from_iterator(self, text_iterator: Iterator[str], |
|
total_texts: Optional[int] = None) -> None: |
|
""" |
|
Train tokenizer from text iterator for memory efficiency |
|
|
|
Args: |
|
text_iterator: Iterator yielding text strings |
|
total_texts: Optional total count for progress tracking |
|
""" |
|
logger.info("Starting BPE training from iterator") |
|
start_time = time.time() |
|
|
|
word_counts = Counter() |
|
processed_texts = 0 |
|
|
|
|
|
current_chunk = [] |
|
|
|
for text in text_iterator: |
|
current_chunk.append(text) |
|
processed_texts += 1 |
|
|
|
if len(current_chunk) >= self.config.chunk_size: |
|
self._process_text_chunk(current_chunk, word_counts) |
|
current_chunk.clear() |
|
|
|
if processed_texts % 10000 == 0: |
|
elapsed = time.time() - start_time |
|
logger.info(f"Processed {processed_texts} texts in {elapsed:.1f}s") |
|
|
|
|
|
if current_chunk: |
|
self._process_text_chunk(current_chunk, word_counts) |
|
|
|
logger.info(f"Pre-processing completed: {len(word_counts)} unique words") |
|
|
|
|
|
filtered_words = {} |
|
for word, count in word_counts.items(): |
|
if count >= self.config.min_freq: |
|
|
|
if word.lower() in self.technical_terms: |
|
count *= 5 |
|
filtered_words[word] = count |
|
|
|
logger.info(f"After filtering: {len(filtered_words)} words") |
|
|
|
|
|
all_chars = set() |
|
for word in filtered_words: |
|
all_chars.update(word) |
|
|
|
|
|
for char in sorted(all_chars): |
|
if char not in self.vocab: |
|
token_id = len(self.vocab) |
|
self.vocab[char] = token_id |
|
self.id_to_token[token_id] = char |
|
|
|
|
|
current_vocab_size = len(self.vocab) |
|
target_vocab_size = self.config.vocab_size |
|
num_merges = target_vocab_size - current_vocab_size |
|
|
|
if num_merges > 0: |
|
|
|
self.bpe.train(filtered_words, num_merges) |
|
|
|
|
|
for merge_pair in self.bpe.merges: |
|
merged_token = merge_pair[0] + merge_pair[1] |
|
if merged_token not in self.vocab: |
|
token_id = len(self.vocab) |
|
self.vocab[merged_token] = token_id |
|
self.id_to_token[token_id] = merged_token |
|
|
|
|
|
for word, count in filtered_words.items(): |
|
tokens = self.apply_bpe(word) |
|
for token in tokens: |
|
self.token_frequencies[token] += count |
|
|
|
training_time = time.time() - start_time |
|
logger.info(f"Training completed in {training_time:.1f}s") |
|
logger.info(f"Final vocabulary size: {len(self.vocab)}") |
|
|
|
def _process_text_chunk(self, texts: List[str], word_counts: Counter) -> None: |
|
"""Process a chunk of texts and update word counts""" |
|
for text in texts: |
|
try: |
|
tokens = self.pre_tokenize(text) |
|
for token in tokens: |
|
if len(token) <= self.config.max_token_length: |
|
word_counts[token] += 1 |
|
except Exception as e: |
|
logger.warning(f"Error processing text chunk: {e}") |
|
continue |
|
|
|
def apply_bpe(self, word: str) -> List[str]: |
|
"""Apply BPE to a word with caching""" |
|
if not word: |
|
return [] |
|
|
|
|
|
cached_result = self.cache.get(word) |
|
if cached_result is not None: |
|
return cached_result |
|
|
|
|
|
tokens = self.bpe.apply(word) |
|
|
|
|
|
self.cache.put(word, tokens) |
|
|
|
return tokens |
|
|
|
def tokenize(self, text: str) -> List[str]: |
|
"""Tokenize text into subword tokens""" |
|
if not text: |
|
return [] |
|
|
|
with self._error_context("tokenization"): |
|
pre_tokens = self.pre_tokenize(text) |
|
final_tokens = [] |
|
|
|
for token in pre_tokens: |
|
if token in self.special_tokens or token in self.vocab: |
|
final_tokens.append(token) |
|
else: |
|
bpe_tokens = self.apply_bpe(token) |
|
final_tokens.extend(bpe_tokens) |
|
|
|
return final_tokens |
|
|
|
def encode(self, text: str, add_special_tokens: bool = False) -> List[int]: |
|
"""Encode text to token IDs""" |
|
if not isinstance(text, str): |
|
raise TypeError(f"Expected str, got {type(text)}") |
|
|
|
tokens = self.tokenize(text) |
|
|
|
if add_special_tokens: |
|
tokens = [self.config.bos_token] + tokens + [self.config.eos_token] |
|
|
|
ids = [] |
|
unk_id = self.vocab[self.config.unk_token] |
|
|
|
for token in tokens: |
|
token_id = self.vocab.get(token, unk_id) |
|
ids.append(token_id) |
|
|
|
return ids |
|
|
|
def decode(self, ids: List[int], skip_special_tokens: bool = False) -> str: |
|
"""Decode token IDs to text""" |
|
if not isinstance(ids, (list, tuple)): |
|
raise TypeError(f"Expected list or tuple, got {type(ids)}") |
|
|
|
tokens = [] |
|
for token_id in ids: |
|
if not isinstance(token_id, int): |
|
raise TypeError(f"Expected int token ID, got {type(token_id)}") |
|
|
|
if token_id not in self.id_to_token: |
|
logger.warning(f"Unknown token ID: {token_id}") |
|
continue |
|
|
|
token = self.id_to_token[token_id] |
|
|
|
if skip_special_tokens and token in self.special_tokens: |
|
continue |
|
|
|
tokens.append(token) |
|
|
|
|
|
text = ''.join(tokens) |
|
text = text.replace('<|tab|>', '\t') |
|
text = text.replace('<|newline|>', '\n') |
|
|
|
return text |
|
|
|
def get_vocab_size(self) -> int: |
|
"""Get vocabulary size""" |
|
return len(self.vocab) |
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
"""Get vocabulary dictionary (copy for safety)""" |
|
return self.vocab.copy() |
|
|
|
def get_cache_info(self) -> Dict[str, int]: |
|
"""Get cache statistics""" |
|
return { |
|
'size': self.cache.size(), |
|
'max_size': self.config.cache_size, |
|
'hit_rate': getattr(self.cache, 'hit_rate', 0) |
|
} |
|
|
|
def save(self, save_dir: Union[str, Path]) -> None: |
|
"""Save tokenizer with validation""" |
|
save_dir = Path(save_dir) |
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"Saving tokenizer to {save_dir}") |
|
|
|
try: |
|
|
|
self.config.save(save_dir / 'config.json') |
|
|
|
|
|
with open(save_dir / 'vocab.json', 'w', encoding='utf-8') as f: |
|
json.dump(self.vocab, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
with open(save_dir / 'merges.txt', 'w', encoding='utf-8') as f: |
|
for merge in self.bpe.merges: |
|
f.write(f"{merge[0]} {merge[1]}\n") |
|
|
|
|
|
with open(save_dir / 'frequencies.pkl', 'wb') as f: |
|
pickle.dump(dict(self.token_frequencies), f) |
|
|
|
|
|
metadata = { |
|
'version': '2.0', |
|
'vocab_size': len(self.vocab), |
|
'num_merges': len(self.bpe.merges), |
|
'special_tokens': self.special_tokens |
|
} |
|
|
|
with open(save_dir / 'metadata.json', 'w', encoding='utf-8') as f: |
|
json.dump(metadata, f, indent=2) |
|
|
|
logger.info("Tokenizer saved successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error saving tokenizer: {e}") |
|
raise |
|
|
|
@classmethod |
|
def load(cls, save_dir: Union[str, Path]) -> 'TechnicalTokenizer': |
|
"""Load tokenizer from directory""" |
|
save_dir = Path(save_dir) |
|
|
|
if not save_dir.exists(): |
|
raise FileNotFoundError(f"Tokenizer directory not found: {save_dir}") |
|
|
|
logger.info(f"Loading tokenizer from {save_dir}") |
|
|
|
try: |
|
|
|
config = TokenizerConfig.load(save_dir / 'config.json') |
|
|
|
|
|
tokenizer = cls(config) |
|
|
|
|
|
with open(save_dir / 'vocab.json', 'r', encoding='utf-8') as f: |
|
tokenizer.vocab = json.load(f) |
|
|
|
tokenizer.id_to_token = {v: k for k, v in tokenizer.vocab.items()} |
|
|
|
|
|
merges_file = save_dir / 'merges.txt' |
|
if merges_file.exists(): |
|
with open(merges_file, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
line = line.strip() |
|
if line: |
|
parts = line.split() |
|
if len(parts) == 2: |
|
tokenizer.bpe.merges.append(tuple(parts)) |
|
|
|
|
|
tokenizer.bpe.merge_ranks = { |
|
merge: i for i, merge in enumerate(tokenizer.bpe.merges) |
|
} |
|
|
|
|
|
freq_file = save_dir / 'frequencies.pkl' |
|
if freq_file.exists(): |
|
with open(freq_file, 'rb') as f: |
|
freq_dict = pickle.load(f) |
|
tokenizer.token_frequencies = Counter(freq_dict) |
|
|
|
logger.info(f"Tokenizer loaded successfully") |
|
logger.info(f"Vocabulary size: {len(tokenizer.vocab)}") |
|
logger.info(f"Number of BPE merges: {len(tokenizer.bpe.merges)}") |
|
|
|
return tokenizer |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading tokenizer: {e}") |
|
raise |
|
|
|
|
|
def create_text_iterator(file_paths: List[Union[str, Path]], |
|
max_texts: Optional[int] = None) -> Iterator[str]: |
|
"""Create memory-efficient text iterator from multiple files""" |
|
processed_count = 0 |
|
|
|
for file_path in file_paths: |
|
file_path = Path(file_path) |
|
|
|
if not file_path.exists(): |
|
logger.warning(f"File not found: {file_path}") |
|
continue |
|
|
|
logger.info(f"Processing file: {file_path}") |
|
|
|
try: |
|
if file_path.suffix == '.jsonl': |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
for line_num, line in enumerate(f, 1): |
|
try: |
|
data = json.loads(line.strip()) |
|
|
|
if 'messages' in data: |
|
|
|
texts = [] |
|
for msg in data['messages']: |
|
content = msg.get('content', '').strip() |
|
if content: |
|
texts.append(content) |
|
if texts: |
|
yield ' '.join(texts) |
|
processed_count += 1 |
|
|
|
elif 'text' in data: |
|
|
|
text = data['text'].strip() |
|
if text: |
|
yield text |
|
processed_count += 1 |
|
|
|
if max_texts and processed_count >= max_texts: |
|
return |
|
|
|
except json.JSONDecodeError as e: |
|
logger.warning(f"JSON decode error at line {line_num} in {file_path}: {e}") |
|
continue |
|
|
|
else: |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
|
|
chunks = re.split(r'\n\s*\n', content) |
|
|
|
for chunk in chunks: |
|
chunk = chunk.strip() |
|
if chunk and len(chunk) > 50: |
|
yield chunk |
|
processed_count += 1 |
|
|
|
if max_texts and processed_count >= max_texts: |
|
return |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing file {file_path}: {e}") |
|
continue |
|
|
|
logger.info(f"Total texts processed: {processed_count}") |
|
|
|
|
|
def train_tokenizer(input_files: List[Union[str, Path]], |
|
output_dir: Union[str, Path], |
|
config: Optional[TokenizerConfig] = None, |
|
max_texts: Optional[int] = None) -> TechnicalTokenizer: |
|
"""Train a new tokenizer from input files""" |
|
|
|
config = config or TokenizerConfig() |
|
tokenizer = TechnicalTokenizer(config) |
|
|
|
|
|
text_iter = create_text_iterator(input_files, max_texts) |
|
|
|
|
|
tokenizer.train_from_iterator(text_iter) |
|
|
|
|
|
tokenizer.save(output_dir) |
|
|
|
return tokenizer |
|
|
|
|
|
def main(): |
|
"""Main CLI interface""" |
|
parser = argparse.ArgumentParser( |
|
description="Production-Quality Technical Tokenizer", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
) |
|
|
|
|
|
parser.add_argument('--input_files', nargs='+', |
|
help='Input files for training') |
|
parser.add_argument('--output_dir', default='tokenizer_output', |
|
help='Output directory for tokenizer') |
|
parser.add_argument('--load_from', |
|
help='Load existing tokenizer from directory') |
|
|
|
|
|
parser.add_argument('--vocab_size', type=int, default=32000, |
|
help='Target vocabulary size') |
|
parser.add_argument('--min_freq', type=int, default=2, |
|
help='Minimum token frequency') |
|
parser.add_argument('--max_texts', type=int, |
|
help='Maximum number of texts to process') |
|
parser.add_argument('--cache_size', type=int, default=10000, |
|
help='BPE cache size') |
|
|
|
|
|
parser.add_argument('--test_text', |
|
help='Test text for tokenization analysis') |
|
parser.add_argument('--benchmark', action='store_true', |
|
help='Run performance benchmarks') |
|
|
|
|
|
parser.add_argument('--verbose', action='store_true', |
|
help='Enable verbose logging') |
|
|
|
args = parser.parse_args() |
|
|
|
if args.verbose: |
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
try: |
|
if args.load_from: |
|
|
|
tokenizer = TechnicalTokenizer.load(args.load_from) |
|
|
|
if args.test_text: |
|
print(f"\nTokenization Analysis:") |
|
print(f"Text: {args.test_text}") |
|
tokens = tokenizer.tokenize(args.test_text) |
|
ids = tokenizer.encode(args.test_text) |
|
decoded = tokenizer.decode(ids) |
|
print(f"Tokens: {tokens}") |
|
print(f"Token IDs: {ids}") |
|
print(f"Decoded: {decoded}") |
|
print(f"Token count: {len(tokens)}") |
|
print(f"Compression ratio: {len(args.test_text.split()) / len(tokens):.2f}") |
|
|
|
if args.benchmark: |
|
run_benchmark(tokenizer) |
|
|
|
else: |
|
|
|
if not args.input_files: |
|
parser.error("--input_files required when not loading existing tokenizer") |
|
|
|
|
|
config = TokenizerConfig( |
|
vocab_size=args.vocab_size, |
|
min_freq=args.min_freq, |
|
cache_size=args.cache_size |
|
) |
|
|
|
|
|
tokenizer = train_tokenizer( |
|
input_files=args.input_files, |
|
output_dir=args.output_dir, |
|
config=config, |
|
max_texts=args.max_texts |
|
) |
|
|
|
|
|
test_texts = [ |
|
"Hello, how can I help you with your Python programming question?", |
|
"The neural network architecture uses attention mechanisms for better performance.", |
|
"```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```", |
|
"The derivative of x² is 2x, and the integral is (x³)/3 + C." |
|
] |
|
|
|
print("\nTokenization Analysis on Sample Texts:") |
|
print("=" * 50) |
|
|
|
for i, text in enumerate(test_texts, 1): |
|
print(f"\nTest {i}:") |
|
print(f"Text: {text}") |
|
tokens = tokenizer.tokenize(text) |
|
ids = tokenizer.encode(text) |
|
print(f"Tokens ({len(tokens)}): {tokens}") |
|
print(f"Token IDs: {ids}") |
|
word_count = len(text.split()) |
|
compression_ratio = word_count / len(tokens) if tokens else 0 |
|
print(f"Compression ratio: {compression_ratio:.2f}") |
|
|
|
print(f"\nTokenizer training completed!") |
|
print(f"Vocabulary size: {tokenizer.get_vocab_size()}") |
|
print(f"Cache info: {tokenizer.get_cache_info()}") |
|
|
|
except Exception as e: |
|
logger.error(f"Error in main: {e}") |
|
if args.verbose: |
|
import traceback |
|
traceback.print_exc() |
|
return 1 |
|
|
|
return 0 |
|
|
|
|
|
def run_benchmark(tokenizer: TechnicalTokenizer) -> None: |
|
"""Run performance benchmarks on the tokenizer""" |
|
import time |
|
import random |
|
import string |
|
|
|
print("\nRunning Performance Benchmarks...") |
|
print("=" * 50) |
|
|
|
|
|
test_texts = [] |
|
|
|
|
|
for _ in range(1000): |
|
length = random.randint(10, 50) |
|
text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10))) |
|
for _ in range(length)) |
|
test_texts.append(text) |
|
|
|
|
|
for _ in range(100): |
|
length = random.randint(100, 500) |
|
text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10))) |
|
for _ in range(length)) |
|
test_texts.append(text) |
|
|
|
|
|
for _ in range(10): |
|
length = random.randint(1000, 5000) |
|
text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10))) |
|
for _ in range(length)) |
|
test_texts.append(text) |
|
|
|
|
|
print("Benchmarking tokenization...") |
|
start_time = time.time() |
|
|
|
total_tokens = 0 |
|
for text in test_texts: |
|
tokens = tokenizer.tokenize(text) |
|
total_tokens += len(tokens) |
|
|
|
tokenization_time = time.time() - start_time |
|
|
|
|
|
print("Benchmarking encoding...") |
|
start_time = time.time() |
|
|
|
all_ids = [] |
|
for text in test_texts: |
|
ids = tokenizer.encode(text) |
|
all_ids.append(ids) |
|
|
|
encoding_time = time.time() - start_time |
|
|
|
|
|
print("Benchmarking decoding...") |
|
start_time = time.time() |
|
|
|
for ids in all_ids: |
|
decoded = tokenizer.decode(ids) |
|
|
|
decoding_time = time.time() - start_time |
|
|
|
|
|
print(f"\nBenchmark Results:") |
|
print(f"Texts processed: {len(test_texts)}") |
|
print(f"Total tokens: {total_tokens:,}") |
|
print(f"Tokenization time: {tokenization_time:.3f}s") |
|
print(f"Encoding time: {encoding_time:.3f}s") |
|
print(f"Decoding time: {decoding_time:.3f}s") |
|
print(f"Tokenization speed: {total_tokens/tokenization_time:.0f} tokens/sec") |
|
print(f"Cache info: {tokenizer.get_cache_info()}") |
|
|
|
|
|
class TokenizerTester: |
|
"""Comprehensive testing utilities for the tokenizer""" |
|
|
|
def __init__(self, tokenizer: TechnicalTokenizer): |
|
self.tokenizer = tokenizer |
|
|
|
def test_roundtrip_consistency(self, texts: List[str]) -> Dict[str, Any]: |
|
"""Test encode/decode roundtrip consistency""" |
|
results = { |
|
'total_tests': len(texts), |
|
'passed': 0, |
|
'failed': 0, |
|
'failures': [] |
|
} |
|
|
|
for i, text in enumerate(texts): |
|
try: |
|
|
|
ids = self.tokenizer.encode(text, add_special_tokens=False) |
|
decoded = self.tokenizer.decode(ids, skip_special_tokens=True) |
|
|
|
|
|
original_tokens = self.tokenizer.tokenize(text) |
|
decoded_tokens = self.tokenizer.tokenize(decoded) |
|
|
|
if len(original_tokens) == len(decoded_tokens): |
|
results['passed'] += 1 |
|
else: |
|
results['failed'] += 1 |
|
results['failures'].append({ |
|
'index': i, |
|
'original': text, |
|
'decoded': decoded, |
|
'original_tokens': len(original_tokens), |
|
'decoded_tokens': len(decoded_tokens) |
|
}) |
|
|
|
except Exception as e: |
|
results['failed'] += 1 |
|
results['failures'].append({ |
|
'index': i, |
|
'error': str(e), |
|
'text': text |
|
}) |
|
|
|
return results |
|
|
|
def test_special_tokens(self) -> Dict[str, bool]: |
|
"""Test special token handling""" |
|
results = {} |
|
|
|
for token_name, token_id in self.tokenizer.special_tokens.items(): |
|
try: |
|
|
|
ids = self.tokenizer.encode(token_name, add_special_tokens=False) |
|
expected_id = self.tokenizer.vocab.get(token_name) |
|
|
|
|
|
decoded = self.tokenizer.decode([token_id]) |
|
|
|
results[token_name] = ( |
|
expected_id in ids and |
|
token_name in decoded |
|
) |
|
|
|
except Exception: |
|
results[token_name] = False |
|
|
|
return results |
|
|
|
def test_edge_cases(self) -> Dict[str, bool]: |
|
"""Test edge cases and error conditions""" |
|
tests = { |
|
'empty_string': True, |
|
'whitespace_only': True, |
|
'very_long_text': True, |
|
'unicode_text': True, |
|
'special_chars': True |
|
} |
|
|
|
try: |
|
|
|
result = self.tokenizer.encode("") |
|
tests['empty_string'] = isinstance(result, list) |
|
except Exception: |
|
tests['empty_string'] = False |
|
|
|
try: |
|
|
|
result = self.tokenizer.encode(" \n\t ") |
|
tests['whitespace_only'] = isinstance(result, list) |
|
except Exception: |
|
tests['whitespace_only'] = False |
|
|
|
try: |
|
|
|
long_text = "test " * 10000 |
|
result = self.tokenizer.encode(long_text) |
|
tests['very_long_text'] = isinstance(result, list) |
|
except Exception: |
|
tests['very_long_text'] = False |
|
|
|
try: |
|
|
|
unicode_text = "Hello 世界 🌍 café naïve" |
|
result = self.tokenizer.encode(unicode_text) |
|
tests['unicode_text'] = isinstance(result, list) |
|
except Exception: |
|
tests['unicode_text'] = False |
|
|
|
try: |
|
|
|
special_text = "!@#$%^&*()_+-=[]{}|;:'\",.<>?/~`" |
|
result = self.tokenizer.encode(special_text) |
|
tests['special_chars'] = isinstance(result, list) |
|
except Exception: |
|
tests['special_chars'] = False |
|
|
|
return tests |
|
|
|
|
|
if __name__ == "__main__": |
|
exit(main()) |