--- library_name: transformers tags: [] --- # Model Card for Model ID ProtST for binary localization ## Running script ```python from transformers import AutoModel, AutoTokenizer, HfArgumentParser, TrainingArguments, Trainer from transformers.data.data_collator import DataCollatorWithPadding from transformers.trainer_pt_utils import get_parameter_names from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from datasets import load_dataset import functools import numpy as np from sklearn.metrics import accuracy_score, matthews_corrcoef import sys import torch import logging import datasets import transformers logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def create_optimizer(opt_model, lr_ratio=0.1): head_names = [] for n, p in opt_model.named_parameters(): if "classifier" in n: head_names.append(n) else: p.requires_grad = False # turn a list of tuple to 2 lists for n, p in opt_model.named_parameters(): if n in head_names: assert p.requires_grad backbone_names = [] for n, p in opt_model.named_parameters(): if n not in head_names and p.requires_grad: backbone_names.append(n) # for weight_decay policy, see # https://github.com/huggingface/transformers/blob/50573c648ae953dcc1b94d663651f07fb02268f4/src/transformers/trainer.py#L947 decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) # forbidden layer norm decay_parameters = [name for name in decay_parameters if "bias" not in name] # training_args.learning_rate head_decay_parameters = [name for name in head_names if name in decay_parameters] head_not_decay_parameters = [name for name in head_names if name not in decay_parameters] # training_args.learning_rate * model_config.lr_ratio backbone_decay_parameters = [name for name in backbone_names if name in decay_parameters] backbone_not_decay_parameters = [name for name in backbone_names if name not in decay_parameters] optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if (n in head_decay_parameters and p.requires_grad)], "weight_decay": training_args.weight_decay, "lr": training_args.learning_rate }, { "params": [p for n, p in opt_model.named_parameters() if (n in backbone_decay_parameters and p.requires_grad)], "weight_decay": training_args.weight_decay, "lr": training_args.learning_rate * lr_ratio }, { "params": [p for n, p in opt_model.named_parameters() if (n in head_not_decay_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": training_args.learning_rate }, { "params": [p for n, p in opt_model.named_parameters() if (n in backbone_not_decay_parameters and p.requires_grad)], "weight_decay": 0.0, "lr": training_args.learning_rate * lr_ratio }, ] optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return optimizer def create_scheduler(training_args, optimizer): from transformers.optimization import get_scheduler return get_scheduler( training_args.lr_scheduler_type, optimizer=optimizer if optimizer is None else optimizer, num_warmup_steps=training_args.get_warmup_steps(training_args.max_steps), num_training_steps=training_args.max_steps, ) def compute_metrics(eval_preds): probs, labels = eval_preds preds = np.argmax(probs, axis=-1) result = {"accuracy": accuracy_score(labels, preds), "mcc": matthews_corrcoef(labels, preds)} return result def preprocess_logits_for_metrics(logits, labels): return torch.softmax(logits, dim=-1) if __name__ == "__main__": device = torch.device("cpu") raw_dataset = load_dataset("Jiqing/ProtST-BinaryLocalization") model = AutoModel.from_pretrained("Jiqing/protst-esm1b-for-sequential-classification", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") output_dir = "/home/jiqingfe/protst/protst_2/ProtST-HuggingFace/output_dir/ProtSTModel/default/ESM-1b_PubMedBERT-abs/240123_015856" training_args = {'output_dir': output_dir, 'overwrite_output_dir': True, 'do_train': True, 'per_device_train_batch_size': 32, 'gradient_accumulation_steps': 1, \ 'learning_rate': 5e-05, 'weight_decay': 0, 'num_train_epochs': 100, 'max_steps': -1, 'lr_scheduler_type': 'constant', 'do_eval': True, \ 'evaluation_strategy': 'epoch', 'per_device_eval_batch_size': 32, 'logging_strategy': 'epoch', 'save_strategy': 'epoch', 'save_steps': 820, \ 'dataloader_num_workers': 0, 'run_name': 'downstream_esm1b_localization_fix', 'optim': 'adamw_torch', 'resume_from_checkpoint': False, \ 'label_names': ['labels'], 'load_best_model_at_end': True, 'metric_for_best_model': 'accuracy', 'bf16': True, "save_total_limit": 3} training_args = HfArgumentParser(TrainingArguments).parse_dict(training_args, allow_extra_keys=False)[0] def tokenize_protein(example, tokenizer=None): protein_seq = example["prot_seq"] protein_seq_str = tokenizer(protein_seq, add_special_tokens=True) example["input_ids"] = protein_seq_str["input_ids"] example["attention_mask"] = protein_seq_str["attention_mask"] example["labels"] = example["localization"] return example func_tokenize_protein = functools.partial(tokenize_protein, tokenizer=tokenizer) for split in ["train", "validation", "test"]: raw_dataset[split] = raw_dataset[split].map(func_tokenize_protein, batched=False, remove_columns=["Unnamed: 0", "prot_seq", "localization"]) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) transformers.utils.logging.set_verbosity_info() log_level = training_args.get_process_log_level() logger.setLevel(log_level) optimizer = create_optimizer(model) scheduler = create_scheduler(training_args, optimizer) # build trainer trainer = Trainer( model=model, args=training_args, train_dataset=raw_dataset["train"], eval_dataset=raw_dataset["validation"], data_collator=data_collator, optimizers=(optimizer, scheduler), compute_metrics=compute_metrics, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) train_result = trainer.train() trainer.save_model() # Saves the tokenizer too for easy upload tokenizer.save_pretrained(training_args.output_dir) metrics = train_result.metrics metrics["train_samples"] = len(raw_dataset["train"]) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) trainer.save_state() metric = trainer.evaluate(raw_dataset["test"], metric_key_prefix="test") print("test metric: ", metric) metric = trainer.evaluate(raw_dataset["validation"], metric_key_prefix="valid") print("valid metric: ", metric) ```