Edit model card

ESM-2 RNA Binding Site LoRA

This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of the esm2_t12_35M_UR50D model for the (binary) token classification task of predicting RNA binding sites of proteins. You can also find a version of this model that was fine-tuned without LoRA here.

Training procedure

This is a Low Rank Adaptation (LoRA) of esm2_t12_35M_UR50D, trained on 166 protein sequences in the RNA binding sites dataset using a 85/15 train/test split. This model was trained with class weighting due to the imbalanced nature of the RNA binding site dataset (fewer binding sites than non-binding sites). This model has slightly improved precision, recall, and F1 score over AmelieSchreiber/esm2_t12_35M_weighted_lora_rna_binding but may suffer from mild overfitting, as indicated by the training loss being slightly lower than the eval loss (see metrics below). If you are searching for binding sites and aren't worried about false positives, the higher recall may make this model preferable to the other RNA binding site predictors.

You can train your own version using this notebook! You just need the RNA binding_sites.xml file found here. You may also need to run some pip install statements at the beginning of the script. If you are running in colab run:

!pip install transformers[torch] datasets peft -q
!pip install accelerate -U -q

Try to improve upon these metrics by adjusting the hyperparameters:

{'eval_loss': 0.500779926776886,
'eval_precision': 0.1708695652173913,
'eval_recall': 0.8397435897435898,
'eval_f1': 0.2839595375722543,
'eval_auc': 0.771835775620126,
'epoch': 11.0}
{'loss': 0.4171,
'learning_rate': 0.00032491416877500004,
'epoch': 11.43}

A similar model can also be trained using the Github with a training script and conda env YAML, which can be found here. This version uses wandb sweeps for hyperparameter search. However, it does not use class weighting.

Framework versions

  • PEFT 0.4.0

Using the Model

To use the model, try running the following pip install statements:

!pip install transformers peft -q

then try tunning:

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_UR50D_RNA_LoRA_weighted"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))
Downloads last month
Unable to determine this model’s pipeline type. Check the docs .