jon-tow commited on
Commit
589adbf
·
verified ·
1 Parent(s): 4ae0672

fix: make `eos_token`/`pad_token` overridable and add `pickle` support

Browse files
Files changed (1) hide show
  1. tokenization_arcade100k.py +17 -3
tokenization_arcade100k.py CHANGED
@@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
124
 
125
  self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
126
  self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
127
- self.eos_token = self.decoder[self.tokenizer.eot_token]
128
- self.pad_token = self.decoder[self.tokenizer.eot_token]
 
 
 
 
129
  # Expose for convenience
130
  self.mergeable_ranks = self.tokenizer._mergeable_ranks
131
  self.special_tokens = self.tokenizer._special_tokens
@@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
133
  def __len__(self):
134
  return self.tokenizer.n_vocab
135
 
 
 
 
 
 
 
 
 
 
 
136
  @property
137
  def vocab_size(self):
138
  return self.tokenizer.n_vocab
@@ -273,4 +287,4 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
273
  token_ids = [token_ids]
274
  if skip_special_tokens:
275
  token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
276
- return self.tokenizer.decode(token_ids)
 
124
 
125
  self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
126
  self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
127
+ # Provide default `eos_token` and `pad_token`
128
+ if self.eos_token is None:
129
+ self.eos_token = self.decoder[self.tokenizer.eot_token]
130
+ if self.pad_token is None:
131
+ self.pad_token = self.decoder[self.tokenizer.pad_token]
132
+
133
  # Expose for convenience
134
  self.mergeable_ranks = self.tokenizer._mergeable_ranks
135
  self.special_tokens = self.tokenizer._special_tokens
 
137
  def __len__(self):
138
  return self.tokenizer.n_vocab
139
 
140
+ def __getstate__(self):
141
+ # Required for `pickle` support
142
+ state = self.__dict__.copy()
143
+ del state["tokenizer"]
144
+ return state
145
+
146
+ def __setstate__(self, state):
147
+ self.__dict__.update(state)
148
+ self.tokenizer = tiktoken.Encoding(**self._tiktoken_config)
149
+
150
  @property
151
  def vocab_size(self):
152
  return self.tokenizer.n_vocab
 
287
  token_ids = [token_ids]
288
  if skip_special_tokens:
289
  token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
290
+ return self.tokenizer.decode(token_ids)