from transformers import PreTrainedTokenizer import json import os class CognitivessTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, merges_file, **kwargs): super().__init__(**kwargs) self.vocab_file = vocab_file self.merges_file = merges_file self.encoder = self.load_vocab(vocab_file) self.decoder = {v: k for k, v in self.encoder.items()} with open(merges_file, encoding="utf-8") as merges_handle: bpe_data = merges_handle.read() bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") merges_file = os.path.join(pretrained_model_name_or_path, "merges.txt") return cls(vocab_file, merges_file, **kwargs) @property def vocab_size(self): return len(self.encoder) def get_vocab(self): return dict(self.encoder) def _tokenize(self, text): return text.split() def _convert_token_to_id(self, token): return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): return self.decoder.get(index, self.unk_token) def convert_tokens_to_string(self, tokens): return " ".join(tokens) def save_vocabulary(self, save_directory): if not os.path.isdir(save_directory): os.makedirs(save_directory) vocab_file = os.path.join(save_directory, "vocab.json") merges_file = os.path.join(save_directory, "merges.txt") with open(vocab_file, "w", encoding="utf-8") as vocab_handle: json.dump(self.encoder, vocab_handle, ensure_ascii=False) with open(merges_file, "w", encoding="utf-8") as merges_handle: merges_handle.write("\n".join(" ".join(pair) for pair in self.bpe_ranks.keys())) return (vocab_file, merges_file) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): bos_token_id = [self.bos_token_id] eos_token_id = [self.eos_token_id] return bos_token_id + token_ids_0 + eos_token_id def load_vocab(self, vocab_file): with open(vocab_file, "r", encoding="utf-8") as f: return json.load(f)