|
from __future__ import annotations |
|
|
|
import os |
|
from pathlib import Path |
|
from typing import List, Optional, Union |
|
|
|
from tokenizers import Tokenizer as BaseTokenizer |
|
|
|
from .aliases import PathOrStr |
|
from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection |
|
from .exceptions import OLMoConfigurationError |
|
|
|
__all__ = ["Tokenizer"] |
|
|
|
|
|
class Tokenizer: |
|
""" |
|
A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`. |
|
|
|
:param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. |
|
:param eos_token_id: The token ID corresponding to the "end-of-sentence" token. |
|
:param truncate_to: Truncate when tokenizing to this number of token IDs. |
|
:param truncate_direction: The direction to truncate in. "right" means truncate the tokens |
|
on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, |
|
this setting has no effect. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
base_tokenizer: BaseTokenizer, |
|
eos_token_id: int, |
|
pad_token_id: Optional[int] = None, |
|
truncate_to: Optional[int] = None, |
|
truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, |
|
): |
|
self.base_tokenizer = base_tokenizer |
|
self.base_tokenizer.no_truncation() |
|
self.eos_token_id = eos_token_id |
|
self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id |
|
self.truncate_to = truncate_to |
|
self.truncate_direction = TruncationDirection(truncate_direction) |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return self.base_tokenizer.get_vocab_size() |
|
|
|
@property |
|
def eos_token(self) -> str: |
|
return self.decode([self.eos_token_id], skip_special_tokens=False) |
|
|
|
@property |
|
def pad_token(self) -> str: |
|
return self.decode([self.pad_token_id], skip_special_tokens=False) |
|
|
|
@classmethod |
|
def from_train_config(cls, config: TrainConfig) -> Tokenizer: |
|
tokenizer_identifier = config.tokenizer.identifier |
|
if Path(tokenizer_identifier).is_file(): |
|
tokenizer = cls.from_file( |
|
tokenizer_identifier, |
|
eos_token_id=config.model.eos_token_id, |
|
pad_token_id=config.model.pad_token_id, |
|
) |
|
else: |
|
tokenizer = cls.from_pretrained( |
|
tokenizer_identifier, |
|
eos_token_id=config.model.eos_token_id, |
|
pad_token_id=config.model.pad_token_id, |
|
) |
|
if config.model.vocab_size != tokenizer.vocab_size: |
|
raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") |
|
return tokenizer |
|
|
|
@classmethod |
|
def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer: |
|
""" |
|
Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. |
|
|
|
:param identifier: The identifier of a model on the Hub that contains a |
|
``tokenizer.json`` file. |
|
:param kwargs: Other key word arguments passed to :class:`Tokenizer`. |
|
""" |
|
base_tokenizer = BaseTokenizer.from_pretrained(identifier) |
|
eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) |
|
return cls(base_tokenizer, eos_token_id, **kwargs) |
|
|
|
@classmethod |
|
def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer: |
|
""" |
|
Initialize a tokenizer from a file. |
|
|
|
You can create those files with ``BaseTokenizer.save()``. |
|
|
|
:param filename: The name of a file containing a tokenizer specification. |
|
:param kwargs: Other key word arguments passed to :class:`Tokenizer`. |
|
""" |
|
base_tokenizer = BaseTokenizer.from_file(filename) |
|
eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) |
|
return cls(base_tokenizer, eos_token_id, **kwargs) |
|
|
|
@classmethod |
|
def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer: |
|
""" |
|
Load a tokenizer from a checkpoint. |
|
""" |
|
from cached_path import cached_path |
|
|
|
|
|
config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) |
|
tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer") |
|
model_config = ModelConfig.load(config_path, key="model") |
|
|
|
|
|
if Path(tokenizer_config.identifier).is_file(): |
|
tokenizer = cls.from_file( |
|
tokenizer_config.identifier, |
|
eos_token_id=model_config.eos_token_id, |
|
pad_token_id=model_config.pad_token_id, |
|
) |
|
else: |
|
tokenizer = cls.from_pretrained( |
|
tokenizer_config.identifier, |
|
eos_token_id=model_config.eos_token_id, |
|
pad_token_id=model_config.pad_token_id, |
|
) |
|
if model_config.vocab_size != tokenizer.vocab_size: |
|
raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") |
|
return tokenizer |
|
|
|
def add_special_tokens(self, input_ids: List[int]) -> List[int]: |
|
""" |
|
Add special tokens in-place (if not already present) to the given token IDs. |
|
""" |
|
if not input_ids or input_ids[-1] != self.eos_token_id: |
|
input_ids.append(self.eos_token_id) |
|
return input_ids |
|
|
|
def num_special_tokens_to_add(self, is_pair: bool = False) -> int: |
|
return 2 if is_pair else 1 |
|
|
|
def _truncate( |
|
self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection |
|
) -> list[int]: |
|
if truncate_to is None or len(input_ids) <= truncate_to: |
|
return input_ids |
|
elif direction == TruncationDirection.left: |
|
return input_ids[len(input_ids) - truncate_to :] |
|
else: |
|
return input_ids[: -(len(input_ids) - truncate_to)] |
|
|
|
def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: |
|
""" |
|
Encode a string into token IDs. |
|
""" |
|
return self.encode_batch([input], add_special_tokens=add_special_tokens)[0] |
|
|
|
def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: |
|
""" |
|
Encode a batch of strings into token IDs. |
|
""" |
|
truncate_to = self.truncate_to |
|
if truncate_to is not None and add_special_tokens: |
|
truncate_to -= self.num_special_tokens_to_add(False) |
|
|
|
batch_encoding = self.base_tokenizer.encode_batch(inputs) |
|
|
|
all_input_ids = [] |
|
for encoding in batch_encoding: |
|
input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction) |
|
if add_special_tokens: |
|
input_ids = self.add_special_tokens(input_ids) |
|
all_input_ids.append(input_ids) |
|
|
|
return all_input_ids |
|
|
|
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: |
|
""" |
|
Decode a list of token IDs to a string. |
|
""" |
|
return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) |
|
|