File size: 7,128 Bytes
2010c83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from __future__ import annotations
import os
from pathlib import Path
from typing import List, Optional, Union
from tokenizers import Tokenizer as BaseTokenizer
from .aliases import PathOrStr
from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection
from .exceptions import OLMoConfigurationError
__all__ = ["Tokenizer"]
class Tokenizer:
"""
A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`.
:param base_tokenizer: The :class:`tokenizers.Tokenizer` to use.
:param eos_token_id: The token ID corresponding to the "end-of-sentence" token.
:param truncate_to: Truncate when tokenizing to this number of token IDs.
:param truncate_direction: The direction to truncate in. "right" means truncate the tokens
on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null,
this setting has no effect.
"""
def __init__(
self,
base_tokenizer: BaseTokenizer,
eos_token_id: int,
pad_token_id: Optional[int] = None,
truncate_to: Optional[int] = None,
truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right,
):
self.base_tokenizer = base_tokenizer
self.base_tokenizer.no_truncation()
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id
self.truncate_to = truncate_to
self.truncate_direction = TruncationDirection(truncate_direction)
@property
def vocab_size(self) -> int:
return self.base_tokenizer.get_vocab_size()
@property
def eos_token(self) -> str:
return self.decode([self.eos_token_id], skip_special_tokens=False)
@property
def pad_token(self) -> str:
return self.decode([self.pad_token_id], skip_special_tokens=False)
@classmethod
def from_train_config(cls, config: TrainConfig) -> Tokenizer:
tokenizer_identifier = config.tokenizer.identifier
if Path(tokenizer_identifier).is_file():
tokenizer = cls.from_file(
tokenizer_identifier,
eos_token_id=config.model.eos_token_id,
pad_token_id=config.model.pad_token_id,
)
else:
tokenizer = cls.from_pretrained(
tokenizer_identifier,
eos_token_id=config.model.eos_token_id,
pad_token_id=config.model.pad_token_id,
)
if config.model.vocab_size != tokenizer.vocab_size:
raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
return tokenizer
@classmethod
def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer:
"""
Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub.
:param identifier: The identifier of a model on the Hub that contains a
``tokenizer.json`` file.
:param kwargs: Other key word arguments passed to :class:`Tokenizer`.
"""
base_tokenizer = BaseTokenizer.from_pretrained(identifier)
eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
return cls(base_tokenizer, eos_token_id, **kwargs)
@classmethod
def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer:
"""
Initialize a tokenizer from a file.
You can create those files with ``BaseTokenizer.save()``.
:param filename: The name of a file containing a tokenizer specification.
:param kwargs: Other key word arguments passed to :class:`Tokenizer`.
"""
base_tokenizer = BaseTokenizer.from_file(filename)
eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1)
return cls(base_tokenizer, eos_token_id, **kwargs)
@classmethod
def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer:
"""
Load a tokenizer from a checkpoint.
"""
from cached_path import cached_path
# Load configs.
config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer")
model_config = ModelConfig.load(config_path, key="model")
# Initialize tokenizer and validate vocab size.
if Path(tokenizer_config.identifier).is_file():
tokenizer = cls.from_file(
tokenizer_config.identifier,
eos_token_id=model_config.eos_token_id,
pad_token_id=model_config.pad_token_id,
)
else:
tokenizer = cls.from_pretrained(
tokenizer_config.identifier,
eos_token_id=model_config.eos_token_id,
pad_token_id=model_config.pad_token_id,
)
if model_config.vocab_size != tokenizer.vocab_size:
raise OLMoConfigurationError("vocab size mismatch between config and tokenizer")
return tokenizer
def add_special_tokens(self, input_ids: List[int]) -> List[int]:
"""
Add special tokens in-place (if not already present) to the given token IDs.
"""
if not input_ids or input_ids[-1] != self.eos_token_id:
input_ids.append(self.eos_token_id)
return input_ids
def num_special_tokens_to_add(self, is_pair: bool = False) -> int:
return 2 if is_pair else 1
def _truncate(
self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection
) -> list[int]:
if truncate_to is None or len(input_ids) <= truncate_to:
return input_ids
elif direction == TruncationDirection.left:
return input_ids[len(input_ids) - truncate_to :]
else:
return input_ids[: -(len(input_ids) - truncate_to)]
def encode(self, input: str, add_special_tokens: bool = True) -> List[int]:
"""
Encode a string into token IDs.
"""
return self.encode_batch([input], add_special_tokens=add_special_tokens)[0]
def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]:
"""
Encode a batch of strings into token IDs.
"""
truncate_to = self.truncate_to
if truncate_to is not None and add_special_tokens:
truncate_to -= self.num_special_tokens_to_add(False)
batch_encoding = self.base_tokenizer.encode_batch(inputs)
all_input_ids = []
for encoding in batch_encoding:
input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction)
if add_special_tokens:
input_ids = self.add_special_tokens(input_ids)
all_input_ids.append(input_ids)
return all_input_ids
def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str:
"""
Decode a list of token IDs to a string.
"""
return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|