--- license: mit language: - en library_name: transformers tags: - esm - esm-2 - protein - binding-site - biology --- # ESM-2 for RNA Binding Site Prediction A small RNA binding site predictor trained on dataset "S1" from [Data of protein-RNA binding sites](https://www.sciencedirect.com/science/article/pii/S2352340916308022#s0035) using [facebook/esm2_t12_35M_UR50D](https://huggingface.co/facebook/esm2_t12_35M_UR50D). The dataset can also be found on Hugging Face [here](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites). The model only has a validation loss of `0.12358924768426839`. To use, try running: ```python3 import torch from transformers import AutoTokenizer, EsmForTokenClassification # Define the class mapping class_mapping = { 0: 'Not Binding Site', 1: 'Binding Site', } # Load the trained model and tokenizer model = EsmForTokenClassification.from_pretrained("AmelieSchreiber/esm2_t12_35M_UR50D_rna_binding_site_predictor") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") # Define the new sequences new_sequences = [ 'VLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTK', 'SQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWF', # ... add more sequences here ... ] # Iterate over the new sequences for seq in new_sequences: # Convert sequence to input IDs inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=1290, return_tensors="pt")["input_ids"] # Apply the model to get the logits with torch.no_grad(): outputs = model(inputs) # Get the predictions by picking the label (class) with the highest logit predictions = torch.argmax(outputs.logits, dim=-1) # Convert the tensor to a list of integers prediction_list = predictions.tolist()[0] # Convert the predicted class indices to class names predicted_labels = [class_mapping[pred] for pred in prediction_list] # Create a list that matches each amino acid in the sequence to its predicted class label residue_to_label = list(zip(list(seq), predicted_labels)) # Print out the list for i, (residue, predicted_label) in enumerate(residue_to_label): print(f"Position {i+1} - {residue}: {predicted_label}") ```