Predicting the Effects of Mutations on Protein Function with ESM-2

Community Article Published December 13, 2023

In the article Language models enable zero-shot prediction of the effects of mutations on protein function, the authors introduce several scoring functions to determine the effects of mutations on protein sequences. Here, we will re-implement these scoring methods using Hugging Face's port of the protein language model ESM-2 from the Transformers library. We will also discuss how to use each one and how to interpret them.

image/png

Introduction

Mutations of protein sequences can be quite complex and the effects they have on proteins can range from detrimental to function, to neutral and inconsequential, to causing improvement in function. It has been shown that even single point mutations or small numbers of mutations can cause drastic conformational changes, resulting in "fold-switching" and changes in the 3D structure of the folded protein. Judging the effects of mutations is difficult, but protein language models like the ESM-2 family of models can provide a lot of information on the effects of mutations on the fold and function of proteins.

In particular, in Language models enable zero-shot prediction of the effects of mutations on protein function the authors introduce several scoring functions with scores that highly correlate with effects on function. The first of these functions is the masked marginal scoring function:

iMlogp(xi=ximtxM)logp(xi=xiwtxM) \sum_{i \in M} \log p(x_i = x_i^{mt}|x_{-M}) - \log p(x_i = x_i^{wt}|x_{-M})

where MM are the masked residues where mutations occur, ximtx_i^{mt} is the mutant-type residue at position ii, and xiwtx_i^{wt} is the wild-type residue at position ii. This function was shown to perform best.

Log-likelihood Ratios and Point Mutations

We can also understand the effects of mutations using the log-likelihood ratios (LLR) for each single point mutation and represent the results in a heatmap which shows us hotspots for mutations that are beneficial or detrimental to the function of the protein. This is exemplified in the HuggingFace space ESM Variants, where the LLR is computed for all point mutation for human proteins. For general proteins, you can try out the HuggingFace Space Variant Effects LLR.

from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display

def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
    # Load the model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Tokenize the input sequence
    input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
    sequence_length = input_ids.shape[1] - 2  # Excluding the special tokens

    # Adjust end position if not specified
    if end_pos is None:
        end_pos = sequence_length

    # List of amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

    # Initialize heatmap
    heatmap = np.zeros((20, end_pos - start_pos + 1))

    # Calculate LLRs for each position and amino acid
    for position in range(start_pos, end_pos + 1):
        # Mask the target position
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, position] = tokenizer.mask_token_id
        
        # Get logits for the masked token
        with torch.no_grad():
            logits = model(masked_input_ids).logits
            
        # Calculate log probabilities
        probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
        log_probabilities = torch.log(probabilities)
        
        # Get the log probability of the wild-type residue
        wt_residue = input_ids[0, position].item()
        log_prob_wt = log_probabilities[wt_residue].item()
        
        # Calculate LLR for each variant
        for i, amino_acid in enumerate(amino_acids):
            log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
            heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt

    # Visualize the heatmap
    plt.figure(figsize=(15, 5))
    plt.imshow(heatmap, cmap="viridis", aspect="auto")
    plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
    plt.yticks(range(20), amino_acids)
    plt.xlabel("Position in Protein Sequence")
    plt.ylabel("Amino Acid Mutations")
    plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
    plt.colorbar(label="Log Likelihood Ratio (LLR)")
    plt.show()

def interactive_heatmap(protein_sequence):
    # Define interactive widgets
    start_slider = widgets.IntSlider(value=1, min=1, max=len(protein_sequence), step=1, description='Start:')
    end_slider = widgets.IntSlider(value=len(protein_sequence), min=1, max=len(protein_sequence), step=1, description='End:')

    ui = widgets.HBox([start_slider, end_slider])

    def update_heatmap(start, end):
        if start <= end:
            generate_heatmap(protein_sequence, start, end)

    out = widgets.interactive_output(update_heatmap, {'start': start_slider, 'end': end_slider})

    # Display the interactive widgets
    display(ui, out)

Below we see an example of how to use this:

# Example usage:
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
interactive_heatmap(protein_sequence)

This should return something with adjustable left and right end-points in case you want to zoom-in to a particular range of positions in the protein sequence. Below, we see the range of positions 40-70:

image/png

Notice the dark blue regions, where the LLR values are negative, indicating that mutations that are likely detrimental to function, and lighter yellow regions where the LLR values are positive, indicating mutations that are likely beneficial to the function of the protein. Also, note how there are dark bands running vertically indicating regions which are likely evolutionarily conserved, and brighter bands running vertically indicating regions of the protein which may in fact be preferable over the wild-type sequence. Note also, for some regions of the protein, there are amino acid mutations which are likely to be detrimental to functioning for entire regions of the protein, indicated by dark bands running horizontally along most of the protein. Similarly, we see brighter bands of yellow running horizontally, indicating almost any residue mutated to that amino acid would be preferential to the wild type. Once we have applied one of these mutations, we will get a different heatmap for the mutant protein. For example, mutating the D amino acid at residue 57, to an L, changes the heatmap. Visualizing residues 40-70, we see the following now:

image/png

Below, we see a figure from the paper showing how LLR heatmaps can suggest beneficial vs. deleterious mutations for protein function. Here, red represents the lower LLR values and blue represents the higher LLR values, so the image is somewhat inverted from the above where blue represents higher LLR values and yellow represents the lower LLR values.

image/png

Scoring Deep Mutational Scans

We can modify the script found in the ESM Github repository (see the predict.py file) for zero-shot scoring of variant effects. In this script, ESM-1v is used, but we will use the newer ESM-2 family of models. Below, we have a script for scoring the effects of mutations with three different scoring methods, pseudo-perplexity (PPPL), wild-type marginal (wt-marginal), and masked marginal. The script requires a CSV file with a column indicating what mutations to apply to the wild-type protein sequence. The script creates an output.csv file with the scores predicted by the model once the user has chosen a model and scoring method:

import argparse
import pathlib
import string
import torch
from esm import Alphabet, pretrained, MSATransformer
import pandas as pd
from tqdm import tqdm
from Bio import SeqIO
import itertools

def remove_insertions(sequence: str) -> str:
    deletekeys = dict.fromkeys(string.ascii_lowercase)
    deletekeys["."] = None
    deletekeys["*"] = None
    translation = str.maketrans(deletekeys)
    return sequence.translate(translation)

def create_parser():
    parser = argparse.ArgumentParser(description="Label a deep mutational scan with predictions from an ensemble of ESM-1v models.")
    parser.add_argument("--model-location", type=str, help="PyTorch model file OR name of pretrained model to download", nargs="+")
    parser.add_argument("--sequence", type=str, help="Base sequence to which mutations were applied")
    parser.add_argument("--dms-input", type=pathlib.Path, help="CSV file containing the deep mutational scan")
    parser.add_argument("--mutation-col", type=str, default="mutant", help="column in the deep mutational scan labeling the mutation as 'AiB'")
    parser.add_argument("--dms-output", type=pathlib.Path, help="Output file containing the deep mutational scan along with predictions")
    parser.add_argument("--offset-idx", type=int, default=0, help="Offset of the mutation positions in `--mutation-col`")
    parser.add_argument("--scoring-strategy", type=str, default="wt-marginals", choices=["wt-marginals", "pseudo-ppl", "masked-marginals"], help="")
    parser.add_argument("--msa-path", type=pathlib.Path, help="path to MSA in a3m format (required for MSA Transformer)")
    parser.add_argument("--msa-samples", type=int, default=400, help="number of sequences to select from the start of the MSA")
    parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available")
    return parser

def label_row(row, sequence, token_probs, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"
    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()

def compute_pppl(row, sequence, model, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"
    sequence = sequence[:idx] + mt + sequence[(idx + 1):]
    data = [("protein1", sequence)]
    batch_converter = alphabet.get_batch_converter()
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    log_probs = []
    for i in range(1, len(sequence) - 1):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
        log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item())
    return sum(log_probs)

def main(args):
    df = pd.read_csv(args.dms_input)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.nogpu else "cpu")

    for model_location in args.model_location:
        model, alphabet = pretrained.load_model_and_alphabet(model_location)
        model = model.to(device)
        model.eval()
        batch_converter = alphabet.get_batch_converter()

        if isinstance(model, MSATransformer):
            data = [read_msa(args.msa_path, args.msa_samples)]
            assert args.scoring_strategy == "masked-marginals", "MSA Transformer only supports masked marginal strategy"
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            batch_tokens = batch_tokens.to(device)

            all_token_probs = []
            for i in tqdm(range(batch_tokens.size(2))):
                batch_tokens_masked = batch_tokens.clone()
                batch_tokens_masked[0, 0, i] = alphabet.mask_idx
                with torch.no_grad():
                    token_probs = torch.log_softmax(model(batch_tokens_masked)["logits"], dim=-1)
                all_token_probs.append(token_probs[:, 0, i])
            token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
            df[model_location] = df.apply(
                lambda row: label_row(row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx),
                axis=1,
            )
        else:
            data = [("protein1", args.sequence)]
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            batch_tokens = batch_tokens.to(device)

            if args.scoring_strategy == "wt-marginals":
                with torch.no_grad():
                    token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1)
                df[model_location] = df.apply(
                    lambda row: label_row(row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx),
                    axis=1,
                )
            elif args.scoring_strategy == "masked-marginals":
                all_token_probs = []
                for i in tqdm(range(batch_tokens.size(1))):
                    batch_tokens_masked = batch_tokens.clone()
                    batch_tokens_masked[0, i] = alphabet.mask_idx
                    with torch.no_grad():
                        token_probs = torch.log_softmax(model(batch_tokens_masked)["logits"], dim=-1)
                    all_token_probs.append(token_probs[:, i])
                token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
                df[model_location] = df.apply(
                    lambda row: label_row(row[args.mutation_col], args.sequence, token_probs, alphabet, args.offset_idx),
                    axis=1,
                )
            elif args.scoring_strategy == "pseudo-ppl":
                tqdm.pandas()
                df[model_location] = df.progress_apply(
                    lambda row: compute_pppl(row[args.mutation_col], args.sequence, model, alphabet, args.offset_idx),
                    axis=1,
                )

    df.to_csv(args.dms_output)

if __name__ == "__main__":
    parser = create_parser()
    args = parser.parse_args()
    main(args)

Example Usage

In your terminal, you can run the script as follows:

python scoring_esm2.py \  
    --model-location esm2_t12_35M_UR50D \
    --sequence "MKTIIALSYIFCLVFA" \
    --dms-input "mutations.csv" \
    --mutation-col "mutant" \
    --dms-output "output_2.csv" \
    --offset-idx 0 \
    --scoring-strategy "masked-marginals" \
    --nogpu

Adjust the model, sequence, mutations file, and scoring strategy to suit your needs, but remember, the masked marginal scoring strategy was shown to perform best. This should create an output.csv file that looks like the following:

,mutant,esm2_t12_35M_UR50D
0,T2B,-10.990091323852539
1,I3A,-0.5448870658874512
2,A5M,-0.8617167472839355

We can see some of these mutations are individually more detrimental than others, where the TBT \to B mutation at the third residue (or second since we are beginning our indices with 0 here) has a much lower score than the other two mutations, which are closer to zero, indicating they are more neutral. Using the LLR values, we can select point mutations that may provide improved function too. When using the wild-type marginal scoring strategy, we get the following:

,mutant,esm2_t12_35M_UR50D
0,T2B,-13.739639282226562
1,I3A,-3.976250171661377
2,A5M,-4.413556098937988

As we can see, the wt-marginal and masked marginal scoring strategies seem to both indicate that the effects of the mutations are deleterious to the function of the protein, but the first mutations is much more detrimental than the second and third. Running the script on the protein from above

MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE

for the mutation D56L (for the protein sequence mentioned in the above section on LLR heatmaps) we get the following:

,mutant,facebook/esm2_t12_35M_UR50D
0,D56L,1.3842720985412598

This confirms the LLR prediction, and shows that the mutation is likely beneficial to the functioning of the protein and that the mutant is likely more fit than the wild-type. This kind of scoring can be used to determine the directionality of protein evolution, providing a kind of vector field or flow description of evolution. This is done for example in Evolocity, where patterns in the evolution of proteins are found. Evolocity was introduced in Evolutionary velocity with protein language models predicts evolutionary dynamics of diverse proteins, and uses the slightly older ESM-1b protein language model, but these methods could be adapted to ESM-2, or other protein language models.

"The key conceptual advance is that by learning the rules underlying local evolution, we can construct a global evolutionary “vector field” that we show can (1) predict the root (or potentially multiple roots) of observed evolutionary trajectories, (2) order protein sequences in evolutionary time, and (3) identify the mutational strategies that drive these trajectories."

It would be interesting to see if using the masked marginal score instead of the pseudolikelihood scores provides significantly different or improved results. We should also note, the implementation of PPPL above in the script does not match the suggested definition in Masked Language Model Scoring in section 2.3 exactly. Here, the authors define PPPL as:

PPPL(T)=exp(1NtTPLL(t)) PPPL(T) = \exp\left(- \frac{1}{N} \sum_{t \in T} PLL(t)\right)

where PLL(t)PLL(t) denotes the pseudo-log-likelihood for token tTt \in T. Thus, we would typically see something more like the following as an implementation of PPPL:

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

def calculate_pppl(model, tokenizer, sequence):
    token_ids = tokenizer.encode(sequence, return_tensors='pt')
    input_length = token_ids.size(1)
    log_likelihood = 0.0

    for i in range(input_length):
        # Create a copy of the token IDs
        masked_token_ids = token_ids.clone()
        # Mask a token that we will try to predict back
        masked_token_ids[0, i] = tokenizer.mask_token_id

        with torch.no_grad():
            output = model(masked_token_ids)
            logit_prob = torch.nn.functional.log_softmax(output.logits, dim=-1)
        
        log_likelihood += logit_prob[0, i, token_ids[0, i]]

    # Calculate the average log likelihood per token
    avg_log_likelihood = log_likelihood / input_length

    # Compute and return the pseudo-perplexity
    pppl = torch.exp(-avg_log_likelihood)
    return pppl.item()

# Load the model and tokenizer
model_name = "facebook/esm2_t12_35M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Protein sequence
protein_sequence = "MKTIIALSYIFCLVFA"

# Calculate PPPL
pppl = calculate_pppl(model, tokenizer, protein_sequence)
print(f"Pseudo-Perplexity of the sequence: {pppl}")

which would return a single value for the entire sequence:

Pseudo-Perplexity of the sequence: 9.073078155517578

Conclusion

In this article, we explored the capabilities of the ESM-2 models, particularly in predicting the effects of mutations on protein functions. Utilizing advanced language models in bioinformatics, we implemented several scoring methods including the masked marginal scoring function, pseudo-perplexity (PPPL), and wild-type marginal (wt-marginal) scoring. These methods allow for a deeper understanding of how mutations can affect protein structure and function, providing valuable insights for research in protein engineering and disease analysis.

The masked marginal scoring function, particularly, stands out due to its significant correlation with functional effects. By calculating the log probability differences between wild-type and mutant residues, this function provides a quantitative measure of the impact of mutations. The sum of scores from individual mutations gives an overall estimate of their collective effect, offering a convenient way to assess multiple mutations simultaneously.

Our script facilitates easy integration and application of these scoring methods. By inputting a protein sequence and a list of desired mutations, users can swiftly obtain scores indicating potential functional changes. The script's flexibility in choosing between different ESM models and scoring strategies allows for tailored analysis suited to specific research needs. Furthermore, the interactive heatmap visualization, implemented using ipywidgets, provides an intuitive graphical representation of the mutational landscape. By highlighting regions of potential functional significance, researchers can quickly identify key areas for further investigation.

In summary, the tools and methods discussed in this article provide a powerful approach for assessing the impact of protein mutations. They offer a blend of computational efficiency and biological insight, which is invaluable in the rapidly evolving field of bioinformatics. As these models continue to improve, we can expect even more accurate and insightful predictions, further advancing our understanding of protein dynamics and disease mechanisms.