|
|
"""HELM-BERT tokenizer.""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
from transformers import PreTrainedTokenizer |
|
|
|
|
|
|
|
|
|
|
|
HELM_VOCAB = { |
|
|
|
|
|
" ": 0, |
|
|
"@": 1, |
|
|
"\n": 2, |
|
|
"§": 3, |
|
|
"¶": 4, |
|
|
|
|
|
"A": 5, |
|
|
"R": 6, |
|
|
"N": 7, |
|
|
"D": 8, |
|
|
"C": 9, |
|
|
"E": 10, |
|
|
"Q": 11, |
|
|
"G": 12, |
|
|
"H": 13, |
|
|
"I": 14, |
|
|
"L": 15, |
|
|
"K": 16, |
|
|
"M": 17, |
|
|
"F": 18, |
|
|
"P": 19, |
|
|
"S": 20, |
|
|
"T": 21, |
|
|
"W": 22, |
|
|
"Y": 23, |
|
|
"V": 24, |
|
|
"X": 25, |
|
|
|
|
|
"[": 26, |
|
|
"]": 27, |
|
|
"{": 28, |
|
|
"}": 29, |
|
|
"(": 30, |
|
|
")": 31, |
|
|
"$": 32, |
|
|
",": 33, |
|
|
":": 34, |
|
|
"|": 35, |
|
|
"-": 36, |
|
|
".": 37, |
|
|
|
|
|
"0": 38, |
|
|
"1": 39, |
|
|
"2": 40, |
|
|
"3": 41, |
|
|
"4": 42, |
|
|
"5": 43, |
|
|
"6": 44, |
|
|
"7": 45, |
|
|
"8": 46, |
|
|
"9": 47, |
|
|
|
|
|
"B": 48, |
|
|
"O": 49, |
|
|
">": 50, |
|
|
|
|
|
"a": 51, |
|
|
"b": 52, |
|
|
"c": 53, |
|
|
"d": 54, |
|
|
"e": 55, |
|
|
"f": 56, |
|
|
"g": 57, |
|
|
"h": 58, |
|
|
"i": 59, |
|
|
"l": 60, |
|
|
"m": 61, |
|
|
"n": 62, |
|
|
"o": 63, |
|
|
"p": 64, |
|
|
"r": 65, |
|
|
"s": 66, |
|
|
"t": 67, |
|
|
"u": 68, |
|
|
"v": 69, |
|
|
"x": 70, |
|
|
"y": 71, |
|
|
"z": 72, |
|
|
|
|
|
"/": 73, |
|
|
"*": 74, |
|
|
"\t": 75, |
|
|
"&": 76, |
|
|
|
|
|
"_": 77, |
|
|
} |
|
|
|
|
|
|
|
|
HELM_ENCODE_MAP = {"PEPTIDE": "/", "me": "*", "am": "\t", "ac": "&"} |
|
|
HELM_DECODE_MAP = {v: k for k, v in HELM_ENCODE_MAP.items()} |
|
|
|
|
|
|
|
|
class HELMBertTokenizer(PreTrainedTokenizer): |
|
|
"""Tokenizer for HELM-BERT. |
|
|
|
|
|
This tokenizer handles HELM (Hierarchical Editing Language for Macromolecules) |
|
|
notation, converting peptide sequences into token IDs for the HELM-BERT model. |
|
|
|
|
|
The tokenizer uses character-level tokenization with special handling for |
|
|
multi-character HELM tokens like "PEPTIDE", "me", "am", "ac". |
|
|
|
|
|
Example: |
|
|
>>> from helmbert import HELMBertTokenizer |
|
|
>>> tokenizer = HELMBertTokenizer() |
|
|
>>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt") |
|
|
>>> inputs.input_ids |
|
|
tensor([[ 1, 73, 39, 28, 5, 37, 9, 37, 8, 37, 10, 29, 32, 32, 32, 32, 2]]) |
|
|
""" |
|
|
|
|
|
vocab_files_names = {"vocab_file": "vocab.json"} |
|
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_file: Optional[str] = None, |
|
|
unk_token: str = "§", |
|
|
sep_token: str = "\n", |
|
|
pad_token: str = " ", |
|
|
cls_token: str = "@", |
|
|
mask_token: str = "¶", |
|
|
bos_token: str = "@", |
|
|
eos_token: str = "\n", |
|
|
model_max_length: int = 512, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
if vocab_file is not None and os.path.isfile(vocab_file): |
|
|
with open(vocab_file, encoding="utf-8") as f: |
|
|
self.vocab = json.load(f) |
|
|
else: |
|
|
self.vocab = HELM_VOCAB.copy() |
|
|
|
|
|
self.ids_to_tokens = {v: k for k, v in self.vocab.items()} |
|
|
|
|
|
|
|
|
self.encode_map = HELM_ENCODE_MAP.copy() |
|
|
self.decode_map = HELM_DECODE_MAP.copy() |
|
|
|
|
|
super().__init__( |
|
|
unk_token=unk_token, |
|
|
sep_token=sep_token, |
|
|
pad_token=pad_token, |
|
|
cls_token=cls_token, |
|
|
mask_token=mask_token, |
|
|
bos_token=bos_token, |
|
|
eos_token=eos_token, |
|
|
model_max_length=model_max_length, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
"""Return the vocabulary size.""" |
|
|
return len(self.vocab) |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
|
"""Return the vocabulary as a dictionary.""" |
|
|
return self.vocab.copy() |
|
|
|
|
|
def _encode_helm(self, text: str) -> str: |
|
|
"""Encode multi-character HELM tokens to single characters. |
|
|
|
|
|
Args: |
|
|
text: Raw HELM notation string |
|
|
|
|
|
Returns: |
|
|
Encoded string with single-character tokens |
|
|
""" |
|
|
if not text: |
|
|
return "" |
|
|
result = text |
|
|
for seq, tok in self.encode_map.items(): |
|
|
result = result.replace(seq, tok) |
|
|
return result |
|
|
|
|
|
def _decode_helm(self, text: str) -> str: |
|
|
"""Decode single-character tokens back to multi-character HELM tokens. |
|
|
|
|
|
Args: |
|
|
text: Encoded string with single-character tokens |
|
|
|
|
|
Returns: |
|
|
Decoded HELM notation string |
|
|
""" |
|
|
if not text: |
|
|
return "" |
|
|
result = text |
|
|
for tok, seq in self.decode_map.items(): |
|
|
result = result.replace(tok, seq) |
|
|
return result |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
"""Tokenize a HELM string into a list of tokens. |
|
|
|
|
|
Args: |
|
|
text: HELM notation string |
|
|
|
|
|
Returns: |
|
|
List of single-character tokens |
|
|
""" |
|
|
|
|
|
encoded = self._encode_helm(text) |
|
|
|
|
|
return list(encoded) |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
"""Convert a token to its ID.""" |
|
|
return self.vocab.get(token, self.vocab.get(self.unk_token, 3)) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
"""Convert an ID to its token.""" |
|
|
return self.ids_to_tokens.get(index, self.unk_token) |
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
|
"""Convert a list of tokens to a HELM string. |
|
|
|
|
|
Args: |
|
|
tokens: List of tokens |
|
|
|
|
|
Returns: |
|
|
Decoded HELM notation string |
|
|
""" |
|
|
|
|
|
joined = "".join(tokens) |
|
|
return self._decode_helm(joined) |
|
|
|
|
|
def build_inputs_with_special_tokens( |
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|
|
) -> List[int]: |
|
|
"""Build model inputs by adding special tokens. |
|
|
|
|
|
Args: |
|
|
token_ids_0: First sequence of token IDs |
|
|
token_ids_1: Optional second sequence of token IDs |
|
|
|
|
|
Returns: |
|
|
List of token IDs with special tokens added |
|
|
""" |
|
|
cls_id = [self.cls_token_id] |
|
|
sep_id = [self.sep_token_id] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return cls_id + token_ids_0 + sep_id |
|
|
|
|
|
return cls_id + token_ids_0 + sep_id + token_ids_1 + sep_id |
|
|
|
|
|
def get_special_tokens_mask( |
|
|
self, |
|
|
token_ids_0: List[int], |
|
|
token_ids_1: Optional[List[int]] = None, |
|
|
already_has_special_tokens: bool = False, |
|
|
) -> List[int]: |
|
|
"""Get a mask identifying special tokens. |
|
|
|
|
|
Args: |
|
|
token_ids_0: First sequence of token IDs |
|
|
token_ids_1: Optional second sequence of token IDs |
|
|
already_has_special_tokens: Whether the sequences already have special tokens |
|
|
|
|
|
Returns: |
|
|
List of 0s and 1s (1 = special token) |
|
|
""" |
|
|
if already_has_special_tokens: |
|
|
return [ |
|
|
1 |
|
|
if x in [self.cls_token_id, self.sep_token_id, self.pad_token_id] |
|
|
else 0 |
|
|
for x in token_ids_0 |
|
|
] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return [1] + [0] * len(token_ids_0) + [1] |
|
|
|
|
|
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] |
|
|
|
|
|
def create_token_type_ids_from_sequences( |
|
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|
|
) -> List[int]: |
|
|
"""Create token type IDs for sequence pairs. |
|
|
|
|
|
Args: |
|
|
token_ids_0: First sequence of token IDs |
|
|
token_ids_1: Optional second sequence of token IDs |
|
|
|
|
|
Returns: |
|
|
List of token type IDs |
|
|
""" |
|
|
sep = [self.sep_token_id] |
|
|
cls = [self.cls_token_id] |
|
|
|
|
|
if token_ids_1 is None: |
|
|
return [0] * len(cls + token_ids_0 + sep) |
|
|
|
|
|
return [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep) |
|
|
|
|
|
def save_vocabulary( |
|
|
self, save_directory: str, filename_prefix: Optional[str] = None |
|
|
) -> Tuple[str]: |
|
|
"""Save the vocabulary to a file. |
|
|
|
|
|
Args: |
|
|
save_directory: Directory to save the vocabulary |
|
|
filename_prefix: Optional prefix for the filename |
|
|
|
|
|
Returns: |
|
|
Tuple containing the path to the saved vocabulary file |
|
|
""" |
|
|
if not os.path.isdir(save_directory): |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
vocab_file = os.path.join( |
|
|
save_directory, |
|
|
(filename_prefix + "-" if filename_prefix else "") + "vocab.json", |
|
|
) |
|
|
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
|
json.dump(self.vocab, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return (vocab_file,) |
|
|
|
|
|
@property |
|
|
def mask_token_id(self) -> int: |
|
|
"""Return the mask token ID.""" |
|
|
return self.vocab.get(self.mask_token, 4) |
|
|
|