File size: 2,437 Bytes
c141a5b e084f01 c141a5b e084f01 c141a5b e084f01 c141a5b e084f01 c141a5b e084f01 c141a5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.tokenization_utils import AddedToken
_codegen_revision = dict(pretrained_model_name_or_path="Salesforce/codegen25-7b-multi",
revision="d4dc9dd90e8b23d5411e6d970e3a11e88dc5c2bc")
CodeGen25Tokenizer = get_class_from_dynamic_module(
"tokenization_codegen25.CodeGen25Tokenizer", **_codegen_revision)
tiktoken_tokenizer = get_class_from_dynamic_module(
"tokenization_codegen25.tiktoken_tokenizer", **_codegen_revision)
class DeciCoderTokenizer(CodeGen25Tokenizer):
def __init__(
self,
pad_token=None,
eos_token="<|endoftext|>",
add_eos_token=False,
add_special_tokens=True,
**kwargs,
):
self._tiktoken_kwargs = dict(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
self.add_eos_token = add_eos_token
self.encoder = tiktoken_tokenizer(**self._tiktoken_kwargs)
pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
super().__init__(
pad_token=pad_token_added,
eos_token=eos_token_added,
add_eos_token=add_eos_token,
add_special_tokens=add_special_tokens,
**kwargs,
)
def _convert_id_to_token(self, index):
""" bug fix in CodeGen25Tokenizer """
try:
return super()._convert_id_to_token(index)
except:
return None
def __getstate__(self):
""" make the object picklable """
return {**self.__dict__, "encoder": None}
def __setstate__(self, state):
""" initialize tiktoken encoder after unpickling """
state["encoder"] = tiktoken_tokenizer(**state["_tiktoken_kwargs"])
self.__dict__ = state
def save_pretrained(self, *args, **kwargs):
"""
add_special_tokens is not JSON serializable, which crashes save_pretrained().
Removing it from the tokenizer_config.json does not affect from_pretrained().
"""
add_special_tokens = self.add_special_tokens
self.add_special_tokens = None
super().save_pretrained(*args, **kwargs)
self.add_special_tokens = add_special_tokens
|