# !pip install sentencepiece==0.1.96 transformers==4.10.0 import sentencepiece as spm import os from transformers import PreTrainedTokenizer from collections import Counter from typing import List, Optional, Tuple class RobertaTokenizer(PreTrainedTokenizer): def __init__( self, pretrained_file, bos_token="", eos_token="", sep_token="", cls_token="", unk_token="", pad_token="", mask_token="", **kwargs ): super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, mask_token=mask_token, **kwargs, ) # load bpe model and vocab file sentencepiece_model = os.path.join(pretrained_file, 'sentencepiece.bpe.model') vocab_file = os.path.join(pretrained_file, 'dict.txt') self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load( sentencepiece_model) # please dont use anything from sp_model bcz it makes everything goes wrong self.bpe_dict = Dictionary().load(vocab_file) # Mimic fairseq token-to-id alignment for the first 4 token self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab self.fairseq_offset = 0 self.fairseq_tokens_to_ids[""] = len(self.bpe_dict) + self.fairseq_offset self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} def _tokenize(self, text): return self.sp_model.EncodeAsPieces(text) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: #TODO return "", "" def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ if token in self.fairseq_tokens_to_ids: return self.fairseq_tokens_to_ids[token] spm_id = self.bpe_dict.index(token) return spm_id def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" if index in self.fairseq_ids_to_tokens: return self.fairseq_ids_to_tokens[index] return self.bpe_dict[index] def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. This implementation does not add special tokens and this method should be overridden in a subclass. Args: token_ids_0 (:obj:`List[int]`): The first tokenized sequence. token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence. Returns: :obj:`List[int]`: The model input with special tokens. """ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] def create_token_type_ids_from_sequences( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does not make use of token type ids, therefore a list of zeros is returned. Args: token_ids_0 (:obj:`List[int]`): List of IDs. token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. Returns: :obj:`List[int]`: List of zeros. """ sep = [self.sep_token_id] cls = [self.cls_token_id] return len(cls + token_ids_0 + sep) * [0] @property def vocab_size(self): return len(self.bpe_dict) + self.fairseq_offset + 1 # Add the token def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab class Dictionary(object): """A mapping from symbols to consecutive integers""" def __init__( self, pad='', eos='', unk='', bos='', extra_special_symbols=None, ): self.unk_word, self.pad_word, self.eos_word = unk, pad, eos self.symbols = [] self.count = [] self.indices = {} self.bos_index = self.add_symbol(bos) self.pad_index = self.add_symbol(pad) self.eos_index = self.add_symbol(eos) self.unk_index = self.add_symbol(unk) if extra_special_symbols: for s in extra_special_symbols: self.add_symbol(s) self.nspecial = len(self.symbols) def __eq__(self, other): return self.indices == other.indices def __getitem__(self, idx): if idx < len(self.symbols): return self.symbols[idx] return self.unk_word def __len__(self): """Returns the number of symbols in the dictionary""" return len(self.symbols) def __contains__(self, sym): return sym in self.indices def index(self, sym): """Returns the index of the specified symbol""" assert isinstance(sym, str) if sym in self.indices: return self.indices[sym] return self.unk_index def unk_string(self, escape=False): """Return unknown string, optionally escaped as: <>""" if escape: return '<{}>'.format(self.unk_word) else: return self.unk_word def add_symbol(self, word, n=1): """Adds a word to the dictionary""" if word in self.indices: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx else: idx = len(self.symbols) self.indices[word] = idx self.symbols.append(word) self.count.append(n) return idx def update(self, new_dict): """Updates counts from new dictionary.""" for word in new_dict.symbols: idx2 = new_dict.indices[word] if word in self.indices: idx = self.indices[word] self.count[idx] = self.count[idx] + new_dict.count[idx2] else: idx = len(self.symbols) self.indices[word] = idx self.symbols.append(word) self.count.append(new_dict.count[idx2]) def finalize(self, threshold=-1, nwords=-1, padding_factor=8): """Sort symbols by frequency in descending order, ignoring special ones. Args: - threshold defines the minimum word count - nwords defines the total number of words in the final dictionary, including special symbols - padding_factor can be used to pad the dictionary size to be a multiple of 8, which is important on some hardware (e.g., Nvidia Tensor Cores). """ if nwords <= 0: nwords = len(self) new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial))) new_symbols = self.symbols[:self.nspecial] new_count = self.count[:self.nspecial] c = Counter(dict(sorted(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))) for symbol, count in c.most_common(nwords - self.nspecial): if count >= threshold: new_indices[symbol] = len(new_symbols) new_symbols.append(symbol) new_count.append(count) else: break threshold_nwords = len(new_symbols) if padding_factor > 1: i = 0 while threshold_nwords % padding_factor != 0: symbol = 'madeupword{:04d}'.format(i) new_indices[symbol] = len(new_symbols) new_symbols.append(symbol) new_count.append(0) i += 1 threshold_nwords += 1 assert len(new_symbols) % padding_factor == 0 assert len(new_symbols) == len(new_indices) self.count = list(new_count) self.symbols = list(new_symbols) self.indices = new_indices def bos(self): """Helper to get index of beginning-of-sentence symbol""" return self.bos_index def pad(self): """Helper to get index of pad symbol""" return self.pad_index def eos(self): """Helper to get index of end-of-sentence symbol""" return self.eos_index def unk(self): """Helper to get index of unk symbol""" return self.unk_index @classmethod def load(cls, f): """Loads the dictionary from a text file with the format: ``` ... ``` """ d = cls() d.add_from_file(f) return d def add_from_file(self, f): """ Loads a pre-existing dictionary from a text file and adds its symbols to this instance. """ if isinstance(f, str): try: with open(f, 'r', encoding='utf-8') as fd: self.add_from_file(fd) except FileNotFoundError as fnfe: raise fnfe except UnicodeError: raise Exception("Incorrect encoding detected in {}, please " "rebuild the dataset".format(f)) return lines = f.readlines() indices_start_line = self._load_meta(lines) for line in lines[indices_start_line:]: idx = line.rfind(' ') if idx == -1: raise ValueError("Incorrect dictionary format, expected ' '") word = line[:idx] count = int(line[idx + 1:]) self.indices[word] = len(self.symbols) self.symbols.append(word) self.count.append(count) def _save(self, f, kv_iterator): if isinstance(f, str): os.makedirs(os.path.dirname(f), exist_ok=True) with open(f, 'w', encoding='utf-8') as fd: return self.save(fd) for k, v in kv_iterator: print('{} {}'.format(k, v), file=f) def _get_meta(self): return [], [] def _load_meta(self, lines): return 0 def save(self, f): """Stores dictionary into a text file""" ex_keys, ex_vals = self._get_meta() self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))