mikeleske commited on
Commit
2512b3e
1 Parent(s): ba641a4

Create tokenizer.py

Browse files
Files changed (1) hide show
  1. tokenizer.py +129 -0
tokenizer.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
2
+ from __future__ import annotations
3
+
4
+ import torch
5
+
6
+ import numpy as np
7
+
8
+ from os import PathLike
9
+ from typing import List, Tuple
10
+
11
+ from tokenizers import Tokenizer
12
+ from transformers.tokenization_utils import PreTrainedTokenizer
13
+ from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
14
+ from transformers.utils.generic import TensorType, PaddingStrategy
15
+
16
+
17
+ EMPTY: str = ""
18
+
19
+
20
+ class ByteTokenizer(PreTrainedTokenizer):
21
+
22
+ """UTF-8 Encoder."""
23
+
24
+ @classmethod
25
+ def from_pretrained(cls, model_id: str | PathLike, **kwargs) -> ByteTokenizer:
26
+
27
+ return cls(**kwargs, byte_level=True)
28
+
29
+ @property
30
+ def vocab_size(self) -> int:
31
+
32
+ return 512
33
+
34
+ @property
35
+ def byte_level(self) -> bool:
36
+
37
+ return self.init_kwargs.get('byte_level', True)
38
+
39
+ def get_vocab(self) -> Dict[str, int]:
40
+
41
+ return {chr(i): i for i in range(self.vocab_size)}
42
+
43
+ def __len__(self) -> int:
44
+
45
+ return self.vocab_size
46
+
47
+ def clamp(self, n: int) -> int:
48
+
49
+ return max(32, min(n, self.vocab_size))
50
+
51
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
52
+
53
+ return list(text)
54
+
55
+ def byte_tokenize(self, text: str) -> np.ndarray:
56
+
57
+ return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
58
+
59
+ def _convert_token_to_id(self, token: str) -> int:
60
+
61
+ return self.clamp(ord(token))
62
+
63
+ def _convert_id_to_token(self, index: int) -> str:
64
+
65
+ return chr(self.clamp(index))
66
+
67
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
68
+
69
+ return EMPTY.join(tokens)
70
+
71
+ def _decode(self, token_ids: List[int], **kwargs) -> str:
72
+
73
+ indices = np.asarray(token_ids, dtype=np.uint8)
74
+
75
+ return (
76
+ indices.clip(min=32, max=self.vocab_size, out=indices)
77
+ .tobytes()
78
+ .decode('utf-8')
79
+ )
80
+
81
+ def _encode_plus(self, text: str, **kwargs) -> BatchEncoding:
82
+
83
+ first_ids = self.byte_tokenize(text).tolist()
84
+
85
+ return self.prepare_for_model(
86
+ first_ids,
87
+ pair_ids=None,
88
+ add_special_tokens=kwargs.get('add_special_tokens', False),
89
+ padding=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD).value,
90
+ truncation=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE).value,
91
+ max_length=kwargs.get('max_length'),
92
+ stride=kwargs.get('stride', 0),
93
+ pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
94
+ return_tensors=kwargs.get('return_tensors'),
95
+ prepend_batch_axis=True,
96
+ return_attention_mask=kwargs.get('return_attention_mask'),
97
+ return_token_type_ids=kwargs.get('return_token_type_ids'),
98
+ return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
99
+ return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
100
+ return_length=kwargs.get('return_length', False),
101
+ verbose=kwargs.get('verbose', True),
102
+ )
103
+
104
+ def _batch_encode_plus(self, batch_text_or_text_pairs: List[str], **kwargs) -> BatchEncoding:
105
+
106
+ input_ids = [(self.byte_tokenize(text).tolist(), None) for text in batch_text_or_text_pairs]
107
+
108
+ return self._batch_prepare_for_model(
109
+ input_ids,
110
+ add_special_tokens=kwargs.get('add_special_tokens', False),
111
+ padding_strategy=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD),
112
+ truncation_strategy=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE),
113
+ max_length=kwargs.get('max_length'),
114
+ stride=kwargs.get('stride', 0),
115
+ pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
116
+ return_attention_mask=kwargs.get('return_attention_mask'),
117
+ return_token_type_ids=kwargs.get('return_token_type_ids'),
118
+ return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
119
+ return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
120
+ return_length=kwargs.get('return_length', False),
121
+ return_tensors=kwargs.get('return_tensors'),
122
+ verbose=kwargs.get('verbose', True),
123
+ )
124
+
125
+ def _save_pretrained(
126
+ self, save_directory: str | PathLike, file_names: Tuple[str], **kwargs
127
+ ) -> Tuple[str]:
128
+
129
+ return file_names