Atom2.7m / tokenization_atom.py
ucr-max's picture
Update Atom2.7m submission
2fd4f23 verified
Raw
History Blame Contribute Delete
2.51 kB
"""Remote-code tokenizer for Atom/Fusion GPT checkpoints.
The tokenizer is intentionally HF-compatible: generic callers can use
``AutoTokenizer.from_pretrained(..., trust_remote_code=True)``. Arithmetic digit
spans are reversed before tokenization so the model receives LSD-first numbers,
matching pretraining.
"""
from __future__ import annotations
import re
from typing import Any
from transformers import PreTrainedTokenizerFast
class AtomTokenizer(PreTrainedTokenizerFast):
vocab_files_names = {"tokenizer_file": "tokenizer.json"}
model_input_names = ["input_ids", "attention_mask"]
slow_tokenizer_class = None
_digit_span_re = re.compile(r"\d+")
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.setdefault("bos_token", "<|bos|>")
kwargs.setdefault("eos_token", "<|eos|>")
kwargs.setdefault("unk_token", "<|unk|>")
kwargs.setdefault("pad_token", "<|pad|>")
super().__init__(*args, **kwargs)
@classmethod
def _reverse_digit_spans(cls, text: str) -> str:
return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text)
@classmethod
def _transform_text(cls, value: Any) -> Any:
if isinstance(value, str):
return cls._reverse_digit_spans(value)
if isinstance(value, tuple):
return tuple(cls._transform_text(item) for item in value)
if isinstance(value, list):
return [cls._transform_text(item) for item in value]
return value
def __call__(self, text=None, text_pair=None, *args: Any, **kwargs: Any):
return super().__call__(
self._transform_text(text),
self._transform_text(text_pair),
*args,
**kwargs,
)
def encode(self, text, text_pair=None, *args: Any, **kwargs: Any):
return super().encode(
self._transform_text(text),
self._transform_text(text_pair),
*args,
**kwargs,
)
def batch_encode_plus(self, batch_text_or_text_pairs, *args: Any, **kwargs: Any):
return super().batch_encode_plus(
self._transform_text(batch_text_or_text_pairs),
*args,
**kwargs,
)
def _decode(self, token_ids, skip_special_tokens: bool = False, **kwargs: Any) -> str:
text = super()._decode(
token_ids,
skip_special_tokens=skip_special_tokens,
**kwargs,
)
return self._reverse_digit_spans(text)