| | from tokenizers import Tokenizer |
| | from tokenizers.models import BPE |
| | from tokenizers.processors import TemplateProcessing |
| | from transformers import PreTrainedTokenizerFast |
| |
|
| |
|
| | |
| | SEQUENCE_VOCAB = [ |
| | "<cls>", "<pad>", "<eos>", "<unk>", |
| | "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", |
| | "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", |
| | "O", ".", "-", "|", |
| | "<mask>", |
| | ] |
| |
|
| | class EsmSequenceTokenizer(PreTrainedTokenizerFast): |
| | model_input_names = ["input_ids", "attention_mask"] |
| |
|
| | def __init__( |
| | self, |
| | unk_token="<unk>", |
| | cls_token="<cls>", |
| | pad_token="<pad>", |
| | mask_token="<mask>", |
| | eos_token="<eos>", |
| | chain_break_token="|", |
| | **kwargs, |
| | ): |
| | all_tokens = SEQUENCE_VOCAB |
| | token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} |
| |
|
| | |
| | bpe = BPE(token_to_id, merges=[], unk_token=unk_token) |
| | tokenizer = Tokenizer(bpe) |
| | special_tokens = [ |
| | cls_token, |
| | pad_token, |
| | mask_token, |
| | eos_token, |
| | chain_break_token, |
| | ] |
| | self.cb_token = chain_break_token |
| | additional_special_tokens = [chain_break_token] |
| |
|
| | tokenizer.add_special_tokens(special_tokens) |
| |
|
| | |
| | |
| | |
| | tokenizer.post_processor = TemplateProcessing( |
| | single="<cls> $A <eos>", |
| | pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1", |
| | special_tokens=[ |
| | ("<cls>", tokenizer.token_to_id("<cls>")), |
| | ("<eos>", tokenizer.token_to_id("<eos>")), |
| | ], |
| | ) |
| | super().__init__( |
| | tokenizer_object=tokenizer, |
| | unk_token=unk_token, |
| | cls_token=cls_token, |
| | pad_token=pad_token, |
| | mask_token=mask_token, |
| | eos_token=eos_token, |
| | additional_special_tokens=additional_special_tokens, |
| | **kwargs, |
| | ) |
| |
|
| | |
| | @property |
| | def bos_token(self): |
| | return self.cls_token |
| |
|
| | @property |
| | def bos_token_id(self): |
| | return self.cls_token_id |
| |
|
| | @property |
| | def chain_break_token(self): |
| | return self.cb_token |
| |
|
| | @property |
| | def chain_break_token_id(self): |
| | return self.convert_tokens_to_ids(self.chain_break_token) |
| |
|
| | @property |
| | def all_token_ids(self): |
| | return list(range(self.vocab_size)) |
| |
|
| | @property |
| | def special_token_ids(self): |
| | return self.all_special_ids |