MPNN-ProGen2-xlarge-CATH42 / tokenization_iPLM.py
徐俊德
init
c525dff
raw
history blame
No virus
3.82 kB
from typing import List, Optional, Union
from transformers import PreTrainedTokenizerFast
from tokenizers.processors import TemplateProcessing
from tokenizers import Tokenizer
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import PaddingStrategy, TensorType
import torch
def create_tokenizer_custom(file):
with open(file, 'r') as f:
return Tokenizer.from_str(f.read())
class iPLMTokenizer(PreTrainedTokenizerFast):
def __init__(self, n_queries, use_structure=True, parallel=False, **kwargs):
super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs)
self.add_special_tokens({'pad_token': '<|pad|>'})
self.use_structure = use_structure
self.n_queries = n_queries if use_structure else 0
self.parallel = parallel
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
raw_text = []
if not isinstance(text, list):
text = [text]
if self.use_structure:
attn_mask_prefix = torch.zeros((len(text), self.n_queries), dtype=bool)
input_ids_prefix = torch.zeros((len(text), self.n_queries), dtype=int)
for i in range(len(text)):
if '|' in text[i]:
res = text[i].split('|')
raw_text.append(res[1])
if self.use_structure:
# covert and pad structure id to ascii
structure_id = torch.tensor([ord(c) for c in res[0]])
input_ids_prefix[i, :len(structure_id)] = structure_id
attn_mask_prefix[i] = True
else:
raw_text.append(text)
batch = super().__call__(raw_text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
if self.use_structure:
batch['attention_mask'] = torch.cat([attn_mask_prefix, batch['attention_mask']], dim=1)
batch['input_ids'] = torch.cat([input_ids_prefix, batch['input_ids']], dim=1)
if "token_type_ids" in batch:
del batch["token_type_ids"]
return batch