| """
|
| TouchGrass tokenizer for HuggingFace.
|
| Wraps extended Qwen tokenizer for HF compatibility.
|
| """
|
|
|
| from typing import List, Optional, Dict, Any
|
| import json
|
| import os
|
|
|
|
|
| class TouchGrassTokenizer:
|
| """
|
| HuggingFace-compatible tokenizer for TouchGrass.
|
| Wraps the extended Qwen tokenizer.
|
| """
|
|
|
| def __init__(
|
| self,
|
| tokenizer_file: Optional[str] = None,
|
| config: Optional[Dict] = None,
|
| **kwargs,
|
| ):
|
| """
|
| Initialize tokenizer.
|
|
|
| Args:
|
| tokenizer_file: Path to tokenizer JSON
|
| config: Tokenizer configuration
|
| """
|
| from .tokenizer.music_token_extension import MusicTokenizerExtension
|
|
|
| self.config = config or {}
|
| self.special_tokens = self.config.get("special_tokens", {})
|
|
|
| if tokenizer_file and os.path.exists(tokenizer_file):
|
| self.tokenizer_ext = MusicTokenizerExtension.from_pretrained(
|
| os.path.dirname(tokenizer_file)
|
| )
|
| self.tokenizer = self.tokenizer_ext.get_tokenizer()
|
| else:
|
|
|
| self.tokenizer_ext = None
|
| self.tokenizer = None
|
|
|
|
|
| self.pad_token = "[PAD]"
|
| self.unk_token = "[UNK]"
|
| self.bos_token = "[BOS]"
|
| self.eos_token = "[EOS]"
|
| self.pad_token_id = self.special_tokens.get("[PAD]", 0)
|
| self.unk_token_id = self.special_tokens.get("[UNK]", 1)
|
| self.bos_token_id = self.special_tokens.get("[BOS]", 2)
|
| self.eos_token_id = self.special_tokens.get("[EOS]", 3)
|
|
|
| @classmethod
|
| def from_pretrained(
|
| cls,
|
| pretrained_model_name_or_path: str,
|
| **kwargs,
|
| ):
|
| """Load tokenizer from pretrained model."""
|
| tokenizer_path = os.path.join(pretrained_model_name_or_path, "tokenizer.json")
|
| config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
|
|
|
| config = {}
|
| if os.path.exists(config_path):
|
| with open(config_path, "r") as f:
|
| config = json.load(f)
|
|
|
| return cls(tokenizer_file=tokenizer_path, config=config, **kwargs)
|
|
|
| def __call__(
|
| self,
|
| text: str | List[str],
|
| padding: bool = False,
|
| truncation: bool = False,
|
| max_length: Optional[int] = None,
|
| return_tensors: str = "pt",
|
| **kwargs,
|
| ) -> Dict[str, Any]:
|
| """
|
| Tokenize text.
|
|
|
| Args:
|
| text: Input text or list of texts
|
| padding: Pad to same length
|
| truncation: Truncate to max_length
|
| max_length: Maximum length
|
| return_tensors: "pt" for PyTorch, "np" for numpy, None for list
|
|
|
| Returns:
|
| Dictionary with input_ids, attention_mask
|
| """
|
| if self.tokenizer is None:
|
| raise ValueError("Tokenizer not initialized. Load from pretrained or extend a base tokenizer.")
|
|
|
| if isinstance(text, str):
|
| text = [text]
|
|
|
| if max_length is None:
|
| max_length = self.config.get("max_seq_len", 4096)
|
|
|
|
|
| result = self.tokenizer(
|
| text,
|
| padding=padding,
|
| truncation=truncation,
|
| max_length=max_length,
|
| return_tensors=return_tensors,
|
| **kwargs
|
| )
|
|
|
| return result
|
|
|
| def encode(
|
| self,
|
| text: str,
|
| add_special_tokens: bool = True,
|
| **kwargs,
|
| ) -> List[int]:
|
| """Encode text to token IDs."""
|
| result = self.tokenizer.encode(
|
| text,
|
| add_special_tokens=add_special_tokens,
|
| return_tensors=None,
|
| )
|
| return result["input_ids"]
|
|
|
| def decode(
|
| self,
|
| token_ids: List[int],
|
| skip_special_tokens: bool = True,
|
| **kwargs,
|
| ) -> str:
|
| """Decode token IDs to text."""
|
| return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
|
| def save_pretrained(self, save_directory: str):
|
| """Save tokenizer to directory."""
|
| os.makedirs(save_directory, exist_ok=True)
|
|
|
|
|
| self.tokenizer.save_pretrained(save_directory)
|
|
|
|
|
| config_path = os.path.join(save_directory, "tokenizer_config.json")
|
| with open(config_path, "w") as f:
|
| json.dump({
|
| "model_type": "touchgrass",
|
| "special_tokens": self.special_tokens,
|
| }, f, indent=2)
|
|
|
| @property
|
| def vocab_size(self) -> int:
|
| """Get vocabulary size."""
|
| return self.tokenizer.vocab_size if self.tokenizer else 0 |