AmelieSchreiber's picture
Update README.md
7bbf93e
metadata
license: mit
language:
  - en
library_name: transformers
tags:
  - esm
  - esm-2
  - protein
  - binding-site
  - biology

ESM-2 for RNA Binding Site Prediction

A small RNA binding site predictor trained on dataset "S1" from Data of protein-RNA binding sites using facebook/esm2_t12_35M_UR50D. The dataset can also be found on Hugging Face here.

The model only has a validation loss of 0.12358924768426839.

To use, try running:

import torch
from transformers import AutoTokenizer, EsmForTokenClassification

# Define the class mapping
class_mapping = {
    0: 'Not Binding Site',
    1: 'Binding Site',
}

# Load the trained model and tokenizer
model = EsmForTokenClassification.from_pretrained("AmelieSchreiber/esm2_t12_35M_UR50D_rna_binding_site_predictor")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Define the new sequences
new_sequences = [
    'VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK',
    'SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF',
    # ... add more sequences here ...
]

# Iterate over the new sequences
for seq in new_sequences:
    # Convert sequence to input IDs
    inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")["input_ids"]

    # Apply the model to get the logits
    with torch.no_grad():
        outputs = model(inputs)

    # Get the predictions by picking the label (class) with the highest logit
    predictions = torch.argmax(outputs.logits, dim=-1)

    # Convert the tensor to a list of integers
    prediction_list = predictions.tolist()[0]

    # Convert the predicted class indices to class names
    predicted_labels = [class_mapping[pred] for pred in prediction_list]

    # Create a list that matches each amino acid in the sequence to its predicted class label
    residue_to_label = list(zip(list(seq), predicted_labels))

    # Print out the list
    for i, (residue, predicted_label) in enumerate(residue_to_label):
        print(f"Position {i+1} - {residue}: {predicted_label}")