| """Remote-code tokenizer for Atom/Fusion GPT checkpoints. |
| |
| The tokenizer is intentionally HF-compatible: generic callers can use |
| ``AutoTokenizer.from_pretrained(..., trust_remote_code=True)``. Arithmetic digit |
| spans are reversed before tokenization so the model receives LSD-first numbers, |
| matching pretraining. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import re |
| from typing import Any |
|
|
| from transformers import PreTrainedTokenizerFast |
|
|
|
|
| class AtomTokenizer(PreTrainedTokenizerFast): |
| vocab_files_names = {"tokenizer_file": "tokenizer.json"} |
| model_input_names = ["input_ids", "attention_mask"] |
| slow_tokenizer_class = None |
| _digit_span_re = re.compile(r"\d+") |
|
|
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| kwargs.setdefault("bos_token", "<|bos|>") |
| kwargs.setdefault("eos_token", "<|eos|>") |
| kwargs.setdefault("unk_token", "<|unk|>") |
| kwargs.setdefault("pad_token", "<|pad|>") |
| super().__init__(*args, **kwargs) |
|
|
| @classmethod |
| def _reverse_digit_spans(cls, text: str) -> str: |
| return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text) |
|
|
| @classmethod |
| def _transform_text(cls, value: Any) -> Any: |
| if isinstance(value, str): |
| return cls._reverse_digit_spans(value) |
| if isinstance(value, tuple): |
| return tuple(cls._transform_text(item) for item in value) |
| if isinstance(value, list): |
| return [cls._transform_text(item) for item in value] |
| return value |
|
|
| def __call__(self, text=None, text_pair=None, *args: Any, **kwargs: Any): |
| return super().__call__( |
| self._transform_text(text), |
| self._transform_text(text_pair), |
| *args, |
| **kwargs, |
| ) |
|
|
| def encode(self, text, text_pair=None, *args: Any, **kwargs: Any): |
| return super().encode( |
| self._transform_text(text), |
| self._transform_text(text_pair), |
| *args, |
| **kwargs, |
| ) |
|
|
| def batch_encode_plus(self, batch_text_or_text_pairs, *args: Any, **kwargs: Any): |
| return super().batch_encode_plus( |
| self._transform_text(batch_text_or_text_pairs), |
| *args, |
| **kwargs, |
| ) |
|
|
| def _decode(self, token_ids, skip_special_tokens: bool = False, **kwargs: Any) -> str: |
| text = super()._decode( |
| token_ids, |
| skip_special_tokens=skip_special_tokens, |
| **kwargs, |
| ) |
| return self._reverse_digit_spans(text) |
|
|