hyenadna-small-32k-seqlen-hf / tokenization_hyena.py
matt
Update config + tokenization
b0472d0
raw
history blame
4.06 kB
from transformers import PreTrainedTokenizer, AddedToken
from typing import List, Optional, Union, Dict, Sequence, Tuple
from pathlib import Path
import json
import os
class HyenaDNATokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids"]
def __init__(self,
model_max_length: int,
bos_token="[BOS]",
eos_token="[SEP]",
sep_token="[SEP]",
cls_token="[CLS]",
pad_token="[PAD]",
mask_token="[MASK]",
unk_token="[UNK]",
**kwargs):
"""Character tokenizer for Hugging Face transformers.
Args:
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following are list of all of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
model_max_length (int): Model maximum sequence length.
"""
self.characters = ('A', 'C', 'G', 'T', 'N')
self.model_max_length = model_max_length
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(self.characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
add_prefix_space = kwargs.pop("add_prefix_space", False)
padding_side = kwargs.pop("padding_side", "left")
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=add_prefix_space,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
def _tokenize(self, text: str) -> List[str]:
return list(text)
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens)
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
# cls = [self.cls_token_id]
result = token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
# HyenaDNA has a fixed vocabulary with no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
return ()