from __future__ import annotations import os from pathlib import Path from typing import List, Optional, Union from tokenizers import Tokenizer as BaseTokenizer from .aliases import PathOrStr from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection from .exceptions import OLMoConfigurationError __all__ = ["Tokenizer"] class Tokenizer: """ A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`. :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. :param eos_token_id: The token ID corresponding to the "end-of-sentence" token. :param truncate_to: Truncate when tokenizing to this number of token IDs. :param truncate_direction: The direction to truncate in. "right" means truncate the tokens on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, this setting has no effect. """ def __init__( self, base_tokenizer: BaseTokenizer, eos_token_id: int, pad_token_id: Optional[int] = None, truncate_to: Optional[int] = None, truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, ): self.base_tokenizer = base_tokenizer self.base_tokenizer.no_truncation() self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id self.truncate_to = truncate_to self.truncate_direction = TruncationDirection(truncate_direction) @property def vocab_size(self) -> int: return self.base_tokenizer.get_vocab_size() @property def eos_token(self) -> str: return self.decode([self.eos_token_id], skip_special_tokens=False) @property def pad_token(self) -> str: return self.decode([self.pad_token_id], skip_special_tokens=False) @classmethod def from_train_config(cls, config: TrainConfig) -> Tokenizer: tokenizer_identifier = config.tokenizer.identifier if Path(tokenizer_identifier).is_file(): tokenizer = cls.from_file( tokenizer_identifier, eos_token_id=config.model.eos_token_id, pad_token_id=config.model.pad_token_id, ) else: tokenizer = cls.from_pretrained( tokenizer_identifier, eos_token_id=config.model.eos_token_id, pad_token_id=config.model.pad_token_id, ) if config.model.vocab_size != tokenizer.vocab_size: raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") return tokenizer @classmethod def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer: """ Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. :param identifier: The identifier of a model on the Hub that contains a ``tokenizer.json`` file. :param kwargs: Other key word arguments passed to :class:`Tokenizer`. """ base_tokenizer = BaseTokenizer.from_pretrained(identifier) eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) return cls(base_tokenizer, eos_token_id, **kwargs) @classmethod def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer: """ Initialize a tokenizer from a file. You can create those files with ``BaseTokenizer.save()``. :param filename: The name of a file containing a tokenizer specification. :param kwargs: Other key word arguments passed to :class:`Tokenizer`. """ base_tokenizer = BaseTokenizer.from_file(filename) eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) return cls(base_tokenizer, eos_token_id, **kwargs) @classmethod def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer: """ Load a tokenizer from a checkpoint. """ from cached_path import cached_path # Load configs. config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer") model_config = ModelConfig.load(config_path, key="model") # Initialize tokenizer and validate vocab size. if Path(tokenizer_config.identifier).is_file(): tokenizer = cls.from_file( tokenizer_config.identifier, eos_token_id=model_config.eos_token_id, pad_token_id=model_config.pad_token_id, ) else: tokenizer = cls.from_pretrained( tokenizer_config.identifier, eos_token_id=model_config.eos_token_id, pad_token_id=model_config.pad_token_id, ) if model_config.vocab_size != tokenizer.vocab_size: raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") return tokenizer def add_special_tokens(self, input_ids: List[int]) -> List[int]: """ Add special tokens in-place (if not already present) to the given token IDs. """ if not input_ids or input_ids[-1] != self.eos_token_id: input_ids.append(self.eos_token_id) return input_ids def num_special_tokens_to_add(self, is_pair: bool = False) -> int: return 2 if is_pair else 1 def _truncate( self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection ) -> list[int]: if truncate_to is None or len(input_ids) <= truncate_to: return input_ids elif direction == TruncationDirection.left: return input_ids[len(input_ids) - truncate_to :] else: return input_ids[: -(len(input_ids) - truncate_to)] def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: """ Encode a string into token IDs. """ return self.encode_batch([input], add_special_tokens=add_special_tokens)[0] def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: """ Encode a batch of strings into token IDs. """ truncate_to = self.truncate_to if truncate_to is not None and add_special_tokens: truncate_to -= self.num_special_tokens_to_add(False) batch_encoding = self.base_tokenizer.encode_batch(inputs) all_input_ids = [] for encoding in batch_encoding: input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction) if add_special_tokens: input_ids = self.add_special_tokens(input_ids) all_input_ids.append(input_ids) return all_input_ids def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: """ Decode a list of token IDs to a string. """ return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)