metadata
library_name: peft
license: mit
language:
- en
tags:
- transformers
- biology
- esm
- esm2
- protein
- protein language model
ESM-2 RNA Binding Site LoRA
This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of the esm2_t6_8M_UR50D model for the (binary) token classification task of predicting RNA binding sites of proteins. The Github with the training script and conda env YAML can be found here. You can also find a version of this model that was fine-tuned without LoRA here.
Training procedure
This is a Low Rank Adaptation (LoRA) of esm2_t6_8M_UR50D
,
trained on 166
protein sequences in the RNA binding sites dataset
using a 80/2
train/test split. It was also trained using class weighting due to the imbalanced nature
of the RNA binding sites data (fewer binding sites than binding sites).
'eval_loss': 0.6286943554878235,
'eval_precision': 0.17191601049868765,
'eval_recall': 0.4628975265017668,
'eval_f1': 0.2507177033492823,
'eval_auc': 0.6792352754560965,
'epoch': 12.0
'learning_rate': 0.000433916636807784
Framework versions
- PEFT 0.4.0
Using the Model
To use, try running:
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch
# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding"
# ESM2 base model
base_model_path = "facebook/esm2_t6_8M_UR50D"
# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence
# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
# Run the model
with torch.no_grad():
logits = loaded_model(**inputs).logits
# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)
# Define labels
id2label = {
0: "No binding site",
1: "Binding site"
}
# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
if token not in ['<pad>', '<cls>', '<eos>']:
print((token, id2label[prediction]))