from transformers import PreTrainedTokenizer from typing import List, Optional import json class SPTTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file=None, **kwargs): super().__init__(**kwargs) self.vocab = self.load_vocab(vocab_file) self.inv_vocab = {v: k for k, v in self.vocab.items()} self.pad_token = self.eos_token = "#" self.unk_token = "[UNK]" @property def vocab_size(self): return len(self.vocab) def get_vocab(self): return dict(self.vocab) def _tokenize(self, text): return list(text) def _convert_token_to_id(self, token): return self.vocab.get(token, self.vocab.get(self.unk_token)) def _convert_id_to_token(self, index): return self.inv_vocab.get(index, self.unk_token) def convert_tokens_to_string(self, tokens): return ''.join(tokens) def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: if token_ids_1 is None: return token_ids_0 + [self.eos_token_id] return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]: if already_has_special_tokens: return [1 if token in [self.eos_token_id] else 0 for token in token_ids_0] if token_ids_1 is None: return [0] * len(token_ids_0) + [1] return [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: if token_ids_1 is None: return [0] * (len(token_ids_0) + 1) return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): tokenizer = super().from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) return tokenizer def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: import os if not os.path.isdir(save_directory): os.mkdir(save_directory) vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" ) with open(vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.vocab, ensure_ascii=False)) return (vocab_file,) def load_vocab(self, vocab_file): if vocab_file is None: return {'\n': 0, ' ': 1, '!': 2, '"': 3, '&': 4, "'": 5, '(': 6, ')': 7, '*': 8, ',': 9, '-': 10, '.': 11, '0': 12, '1': 13, '2': 14, '3': 15, '4': 16, '5': 17, '6': 18, '7': 19, '8': 20, '9': 21, ':': 22, ';': 23, '?': 24, 'A': 25, 'B': 26, 'C': 27, 'D': 28, 'E': 29, 'F': 30, 'G': 31, 'H': 32, 'I': 33, 'J': 34, 'K': 35, 'L': 36, 'M': 37, 'N': 38, 'O': 39, 'P': 40, 'Q': 41, 'R': 42, 'S': 43, 'T': 44, 'U': 45, 'V': 46, 'W': 47, 'X': 48, 'Y': 49, 'Z': 50, '[': 51, ']': 52, '`': 53, 'a': 54, 'b': 55, 'c': 56, 'd': 57, 'e': 58, 'f': 59, 'g': 60, 'h': 61, 'i': 62, 'j': 63, 'k': 64, 'l': 65, 'm': 66, 'n': 67, 'o': 68, 'p': 69, 'q': 70, 'r': 71, 's': 72, 't': 73, 'u': 74, 'v': 75, 'w': 76, 'x': 77, 'y': 78, 'z': 79, '£': 80, '°': 81, 'ß': 82, 'à': 83, 'â': 84, 'è': 85, 'é': 86, 'ê': 87, 'î': 88, 'ñ': 89, 'ô': 90, 'ö': 91, 'û': 92, 'ü': 93} else: with open(vocab_file, 'r', encoding='utf-8') as f: return json.load(f)