ESMBind (ESMB) Ensemble Models

Community Article Published September 22, 2023

TLDR: In this post we are going to discuss how to build a basic ensemble model using ESMBind (ESMB) models. We will employ both a "hard" and "soft" voting strategy. We will show you how to compte the train and test metrics on a preprocessed train/test split dataset of protein sequences. Please note the following is purely for demonstration purposes only. These models are not well tested, and appear to be overfit (see the Precision, F1 Score, and MCC below).

image/png

Introduction

Note, due to memory constraints imposed by using ensembles, you may need to run this code example locally or on a Google Colab Pro instance. You might also try a Kaggle notebook using the P100 GPU. Another option would be to use our previous post to train two or more smaller ESMB models using esm2_t6_8M_UR50D as the base model. Recall, we showed how to finetune a binding site predictor using Low Rank Adaptation (LoRA) in this post. We will recall some of the information here, but you should probably read over that post first before continuing unless you are already familiar with LoRA and the basics of ensembe models. Also, note that this post is purely for demonstration purposes only. To get a better ensemble, you should train your own models using the examples given in the previous post with different hyperparameters.

ESMBind (or ESMB) is a collection of finetuned models that use Low Rank Adaptation (LoRA) on top of the base model ESM-2, finetuned for predicting binding sites of proteins based only on a single protein's sequence. It does not require Multiple Sequence Alignment (MSA) or any structural information about the protein's 3D fold or backbone structure. This makes ESMB models accessible, simple to use, and they require less domain knowledge to apply and understand, making them more interpetable. However, this may come at the cost of performance.

Remember, we showed how to use Low Rank Adaptation (LoRA) of the protein language model (pLM) ESM-2 in the post linked to above. LoRA is a technique that was shown to help dramatically improve overfitting in the pLM esm2_t12_35M_UR50D (see also ESM on Hugging Face). This also allows us to finetune larger models in a parameter efficient way. Below, we are going to provide you with code to both get the train/test metrics on a preprocessed dataset used as the train/test split for the individual models in the ensemble, as well as code to run inference on your own protein sequences.

Train/Test Datasets

Before getting started, download the following pickle files, then adjust the paths below in the code to match your local file paths.

Getting the Train/Test Metrics on a Large Preprocessed Dataset

Note, this will run in a Google Colab or Kaggle instance. However, if using the free GPU available in Colab, the first part of the code will require several hours to run. The inference part of the code for testing an enseble model on a single protein sequence or a small collection of proteins will run in seconds. So, if you only want to test the model on a few protein sequences, you can skip to the last section on "Inference".

Step 0: Installation and Imports

!pip install transformers -q 
!pip install datasets -q 
!pip install accelerate -q 
!pip install scipy -q
!pip install scikit-learn -q
!pip install peft -q 
import os
import pickle
import numpy as np
from scipy import stats
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from transformers import AutoModelForTokenClassification, Trainer, AutoTokenizer, DataCollatorForTokenClassification
from datasets import Dataset, concatenate_datasets
from accelerate import Accelerator
from peft import PeftModel
import gc

Step 1: Loading Data

In this step, you are loading sequences and labels for both training and test datasets from pickle files. These datasets are used to train and evaluate your models respectively.

# Step 1: Load train/test data and labels from pickle files
with open("/content/drive/MyDrive/train_sequences_chunked_by_family.pkl", "rb") as f:
    train_sequences = pickle.load(f)
with open("/content/drive/MyDrive/test_sequences_chunked_by_family.pkl", "rb") as f:
    test_sequences = pickle.load(f)
with open("/content/drive/MyDrive/train_labels_chunked_by_family.pkl", "rb") as f:
    train_labels = pickle.load(f)
with open("/content/drive/MyDrive/test_labels_chunked_by_family.pkl", "rb") as f:
    test_labels = pickle.load(f)

Step 2: Batch Tokenization and Dataset Creation

In this step, the sequences are tokenized using a pre-trained tokenizer. Tokenization is the process of converting input text into tokens, which are integer values. The tokenized sequences and labels are then used to create datasets.

# Step 2: Define the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
max_sequence_length = tokenizer.model_max_length

Step 3: Compute Metrics in Batches to Save Memory

# Step 3: Define a `compute_metrics_for_batch` function.
def compute_metrics_for_batch(sequences_batch, labels_batch, models, voting='hard'):
    # Tokenize batch
    batch_tokenized = tokenizer(sequences_batch, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
    
    batch_dataset = Dataset.from_dict({k: v for k, v in batch_tokenized.items()})
    batch_dataset = batch_dataset.add_column("labels", labels_batch[:len(batch_dataset)])
    
    # Convert labels to numpy array of shape (1000, 1002)
    labels_array = np.array([np.pad(label, (0, 1002 - len(label)), constant_values=-100) for label in batch_dataset["labels"]])
    
    # Initialize a trainer for each model
    data_collator = DataCollatorForTokenClassification(tokenizer)
    trainers = [Trainer(model=model, data_collator=data_collator) for model in models]
    
    # Get the predictions from each model
    all_predictions = [trainer.predict(test_dataset=batch_dataset)[0] for trainer in trainers]
    
    if voting == 'hard':
        # Hard voting
        hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions]
        ensemble_predictions = stats.mode(hard_predictions, axis=0)[0][0]
    elif voting == 'soft':
        # Soft voting
        avg_predictions = np.mean(all_predictions, axis=0)
        ensemble_predictions = np.argmax(avg_predictions, axis=2)
    else:
        raise ValueError("Voting must be either 'hard' or 'soft'")
    
    print("Shape of ensemble_predictions:", ensemble_predictions.shape)  # Debug print
    
    # Use broadcasting to create 2D mask
    mask_2d = labels_array != -100
    
    # Filter true labels and predictions using the mask
    true_labels_list = [label[mask_2d[idx]] for idx, label in enumerate(labels_array)]
    true_labels = np.concatenate(true_labels_list)
    flat_predictions_list = [ensemble_predictions[idx][mask_2d[idx]] for idx in range(ensemble_predictions.shape[0])]
    flat_predictions = np.concatenate(flat_predictions_list).tolist()

    # Compute the metrics
    accuracy = accuracy_score(true_labels, flat_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
    auc = roc_auc_score(true_labels, flat_predictions)
    mcc = matthews_corrcoef(true_labels, flat_predictions)  # Compute MCC
    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}

Step 4: Define a Function to Evaluate in Batches

#Step 4: Evaluate in Batches
def evaluate_in_batches(sequences, labels, models, voting='hard', batch_size=1000):
    num_batches = len(sequences) // batch_size + int(len(sequences) % batch_size != 0)
    metrics_list = []
    
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        batch_metrics = compute_metrics_for_batch(sequences[start_idx:end_idx], labels[start_idx:end_idx], models, voting)
        
        # Print metrics for the first five batches
        if i < 5:
            print(f"Batch {i+1}/{num_batches} metrics: {batch_metrics}")
        
        metrics_list.append(batch_metrics)
    
    # Average metrics over all batches
    avg_metrics = {key: np.mean([metrics[key] for metrics in metrics_list]) for key in metrics_list[0]}
    return avg_metrics

Step 5: Define the Ensemble Model

# Load pre-trained base model and fine-tuned LoRA models
accelerator = Accelerator()
base_model_path = "facebook/esm2_t12_35M_UR50D"
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
lora_model_paths = [
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1",
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1",
    # Add more models or swap out for your own models
]
models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths]
models = [accelerator.prepare(model) for model in models]

Step 6: Ensemble Voting and Metric Calculation

# Step 5: Compute and print the metrics
train_metrics_hard = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='hard')
test_metrics_hard = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='hard')
train_metrics_soft = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='soft')
test_metrics_soft = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='soft')

train_metrics_hard, test_metrics_hard, train_metrics_soft, test_metrics_soft

This will then print something like the following:

train - Batch 1/451 metrics: {'accuracy': 0.9907783025067246, 'precision': 0.7792440817271516, 'recall': 0.9714265098491954, 'f1': 0.8647867420349434, 'auc': 0.9814053346312887, 'mcc': 0.8656769123429833}

train - Batch 2/451 metrics: {'accuracy': 0.9906862419735746, 'precision': 0.7686626071267478, 'recall': 0.9822046109510086, 'f1': 0.8624114372469636, 'auc': 0.9865753167670478, 'mcc': 0.8645747724704963}

train - Batch 3/451 metrics: {'accuracy': 0.9907034630406232, 'precision': 0.7662082514734774, 'recall': 0.9884141926140478, 'f1': 0.8632411067193676, 'auc': 0.9895938451445732, 'mcc': 0.8659743174909746}

train - Batch 4/451 metrics: {'accuracy': 0.991028787153535, 'precision': 0.7751964275620372, 'recall': 0.9881115354132142, 'f1': 0.8687994931897371, 'auc': 0.9896153675458282, 'mcc': 0.871052392709521}

train - Batch 5/451 metrics: {'accuracy': 0.9901174908557153, 'precision': 0.7585922916437905, 'recall': 0.9865762227775794, 'f1': 0.8576926658183058, 'auc': 0.988401969496207, 'mcc': 0.8605718730416185}

There will then be a long wait for the train batches to finish, and then the first five test batch metrics will be printed, which will look similar to the train metrics.

test - Batch 1/114 metrics: {'accuracy': 0.9410464672512716, 'precision': 0.37514282087088996, 'recall': 0.8439481350317016, 'f1': 0.5194051887787388, 'auc': 0.8944018149939027, 'mcc': 0.5392923907809524}

test - Batch 2/114 metrics: {'accuracy': 0.938214353140821, 'precision': 0.361414131305044, 'recall': 0.8304587788892721, 'f1': 0.5036435270736724, 'auc': 0.886450001724052, 'mcc': 0.5233747173742583}

test - Batch 3/114 metrics: {'accuracy': 0.9411384591024733, 'precision': 0.3683750578316969, 'recall': 0.8300225864365552, 'f1': 0.5102807398572268, 'auc': 0.8877119446522322, 'mcc': 0.5294666106367614}

test - Batch 4/114 metrics: {'accuracy': 0.9403683315585174, 'precision': 0.369614054572532, 'recall': 0.8394290300389818, 'f1': 0.5132402166102942, 'auc': 0.8918623875782199, 'mcc': 0.5334084101768152}

test - Batch 5/114 metrics: {'accuracy': 0.9400765476285562, 'precision': 0.37219051467245823, 'recall': 0.8356296422294041, 'f1': 0.514999563204333, 'auc': 0.8899200984461443, 'mcc': 0.5337721026971387}

This will repeat for the soft voting strategy as well. After another long wait for each of the train and test batches for the soft voting strategy, you should get the average of all of the batches printed for both the train and test metrics.

Inference

Lastly, we can run inference on a protein of interest as in the code below. This can be run independently of the rest of the code in this post and should only take a few seconds.

from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer
from datasets import Dataset
from peft import PeftModel
import numpy as np
from scipy import stats

# ESM-2 base model
base_model_path = "facebook/esm2_t12_35M_UR50D"

# Paths to the saved LoRA models
lora_model_paths = [
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3",
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1",
    "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1",
    # add paths to other models
]

# Load the base model
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)

# Load the models
models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths]

# Define the new protein sequence
new_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"

# Step 1 and 2: Tokenization and Dataset creation
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
tokenized_inputs = tokenizer(new_sequence, return_tensors="pt", truncation=True, padding=True, is_split_into_words=False)
new_dataset = Dataset.from_dict({k: v for k, v in tokenized_inputs.items()})

# Step 3: Create trainer objects for each model in the ensemble
data_collator = DataCollatorForTokenClassification(tokenizer)
trainers = [Trainer(model=model, data_collator=data_collator) for model in models]

# Step 4: Getting predictions from each model and applying voting strategies
all_predictions = [trainer.predict(test_dataset=new_dataset)[0] for trainer in trainers]

# Hard voting
hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions]
ensemble_predictions_hard = stats.mode(hard_predictions, axis=0)[0][0]

# Soft voting
avg_predictions = np.mean(all_predictions, axis=0)
ensemble_predictions_soft = np.argmax(avg_predictions, axis=2)

# Print the final predictions obtained using hard and soft voting
print("Hard voting predictions:", ensemble_predictions_hard)
print("Soft voting predictions:", ensemble_predictions_soft)

This will print something like the following:

Hard voting predictions: [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 1 1 0 0 1 1 1 1 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0]
Soft voting predictions: [[0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 1 1 0 0 1 1 1 1 1 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0]]

Here, 1 represents the binding sites predicted by the ensemble model, and 0 represents non-binding sites predicted by the ensemble model. Next, to get something more amenable for designing a binding partner for your protein, run the following code:

# Convert token IDs back to amino acid residues
residues = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"][0])

# Print the amino acid residues and their positions for binding sites using hard voting
binding_sites_hard = [(idx, residue) for idx, (label, residue) in enumerate(zip(ensemble_predictions_hard[0], residues)) if label == 1]
print("Binding sites (Hard voting):")
for position, residue in binding_sites_hard:
    print(f"{residue}{position}")

# Print the amino acid residues and their positions for binding sites using soft voting
binding_sites_soft = [(idx, residue) for idx, (label, residue) in enumerate(zip(ensemble_predictions_soft[0], residues)) if label == 1]
print("\nBinding sites (Soft voting):")
for position, residue in binding_sites_soft:
    print(f"{residue}{position}")

This will print something like the following:

Binding sites (Hard voting):
P8
N9
H10
I12
Y13
I14
N15
N16
L17
N18
E19
K20
K22
F34
G38
L41
L44
V45
S46
R47
S48
L49
K50
M51
R52
G53
Q54
A55
F59
Q73
G74
Y78
D79
K80
P81
M82
I84
Q85
Y86
A87
K88
T89
D90

Binding sites (Soft voting):
P8
N9
H10
I12
Y13
I14
N15
N16
L17
N18
E19
K20
K22
F34
G38
L41
L44
V45
S46
R47
S48
L49
K50
M51
R52
G53
Q54
A55
F59
Q73
G74
Y78
D79
K80
P81
M82
I84
Q85
Y86
A87
K88
T89
D90

Designing a Binder for your Protein with RFDiffusion

RFDiffusion is a diffusion model that generates 3D protein structures. This is conceptually similar to diffusion models like Stable Diffusion and Dall-E, but for proteins. It's architecture is different from stable diffusion (using RosettaFold as the backbone model as apposed to the UNet that is used in StableDiffusion).

Once you have your binding site predictions, you should head over to the RFDiffusion Notebook and design a binder for your protein using some subset of the binding sites predicted by the model as "hotspots" for the binder. You'll need a PDB file for your protein first. To get one, head over to the ESMFold tool at the ESM Metagenomic Atlas website. Select "Fold Sequence", and paste in your protein sequence to fold it and press enter. Once your protein is folded you should get a 3D structure:

image/png

You can now download your PDB file. Once you have it, upload it to the RFDiffusion Google Colab notebook and use the path to your uploaded PDB file in the RFDiffusion notebook for designing a binder to your protein. With the following settings:

%%time
#@title run **RFdiffusion** to generate a backbone
name = "test" #@param {type:"string"}
contigs = "100" #@param {type:"string"}
pdb = "/content/unnamed.pdb" #@param {type:"string"}
iterations = 50 #@param ["25", "50", "100", "150", "200"] {type:"raw"}
hotspot = "A41,A44,A45,A46" #@param {type:"string"}
num_designs = 1 #@param ["1", "2", "4", "8", "16", "32"] {type:"raw"}
visual = "interactive" #@param ["none", "image", "interactive"]
#@markdown ---
#@markdown **symmetry** settings
#@markdown ---
symmetry = "cyclic" #@param ["none", "auto", "cyclic", "dihedral"]
order = 3 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"] {type:"raw"}
chains = "" #@param {type:"string"}
add_potential = True #@param {type:"boolean"}
#@markdown - `symmetry='auto'` enables automatic symmetry dectection with [AnAnaS](https://team.inria.fr/nano-d/software/ananas/).
#@markdown - `chains="A,B"` filter PDB input to these chains (may help auto-symm detector)
#@markdown - `add_potential` to discourage clashes between chains

# determine where to save
path = name
while os.path.exists(f"outputs/{path}_0.pdb"):
  path = name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))

flags = {"contigs":contigs,
         "pdb":pdb,
         "order":order,
         "iterations":iterations,
         "symmetry":symmetry,
         "hotspot":hotspot,
         "path":path,
         "chains":chains,
         "add_potential":add_potential,
         "num_designs":num_designs,
         "visual":visual}

for k,v in flags.items():
  if isinstance(v,str):
    flags[k] = v.replace("'","").replace('"','')

contigs, copies = run_diffusion(**flags)

You'll get a cyclic protein like the following:

image/png

You can run the rest of the RFDiffusion Colab notebook to get a sequence that folds to the structure you've generated and validate it. That's it! You've succesfully designed a protein that is predicted to bind to your protein of interest along the "hotspots", that is along the sites of interest given by selecting a subset of the binding sited predicted by an ESMBind model or ensemble of models. Be sure give the RFDiffusion paper linked to on the RFDiffusion Github a read, and send the RFDiffusion people some love by giving their Github a star. They've built an amazing protein diffusion model! You can also interact with more protein related models, including RFDiffusion on Neurosnap.