Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Tokenizer utilities for extracting BPE/SentencePiece metadata. | |
| Provides functions to: | |
| - Extract subword pieces from tokens | |
| - Calculate byte lengths | |
| - Identify multi-split identifiers (≥3 subwords) | |
| - Detect tokenization artifacts | |
| """ | |
| from typing import List, Tuple, Dict, Optional | |
| import re | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class TokenizerMetadata: | |
| """Extracts and analyzes tokenization metadata""" | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| # Detect tokenizer type | |
| self.tokenizer_type = self._detect_tokenizer_type() | |
| def _detect_tokenizer_type(self) -> str: | |
| """Detect whether tokenizer uses BPE, SentencePiece, or other""" | |
| tokenizer_name = self.tokenizer.__class__.__name__.lower() | |
| if 'sentencepiece' in tokenizer_name: | |
| return 'sentencepiece' | |
| elif 'gpt2' in tokenizer_name or 'codegen' in tokenizer_name: | |
| return 'bpe' | |
| elif 'llama' in tokenizer_name: | |
| return 'sentencepiece' | |
| else: | |
| return 'unknown' | |
| def get_subword_pieces(self, token_id: int) -> List[str]: | |
| """ | |
| Extract subword pieces for a token ID. | |
| For BPE (GPT-2/CodeGen): | |
| - Tokens may contain 'Ġ' prefix for spaces | |
| - Example: token_id=1234 → "Ġuser" → ["user"] | |
| For SentencePiece (Llama): | |
| - Tokens may contain '▁' prefix for spaces | |
| - Example: token_id=5678 → "▁name" → ["name"] | |
| Returns: | |
| List of subword pieces (cleaned of special characters) | |
| """ | |
| try: | |
| # Decode single token | |
| token_str = self.tokenizer.decode([token_id]) | |
| # Clean special characters | |
| if self.tokenizer_type == 'bpe': | |
| # Remove 'Ġ' (GPT-2 space marker) | |
| cleaned = token_str.replace('Ġ', '') | |
| elif self.tokenizer_type == 'sentencepiece': | |
| # Remove '▁' (SentencePiece space marker) | |
| cleaned = token_str.replace('▁', '') | |
| else: | |
| cleaned = token_str | |
| # For compound identifiers, split on underscores/camelCase | |
| pieces = self._split_identifier(cleaned) | |
| return pieces if pieces else [cleaned] | |
| except Exception as e: | |
| logger.warning(f"Failed to extract subword pieces for token_id {token_id}: {e}") | |
| return [] | |
| def _split_identifier(self, text: str) -> List[str]: | |
| """ | |
| Split identifier into components. | |
| Examples: | |
| - "get_user_data" → ["get", "user", "data"] | |
| - "getUserData" → ["get", "User", "Data"] | |
| - "process" → ["process"] | |
| """ | |
| # Split on underscores | |
| if '_' in text: | |
| return [p for p in text.split('_') if p] | |
| # Split camelCase (insert _ before capitals, then split) | |
| camel_split = re.sub(r'([a-z])([A-Z])', r'\1_\2', text) | |
| if '_' in camel_split: | |
| return [p for p in camel_split.split('_') if p] | |
| # Single token | |
| return [text] | |
| def get_byte_length(self, token_id: int) -> int: | |
| """Get byte length of token (UTF-8 encoding)""" | |
| try: | |
| token_str = self.tokenizer.decode([token_id]) | |
| return len(token_str.encode('utf-8')) | |
| except Exception as e: | |
| logger.warning(f"Failed to get byte length for token_id {token_id}: {e}") | |
| return 0 | |
| def is_multi_split_identifier(self, token_ids: List[int], window_size: int = 5) -> List[bool]: | |
| """ | |
| Identify sequences of ≥3 tokens that form a single identifier. | |
| This detects cases like: | |
| - ["process", "_", "user"] (3 tokens for process_user) | |
| - ["get", "User", "Data"] (3 tokens for getUserData) | |
| Args: | |
| token_ids: List of token IDs | |
| window_size: Size of sliding window to check (default 5) | |
| Returns: | |
| Boolean array indicating if each token is part of multi-split identifier | |
| """ | |
| flags = [False] * len(token_ids) | |
| for i in range(len(token_ids)): | |
| # Look ahead up to window_size tokens | |
| window_end = min(i + window_size, len(token_ids)) | |
| window_tokens = token_ids[i:window_end] | |
| # Decode window | |
| window_text = self.tokenizer.decode(window_tokens) | |
| # Check if this looks like an identifier | |
| # Heuristic: contains underscores or camelCase, no spaces | |
| if self._is_identifier(window_text): | |
| # Count pieces | |
| pieces = self._split_identifier(window_text) | |
| if len(pieces) >= 3: | |
| # Mark all tokens in window as part of multi-split | |
| for j in range(i, window_end): | |
| flags[j] = True | |
| return flags | |
| def _is_identifier(self, text: str) -> bool: | |
| """Check if text looks like a code identifier""" | |
| # No spaces (identifiers don't have spaces) | |
| if ' ' in text: | |
| return False | |
| # Contains letters (not just punctuation) | |
| if not any(c.isalpha() for c in text): | |
| return False | |
| # Contains underscore or camelCase | |
| if '_' in text or any(c.isupper() for c in text): | |
| return True | |
| return False | |
| def analyze_tokens(self, token_ids: List[int]) -> List[Dict[str, any]]: | |
| """ | |
| Comprehensive analysis of token sequence. | |
| Returns list of dictionaries with: | |
| - token_id: int | |
| - text: str (decoded token) | |
| - bpe_pieces: List[str] (subword pieces) | |
| - byte_length: int | |
| - is_multi_split: bool (part of multi-split identifier) | |
| """ | |
| multi_split_flags = self.is_multi_split_identifier(token_ids) | |
| results = [] | |
| for i, token_id in enumerate(token_ids): | |
| pieces = self.get_subword_pieces(token_id) | |
| byte_len = self.get_byte_length(token_id) | |
| text = self.tokenizer.decode([token_id]) | |
| results.append({ | |
| 'token_id': token_id, | |
| 'text': text, | |
| 'bpe_pieces': pieces, | |
| 'byte_length': byte_len, | |
| 'is_multi_split': multi_split_flags[i], | |
| 'num_pieces': len(pieces) | |
| }) | |
| return results | |
| def get_tokenizer_stats(tokenizer, text: str) -> Dict[str, any]: | |
| """ | |
| Get tokenization statistics for a given text. | |
| Returns: | |
| Dictionary with: | |
| - num_tokens: Total tokens | |
| - avg_bytes_per_token: Average bytes per token | |
| - num_multi_split: Number of tokens in multi-split identifiers | |
| - tokenization_ratio: Characters / tokens | |
| """ | |
| token_ids = tokenizer.encode(text, add_special_tokens=False) | |
| metadata = TokenizerMetadata(tokenizer) | |
| analysis = metadata.analyze_tokens(token_ids) | |
| total_bytes = sum(t['byte_length'] for t in analysis) | |
| num_multi_split = sum(1 for t in analysis if t['is_multi_split']) | |
| return { | |
| 'num_tokens': len(token_ids), | |
| 'avg_bytes_per_token': total_bytes / len(token_ids) if token_ids else 0, | |
| 'num_multi_split': num_multi_split, | |
| 'tokenization_ratio': len(text) / len(token_ids) if token_ids else 0, | |
| 'analysis': analysis | |
| } | |
| def flag_risk_hotspots(token_analysis: List[Dict[str, any]], entropy_threshold: float = 1.5) -> List[int]: | |
| """ | |
| Flag tokens that are risk hotspots based on tokenization + entropy. | |
| A token is flagged if: | |
| - It's part of a multi-split identifier (≥3 subwords) | |
| - AND has high entropy (model is uncertain) | |
| Args: | |
| token_analysis: Output from TokenizerMetadata.analyze_tokens() | |
| entropy_threshold: Entropy threshold (default 1.5 nats) | |
| Returns: | |
| List of indices of flagged tokens | |
| Note: Entropy must be provided externally (from instrumentation layer) | |
| This function only checks the tokenization criterion. | |
| """ | |
| flagged = [] | |
| for i, token in enumerate(token_analysis): | |
| if token['is_multi_split'] and token['num_pieces'] >= 3: | |
| flagged.append(i) | |
| return flagged | |
| # Example usage | |
| if __name__ == "__main__": | |
| # This would be used with an actual tokenizer | |
| # from transformers import AutoTokenizer | |
| # tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") | |
| # | |
| # metadata = TokenizerMetadata(tokenizer) | |
| # stats = get_tokenizer_stats(tokenizer, "def process_user_data(user_name):") | |
| # print(stats) | |
| print("Tokenizer utilities module loaded successfully") | |