helm-bert / tokenization_helmbert.py
Flansma's picture
Upload folder using huggingface_hub
62249d1 verified
"""HELM-BERT tokenizer."""
import json
import os
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
# Default vocabulary for HELM notation
HELM_VOCAB = {
# Special tokens (0-4)
" ": 0, # PAD
"@": 1, # BOS/CLS
"\n": 2, # EOS/SEP
"§": 3, # UNK
"¶": 4, # MASK
# Natural amino acids (5-25)
"A": 5,
"R": 6,
"N": 7,
"D": 8,
"C": 9,
"E": 10,
"Q": 11,
"G": 12,
"H": 13,
"I": 14,
"L": 15,
"K": 16,
"M": 17,
"F": 18,
"P": 19,
"S": 20,
"T": 21,
"W": 22,
"Y": 23,
"V": 24,
"X": 25, # Unknown amino acid
# Structure symbols (26-37)
"[": 26,
"]": 27,
"{": 28,
"}": 29,
"(": 30,
")": 31,
"$": 32,
",": 33,
":": 34,
"|": 35,
"-": 36,
".": 37,
# Numbers (38-47)
"0": 38,
"1": 39,
"2": 40,
"3": 41,
"4": 42,
"5": 43,
"6": 44,
"7": 45,
"8": 46,
"9": 47,
# Uppercase non-amino acids (48-50)
"B": 48,
"O": 49,
">": 50,
# Lowercase letters (51-72)
"a": 51,
"b": 52,
"c": 53,
"d": 54,
"e": 55,
"f": 56,
"g": 57,
"h": 58,
"i": 59,
"l": 60,
"m": 61,
"n": 62,
"o": 63,
"p": 64,
"r": 65,
"s": 66,
"t": 67,
"u": 68,
"v": 69,
"x": 70,
"y": 71,
"z": 72,
# Encoded polymer markers (73-76)
"/": 73, # PEPTIDE
"*": 74, # me
"\t": 75, # am
"&": 76, # ac
# Miscellaneous (77)
"_": 77,
}
# Multi-character to single-character encoding
HELM_ENCODE_MAP = {"PEPTIDE": "/", "me": "*", "am": "\t", "ac": "&"}
HELM_DECODE_MAP = {v: k for k, v in HELM_ENCODE_MAP.items()}
class HELMBertTokenizer(PreTrainedTokenizer):
"""Tokenizer for HELM-BERT.
This tokenizer handles HELM (Hierarchical Editing Language for Macromolecules)
notation, converting peptide sequences into token IDs for the HELM-BERT model.
The tokenizer uses character-level tokenization with special handling for
multi-character HELM tokens like "PEPTIDE", "me", "am", "ac".
Example:
>>> from helmbert import HELMBertTokenizer
>>> tokenizer = HELMBertTokenizer()
>>> inputs = tokenizer("PEPTIDE1{A.C.D.E}$$$$", return_tensors="pt")
>>> inputs.input_ids
tensor([[ 1, 73, 39, 28, 5, 37, 9, 37, 8, 37, 10, 29, 32, 32, 32, 32, 2]])
"""
vocab_files_names = {"vocab_file": "vocab.json"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file: Optional[str] = None,
unk_token: str = "§",
sep_token: str = "\n",
pad_token: str = " ",
cls_token: str = "@",
mask_token: str = "¶",
bos_token: str = "@",
eos_token: str = "\n",
model_max_length: int = 512,
**kwargs,
):
# Load or create vocabulary
if vocab_file is not None and os.path.isfile(vocab_file):
with open(vocab_file, encoding="utf-8") as f:
self.vocab = json.load(f)
else:
self.vocab = HELM_VOCAB.copy()
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
# HELM encoding/decoding maps
self.encode_map = HELM_ENCODE_MAP.copy()
self.decode_map = HELM_DECODE_MAP.copy()
super().__init__(
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
bos_token=bos_token,
eos_token=eos_token,
model_max_length=model_max_length,
**kwargs,
)
@property
def vocab_size(self) -> int:
"""Return the vocabulary size."""
return len(self.vocab)
def get_vocab(self) -> Dict[str, int]:
"""Return the vocabulary as a dictionary."""
return self.vocab.copy()
def _encode_helm(self, text: str) -> str:
"""Encode multi-character HELM tokens to single characters.
Args:
text: Raw HELM notation string
Returns:
Encoded string with single-character tokens
"""
if not text:
return ""
result = text
for seq, tok in self.encode_map.items():
result = result.replace(seq, tok)
return result
def _decode_helm(self, text: str) -> str:
"""Decode single-character tokens back to multi-character HELM tokens.
Args:
text: Encoded string with single-character tokens
Returns:
Decoded HELM notation string
"""
if not text:
return ""
result = text
for tok, seq in self.decode_map.items():
result = result.replace(tok, seq)
return result
def _tokenize(self, text: str) -> List[str]:
"""Tokenize a HELM string into a list of tokens.
Args:
text: HELM notation string
Returns:
List of single-character tokens
"""
# First encode multi-character tokens to single characters
encoded = self._encode_helm(text)
# Return as list of characters
return list(encoded)
def _convert_token_to_id(self, token: str) -> int:
"""Convert a token to its ID."""
return self.vocab.get(token, self.vocab.get(self.unk_token, 3))
def _convert_id_to_token(self, index: int) -> str:
"""Convert an ID to its token."""
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Convert a list of tokens to a HELM string.
Args:
tokens: List of tokens
Returns:
Decoded HELM notation string
"""
# Join tokens and decode back to HELM notation
joined = "".join(tokens)
return self._decode_helm(joined)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Build model inputs by adding special tokens.
Args:
token_ids_0: First sequence of token IDs
token_ids_1: Optional second sequence of token IDs
Returns:
List of token IDs with special tokens added
"""
cls_id = [self.cls_token_id]
sep_id = [self.sep_token_id]
if token_ids_1 is None:
return cls_id + token_ids_0 + sep_id
return cls_id + token_ids_0 + sep_id + token_ids_1 + sep_id
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""Get a mask identifying special tokens.
Args:
token_ids_0: First sequence of token IDs
token_ids_1: Optional second sequence of token IDs
already_has_special_tokens: Whether the sequences already have special tokens
Returns:
List of 0s and 1s (1 = special token)
"""
if already_has_special_tokens:
return [
1
if x in [self.cls_token_id, self.sep_token_id, self.pad_token_id]
else 0
for x in token_ids_0
]
if token_ids_1 is None:
return [1] + [0] * len(token_ids_0) + [1]
return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""Create token type IDs for sequence pairs.
Args:
token_ids_0: First sequence of token IDs
token_ids_1: Optional second sequence of token IDs
Returns:
List of token type IDs
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return [0] * len(cls + token_ids_0 + sep)
return [0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
"""Save the vocabulary to a file.
Args:
save_directory: Directory to save the vocabulary
filename_prefix: Optional prefix for the filename
Returns:
Tuple containing the path to the saved vocabulary file
"""
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + "vocab.json",
)
with open(vocab_file, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, ensure_ascii=False, indent=2)
return (vocab_file,)
@property
def mask_token_id(self) -> int:
"""Return the mask token ID."""
return self.vocab.get(self.mask_token, 4)