Deci
/

Text Generation
Transformers
Safetensors
deci
custom_code
tomer-deci commited on
Commit
c141a5b
·
verified ·
1 Parent(s): 979a600

Update tokenization_decicoder.py

Browse files

Fixed more CodeGen bugs in DeciCoderTokenizer

Files changed (1) hide show
  1. tokenization_decicoder.py +31 -6
tokenization_decicoder.py CHANGED
@@ -1,10 +1,14 @@
1
- from transformers.models.auto.tokenization_auto import get_class_from_dynamic_module
2
  from transformers.tokenization_utils import AddedToken
3
 
4
- CodeGen25Tokenizer = get_class_from_dynamic_module("tokenization_codegen25.CodeGen25Tokenizer",
5
- "Salesforce/codegen25-7b-multi")
6
- tiktoken_tokenizer = get_class_from_dynamic_module("tokenization_codegen25.tiktoken_tokenizer",
7
- "Salesforce/codegen25-7b-multi")
 
 
 
 
8
 
9
 
10
  class DeciCoderTokenizer(CodeGen25Tokenizer):
@@ -16,8 +20,9 @@ class DeciCoderTokenizer(CodeGen25Tokenizer):
16
  add_special_tokens=True,
17
  **kwargs,
18
  ):
 
19
  self.add_eos_token = add_eos_token
20
- self.encoder = tiktoken_tokenizer(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
21
  pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
22
  eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
23
  super().__init__(
@@ -29,7 +34,27 @@ class DeciCoderTokenizer(CodeGen25Tokenizer):
29
  )
30
 
31
  def _convert_id_to_token(self, index):
 
32
  try:
33
  return super()._convert_id_to_token(index)
34
  except:
35
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
2
  from transformers.tokenization_utils import AddedToken
3
 
4
+ _codegen_revision = dict(pretrained_model_name_or_path="Salesforce/codegen25-7b-multi",
5
+ revision="d4dc9dd90e8b23d5411e6d970e3a11e88dc5c2bc")
6
+
7
+ CodeGen25Tokenizer = get_class_from_dynamic_module(
8
+ "tokenization_codegen25.CodeGen25Tokenizer", **_codegen_revision)
9
+
10
+ tiktoken_tokenizer = get_class_from_dynamic_module(
11
+ "tokenization_codegen25.tiktoken_tokenizer", **_codegen_revision)
12
 
13
 
14
  class DeciCoderTokenizer(CodeGen25Tokenizer):
 
20
  add_special_tokens=True,
21
  **kwargs,
22
  ):
23
+ self._tiktoken_kwargs = dict(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
24
  self.add_eos_token = add_eos_token
25
+ self.encoder = tiktoken_tokenizer(**self._tiktoken_kwargs)
26
  pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
27
  eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
28
  super().__init__(
 
34
  )
35
 
36
  def _convert_id_to_token(self, index):
37
+ """ bug fix in CodeGen25Tokenizer """
38
  try:
39
  return super()._convert_id_to_token(index)
40
  except:
41
  return None
42
+
43
+ def __getstate__(self):
44
+ """ make the object picklable """
45
+ return {**self.__dict__, "encoder": None}
46
+
47
+ def __setstate__(self, state):
48
+ """ initialize tiktoken encoder after unpickling """
49
+ state["encoder"] = tiktoken_tokenizer(**state["_tiktoken_kwargs"])
50
+ self.__dict__ = state
51
+
52
+ def save_pretrained(self, *args, **kwargs):
53
+ """
54
+ add_special_tokens is not JSON serializable, which crashes save_pretrained().
55
+ Removing it from the tokenizer_config.json does not affect from_pretrained().
56
+ """
57
+ add_special_tokens = self.add_special_tokens
58
+ self.add_special_tokens = None
59
+ super().save_pretrained(*args, **kwargs)
60
+ self.add_special_tokens = add_special_tokens