FloraBERT / module /transformers_utility.py
Gurveer05's picture
Improved model
d9d8dae
from pathlib import PosixPath
from typing import Union, Optional
import torch
from transformers import (
RobertaConfig,
RobertaTokenizerFast,
RobertaForMaskedLM,
RobertaForSequenceClassification,
)
from .models import (
RobertaMeanPoolConfig,
RobertaForSequenceClassificationMeanPool,
)
RobertaSettings = dict(
padding_side='left'
)
MODELS = {
"roberta-lm": (RobertaConfig, RobertaTokenizerFast, RobertaForMaskedLM, RobertaSettings),
"roberta-pred": (RobertaConfig, RobertaTokenizerFast, RobertaForSequenceClassification, RobertaSettings),
"roberta-pred-mean-pool": (RobertaMeanPoolConfig, RobertaTokenizerFast, RobertaForSequenceClassificationMeanPool, RobertaSettings)
}
def load_model(model_name: str,
tokenizer_dir: Union[str, PosixPath],
max_tokenized_len: int = 254,
pretrained_model: Union[str, PosixPath] = None,
k: Optional[int] = None,
do_lower_case: Optional[bool] = None,
padding_side: Optional[str] = 'left',
**config_settings) -> tuple:
"""Load specified model, config, and tokenizer.
Args:
model_name (str): Name of model. Acceptable options are
- 'roberta-lm',
- 'roberta-pred',
- 'roberta-pred-mean-pool'
tokenizer_dir (Union[str, PosixPath]): Directory containing tokenizer
files: merges.txt and vocab.txt
max_len (int, optional): Maximum tokenized length,
not including SOS and EOS. Defaults to 254.
pretrained_model (Union[str, PosixPath], optional): path to saved
pretrained RoBERTa transformer model. Defaults to None.
k (Optional[int], optional): Size of kmers (for DNABERT model). Defaults to 6.
do_lower_case (bool, optional): Whether to convert all inputs to lower case. Defaults to None.
padding_side (str, optional): Which side to pad on. Defaults to 'left'.
Returns:
tuple: config_obj, tokenizer, model
"""
config_settings = config_settings or {}
max_position_embeddings = max_tokenized_len + 2 # To include SOS and EOS
config_class, tokenizer_class, model_class, tokenizer_settings = MODELS[model_name]
kwargs = dict(
max_len=max_tokenized_len,
truncate=True,
padding="max_length",
**tokenizer_settings
)
if k is not None:
kwargs.update(dict(k=k))
if do_lower_case is not None:
kwargs.update(dict(do_lower_case=do_lower_case))
if padding_side is not None:
kwargs.update(dict(padding_side=padding_side))
tokenizer = tokenizer_class.from_pretrained(str(tokenizer_dir), **kwargs)
name_or_path = str(pretrained_model) or ''
config_obj = config_class(
vocab_size=len(tokenizer),
max_position_embeddings=max_position_embeddings,
name_or_path=name_or_path,
output_hidden_states=True,
**config_settings
)
if pretrained_model:
# print(f"Loading from pretrained model {pretrained_model}")
model = model_class(config=config_obj)
state_dict = torch.load(pretrained_model)
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
unexpected_keys = [k for k in state_dict.keys() if 'position_ids' in k]
for key in unexpected_keys:
del state_dict[key]
model.load_state_dict(state_dict)
else:
print("Loading untrained model")
model = model_class(config=config_obj)
model.resize_token_embeddings(len(tokenizer))
return config_obj, tokenizer, model