File size: 9,474 Bytes
c778cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RWKV6."""
import os
import re
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from transformers.utils import logging
if TYPE_CHECKING:
pass
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "rwkv_vocab_v20230424.txt",
}
class TRIE:
__slots__ = tuple("ch,to,values,front".split(","))
to: list
values: set
def __init__(self, front=None, ch=None):
self.ch = ch
self.to = [None for ch in range(256)]
self.values = set()
self.front = front
def __repr__(self):
fr = self
ret = []
while fr != None:
if fr.ch != None:
ret.append(fr.ch)
fr = fr.front
return "<TRIE %s %s>" % (ret[::-1], self.values)
def add(self, key: bytes, idx: int = 0, val=None):
if idx == len(key):
if val is None:
val = key
self.values.add(val)
return self
ch = key[idx]
if self.to[ch] is None:
self.to[ch] = TRIE(front=self, ch=ch)
return self.to[ch].add(key, idx=idx + 1, val=val)
def find_longest(self, key: bytes, idx: int = 0):
u: TRIE = self
ch: int = key[idx]
while u.to[ch] is not None:
u = u.to[ch]
idx += 1
if u.values:
ret = idx, u, u.values
if idx == len(key):
break
ch = key[idx]
return ret
class RWKV_TOKENIZER:
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
with open(file_name, "r", encoding="utf-8") as f:
lines = f.readlines()
for l in lines:
idx = int(l[: l.index(" ")])
x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(" ") :])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
self.root = TRIE()
for t, i in self.token2idx.items():
_ = self.root.add(t, val=(t, i))
def encodeBytes(self, src: bytes):
idx: int = 0
tokens = []
while idx < len(src):
_idx: int = idx
idx, _, values = self.root.find_longest(src, idx)
assert idx != _idx
_, token = next(iter(values))
tokens.append(token)
return tokens
def decodeBytes(self, tokens):
return b"".join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src):
if isinstance(src, str):
return [self.encodeBytes(src.encode("utf-8"))]
elif isinstance(src, list):
return [self.encodeBytes(s.encode("utf-8")) for s in src]
def decode(self, tokens):
return [self.decodeBytes(batch).decode("utf-8") for batch in tokens]
# try:
# return self.decodeBytes(tokens).decode('utf-8')
# except:
# return '\ufffd' # bad utf-8
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode("utf-8")
except:
pass
print(f"{repr(s)}{i}", end=" ")
print()
class Rwkv6Tokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self, vocab_file, bos_token="<s>", eos_token="<s>", unk_token="<s>", **kwargs
):
if not os.path.isfile(vocab_file):
raise ValueError(
f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
" model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines()
if "add_bos_token" in kwargs:
self.add_bos_token = kwargs["add_bos_token"]
else:
self.add_bos_token = False
self.trie_tokenizer = RWKV_TOKENIZER(vocab_file)
vocab = self.trie_tokenizer.token2idx
self.encoder = vocab
self.decoder = {v: k for k, v in vocab.items()}
self._added_tokens_decoder = {0: AddedToken(str(bos_token))}
super().__init__(
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
)
@property
def vocab_size(self):
return len(self.encoder)
def get_vocab(self):
vocab = {str(self.convert_ids_to_tokens(i)): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, split_special_tokens=False):
# return self.wordpiece_tokenizer.tokenize(text.encode("utf-8"))
return self.trie_tokenizer.encode(text)[0]
def _convert_token_to_id(self, token):
return token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (byte) using the vocab."""
token = self.decoder.get(index, self.unk_token)
if isinstance(token, (bytes)):
token = token.decode("utf-8", errors="replace")
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (bytes) in a single string. Additional tokens are encoded to bytes"""
out_string = b"".join(
[k.encode(errors="replace") if isinstance(k, str) else k for k in tokens]
).decode("utf-8")
return out_string
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
)
else:
vocab_file = (
filename_prefix + "-" if filename_prefix else ""
) + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(
self.encoder.items(), key=lambda kv: kv[1]
):
if index != token_index:
logger.warning(
f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
)
index = token_index
writer.write(str(token) + "\n")
index += 1
return (vocab_file,)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
output = bos_token_ids + token_ids_0
if token_ids_1 is None:
return output
return output + bos_token_ids + token_ids_1
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
if not self.add_bos_token:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=False,
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0))
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|