| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Tokenization classes.""" |
|
|
|
|
| import collections |
| import logging |
| import os |
| import math |
| import unicodedata |
|
|
|
|
| from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| VOCAB_FILES_NAMES = {"vocab_file": os.getenv("VOCAB_NAME")} |
|
|
| PRETRAINED_VOCAB_FILES_MAP = {"vocab_file": { |
| 'dna' : os.getenv("VOCAB_PATH") |
| } |
| } |
|
|
|
|
| PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {'dna': os.getenv("POSITIONAL_EMBEDDINGS_SIZE")} |
| PRETRAINED_INIT_CONFIGURATION = {'dna': {"do_lower_case": False}} |
|
|
| def load_vocab(vocab_file): |
| """Loads a vocabulary file into a dictionary.""" |
| vocab = collections.OrderedDict() |
| with open(vocab_file, "r", encoding="utf-8") as reader: |
| tokens = reader.readlines() |
| for index, token in enumerate(tokens): |
| token = token.rstrip("\n") |
| vocab[token] = index |
| return vocab |
|
|
|
|
| def whitespace_tokenize(text): |
| """Runs basic whitespace cleaning and splitting on a piece of text.""" |
| text = text.strip() |
| if not text: |
| return [] |
| tokens = text.split() |
| return tokens |
|
|
|
|
| class DNATokenizer(PreTrainedTokenizer): |
| r""" |
| Constructs a BertTokenizer. |
| :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece |
| Args: |
| vocab_file: Path to a one-wordpiece-per-line vocabulary file |
| do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True |
| do_basic_tokenize: Whether to do basic tokenization before wordpiece. |
| max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the |
| minimum of this value (if specified) and the underlying BERT model's sequence length. |
| never_split: List of tokens which will never be split during tokenization. Only has an effect when |
| do_basic_tokenize=True |
| """ |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
| pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
| pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION |
| max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
|
|
|
|
| def __init__( |
| self, |
| vocab_file, |
| do_lower_case=False, |
| never_split=None, |
| unk_token="[UNK]", |
| sep_token="[SEP]", |
| pad_token="[PAD]", |
| cls_token="[CLS]", |
| mask_token="[MASK]", |
| tokenize_chinese_chars=True, |
| **kwargs |
| ): |
| """Constructs a BertTokenizer. |
| Args: |
| **vocab_file**: Path to a one-wordpiece-per-line vocabulary file |
| **do_lower_case**: (`optional`) boolean (default True) |
| Whether to lower case the input |
| Only has an effect when do_basic_tokenize=True |
| **do_basic_tokenize**: (`optional`) boolean (default True) |
| Whether to do basic tokenization before wordpiece. |
| **never_split**: (`optional`) list of string |
| List of tokens which will never be split during tokenization. |
| Only has an effect when do_basic_tokenize=True |
| **tokenize_chinese_chars**: (`optional`) boolean (default True) |
| Whether to tokenize Chinese characters. |
| This should likely be deactivated for Japanese: |
| see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 |
| """ |
| super().__init__( |
| unk_token=unk_token, |
| sep_token=sep_token, |
| pad_token=pad_token, |
| cls_token=cls_token, |
| mask_token=mask_token, |
| **kwargs, |
| ) |
| self.max_len_single_sentence = self.max_len - 2 |
| self.max_len_sentences_pair = self.max_len - 3 |
|
|
| if not os.path.isfile(vocab_file): |
| raise ValueError( |
| "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " |
| "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) |
| ) |
| self.vocab = load_vocab(vocab_file) |
| self.kmer = VOCAB_KMER[str(len(self.vocab))] |
| self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) |
| self.basic_tokenizer = BasicTokenizer( |
| do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars |
| ) |
|
|
| @property |
| def vocab_size(self): |
| return len(self.vocab) |
|
|
| def _tokenize(self, text): |
| split_tokens = [] |
| for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): |
| split_tokens.append(token) |
| |
| return split_tokens |
|
|
| def _convert_token_to_id(self, token): |
| """ Converts a token (str) in an id using the vocab. """ |
| return self.vocab.get(token, self.vocab.get(self.unk_token)) |
|
|
| def _convert_id_to_token(self, index): |
| """Converts an index (integer) in a token (str) using the vocab.""" |
| return self.ids_to_tokens.get(index, self.unk_token) |
|
|
| def convert_tokens_to_string(self, tokens): |
| """ Converts a sequence of tokens (string) in a single string. """ |
| out_string = " ".join(tokens).replace(" ##", "").strip() |
| return out_string |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| """ |
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks |
| by concatenating and adding special tokens. |
| A BERT sequence has the following format: |
| single sequence: [CLS] X [SEP] |
| pair of sequences: [CLS] A [SEP] B [SEP] |
| """ |
| cls = [self.cls_token_id] |
| sep = [self.sep_token_id] |
|
|
| if token_ids_1 is None: |
| if len(token_ids_0) < 510: |
| return cls + token_ids_0 + sep |
| else: |
| output = [] |
| num_pieces = int(len(token_ids_0)//510) + 1 |
| for i in range(num_pieces): |
| output.extend(cls + token_ids_0[510*i:min(len(token_ids_0), 510*(i+1))] + sep) |
| return output |
|
|
| return cls + token_ids_0 + sep + token_ids_1 + sep |
|
|
| def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): |
| """ |
| Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding |
| special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. |
| Args: |
| token_ids_0: list of ids (must not contain special tokens) |
| token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids |
| for sequence pairs |
| already_has_special_tokens: (default False) Set to True if the token list is already formated with |
| special tokens for the model |
| Returns: |
| A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. |
| """ |
|
|
| if already_has_special_tokens: |
| if token_ids_1 is not None: |
| raise ValueError( |
| "You should not supply a second sequence if the provided sequence of " |
| "ids is already formated with special tokens for the model." |
| ) |
| return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) |
|
|
| if token_ids_1 is not None: |
| return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
| |
| if len(token_ids_0) < 510: |
| return [1] + ([0] * len(token_ids_0)) + [1] |
| else: |
| output = [] |
| num_pieces = int(len(token_ids_0)//510) + 1 |
| for i in range(num_pieces): |
| output.extend([1] + ([0] * (min(len(token_ids_0), 510*(i+1))-510*i)) + [1]) |
| return output |
| return [1] + ([0] * len(token_ids_0)) + [1] |
|
|
| def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): |
| """ |
| Creates a mask from the two sequences passed to be used in a sequence-pair classification task. |
| A BERT sequence pair mask has the following format: |
| 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 |
| | first sequence | second sequence |
| if token_ids_1 is None, only returns the first portion of the mask (0's). |
| """ |
| sep = [self.sep_token_id] |
| cls = [self.cls_token_id] |
| if token_ids_1 is None: |
| if len(token_ids_0) < 510: |
| return len(cls + token_ids_0 + sep) * [0] |
| else: |
| num_pieces = int(len(token_ids_0)//510) + 1 |
| return (len(cls + token_ids_0 + sep) + 2*(num_pieces-1)) * [0] |
| return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] |
|
|
| def save_vocabulary(self, vocab_path): |
| """Save the tokenizer vocabulary to a directory or file.""" |
| index = 0 |
| if os.path.isdir(vocab_path): |
| vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) |
| else: |
| vocab_file = vocab_path |
| with open(vocab_file, "w", encoding="utf-8") as writer: |
| for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): |
| if index != token_index: |
| logger.warning( |
| "Saving vocabulary to {}: vocabulary indices are not consecutive." |
| " Please check that the vocabulary is not corrupted!".format(vocab_file) |
| ) |
| index = token_index |
| writer.write(token + "\n") |
| index += 1 |
| return (vocab_file,) |
|
|
|
|
| class BasicTokenizer(object): |
| """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" |
|
|
| def __init__(self, do_lower_case=False, never_split=None, tokenize_chinese_chars=True): |
| """ Constructs a BasicTokenizer. |
| Args: |
| **do_lower_case**: Whether to lower case the input. |
| **never_split**: (`optional`) list of str |
| Kept for backward compatibility purposes. |
| Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) |
| List of token not to split. |
| **tokenize_chinese_chars**: (`optional`) boolean (default True) |
| Whether to tokenize Chinese characters. |
| This should likely be deactivated for Japanese: |
| see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 |
| """ |
| if never_split is None: |
| never_split = [] |
| self.do_lower_case = do_lower_case |
| self.never_split = never_split |
| self.tokenize_chinese_chars = tokenize_chinese_chars |
|
|
| def tokenize(self, text, never_split=None): |
| """ Basic Tokenization of a piece of text. |
| Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. |
| Args: |
| **never_split**: (`optional`) list of str |
| Kept for backward compatibility purposes. |
| Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) |
| List of token not to split. |
| """ |
| never_split = self.never_split + (never_split if never_split is not None else []) |
| text = self._clean_text(text) |
| |
| |
| |
| |
| |
| |
| orig_tokens = whitespace_tokenize(text) |
| split_tokens = [] |
| for token in orig_tokens: |
| if token not in never_split: |
| token = self._run_strip_accents(token) |
| split_tokens.extend(self._run_split_on_punc(token, never_split)) |
|
|
| output_tokens = whitespace_tokenize(" ".join(split_tokens)) |
| return output_tokens |
|
|
| def _run_strip_accents(self, text): |
| """Strips accents from a piece of text.""" |
| text = unicodedata.normalize("NFD", text) |
| output = [] |
| for char in text: |
| cat = unicodedata.category(char) |
| if cat == "Mn": |
| continue |
| output.append(char) |
| return "".join(output) |
|
|
| def _run_split_on_punc(self, text, never_split=None): |
| """Splits punctuation on a piece of text.""" |
| if never_split is not None and text in never_split: |
| return [text] |
| chars = list(text) |
| i = 0 |
| start_new_word = True |
| output = [] |
| while i < len(chars): |
| char = chars[i] |
| if _is_punctuation(char): |
| output.append([char]) |
| start_new_word = True |
| else: |
| if start_new_word: |
| output.append([]) |
| start_new_word = False |
| output[-1].append(char) |
| i += 1 |
|
|
| return ["".join(x) for x in output] |
|
|
|
|
|
|
| def _clean_text(self, text): |
| """Performs invalid character removal and whitespace cleanup on text.""" |
| output = [] |
| for char in text: |
| cp = ord(char) |
| if cp == 0 or cp == 0xFFFD or _is_control(char): |
| continue |
| if _is_whitespace(char): |
| output.append(" ") |
| else: |
| output.append(char) |
| return "".join(output) |
|
|
|
|
| def _is_whitespace(char): |
| """Checks whether `chars` is a whitespace character.""" |
| |
| |
| if char == " " or char == "\t" or char == "\n" or char == "\r": |
| return True |
| cat = unicodedata.category(char) |
| if cat == "Zs": |
| return True |
| return False |
|
|
|
|
| def _is_control(char): |
| """Checks whether `chars` is a control character.""" |
| |
| |
| if char == "\t" or char == "\n" or char == "\r": |
| return False |
| cat = unicodedata.category(char) |
| if cat.startswith("C"): |
| return True |
| return False |
|
|
|
|
| def _is_punctuation(char): |
| """Checks whether `chars` is a punctuation character.""" |
| cp = ord(char) |
| |
| |
| |
| |
| if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): |
| return True |
| cat = unicodedata.category(char) |
| if cat.startswith("P"): |
| return True |
| return False |
|
|