librarian-bot's picture
Librarian Bot: Add base_model information to model
8af45a9
|
raw
history blame
2.73 kB
metadata
language:
  - en
license: mit
library_name: peft
tags:
  - transformers
  - biology
  - esm
  - esm2
  - protein
  - protein language model
base_model: facebook/esm2_t12_35M_UR50D

ESM-2 RNA Binding Site LoRA

This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of the esm2_t12_35M_UR50D model for the (binary) token classification task of predicting RNA binding sites of proteins. The Github with the training script and conda env YAML can be found here. You can also find a version of this model that was fine-tuned without LoRA here.

Training procedure

This is a Low Rank Adaptation (LoRA) of esm2_t6_8M_UR50D, trained on 166 protein sequences in the RNA binding sites dataset using a 75/25 train/test split. It achieves an evaluation loss of 0.18801096081733704.

Framework versions

  • PEFT 0.4.0

Using the Model

To use, try running:

from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
import torch

# Path to the saved LoRA model
model_path = "AmelieSchreiber/esm2_t12_35M_LoRA_RNA_binding"
# ESM2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Load the model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
loaded_model = PeftModel.from_pretrained(base_model, model_path)

# Ensure the model is in evaluation mode
loaded_model.eval()

# Load the tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Protein sequence for inference
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your actual sequence

# Tokenize the sequence
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

# Run the model
with torch.no_grad():
    logits = loaded_model(**inputs).logits

# Get predictions
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # Convert input ids back to tokens
predictions = torch.argmax(logits, dim=2)

# Define labels
id2label = {
    0: "No binding site",
    1: "Binding site"
}

# Print the predicted labels for each token
for token, prediction in zip(tokens, predictions[0].numpy()):
    if token not in ['<pad>', '<cls>', '<eos>']:
        print((token, id2label[prediction]))