Predicting Protein-Protein Interactions Using a Protein Language Model and Linear Sum Assignment

Community Article Published October 15, 2023

TLDR: This blog post is about using ESM-2, a protein language model, to score pairs of proteins using masked language modeling loss, in order to predict pairs of proteins that have a high likelihood of binding to one another.

image/png

In the paper Pairing interacting protein sequences using masked language modeling, the authors propose a method that uses either of two protein language models, MSA Transformer or ESM-2, to predict how likely it is that a pair of proteins bind to one another. In this post, we will focus on ESM-2, as this method is faster and easier to use than MSA Transformer, which requires Multiple Sequence Alignments (MSAs). The method is very simple. We take a list of proteins we would like to test for interactions, then concatenate them in pairs. Once the protein pairs are created, we use the masked language modeling capabilities of ESM-2 and randomly mask residues, then compute the MLM loss. We average over several iterations of this for each pair of proteins, obtaining a score that indicates how likely two proteins are to bind to one another. We then compute a matrix of such scores. Using this matrix are able to solve the associated linear assignment problem to compute optimal binding partners.

Introduction

Predicting protein-protein interactions is a critical task in molecular biology. Here, we'll use the ESM-2 model from Meta AI to compute Masked Language Model (MLM) loss for protein pairs, aiming to find the pairs with the lowest loss. The rationale is that proteins that interact in reality will produce a lower MLM loss than those that don't.

1. Importing Necessary Libraries

import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
  • numpy: Used for numerical operations, especially for handling matrices.
  • linear_sum_assignment: This is an optimization algorithm from the SciPy library that solves the linear sum assignment problem. We'll use this to find the optimal pairing of proteins based on the MLM loss.
  • transformers: This is the Hugging Face's library that provides pre-trained models for various NLP tasks. We're using it to load the ESM-2 model and its tokenizer.
  • torch: The PyTorch library, on which the transformers library is built.

2. Initializing the Model and Tokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
  • tokenizer: This is used to convert protein sequences into a format suitable for the model.
  • model: This is the ESM-2 model, specifically built for protein sequences.

3. Setting up the GPU

Using a GPU can greatly speed up computations. The following code checks if a GPU is available and sets the model to run on it.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

4. Defining the Protein Sequences

For this tutorial, we've chosen a set of protein sequences for our analysis:

all_proteins = [
    "MEESQSELNIDPPLSQETFSELWNLLPENNVLSSELCPAVDELLLPESVVNWLDEDSDDAPRMPATSAPTAPGPAPSWPLSSSVPSPKTYPGTYGFRLGFLHSGTAKSVTWTYSPLLNKLFCQLAKTCPVQLWVSSPPPPNTCVRAMAIYKKSEFVTEVVRRCPHHERCSDSSDGLAPPQHLIRVEGNLRAKYLDDRNTFRHSVVVPYEPPEVGSDYTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNVLGRNSFEVRVCACPGRDRRTEEENFHKKGEPCPEPPPGSTKRALPPSTSSSPPQKKKPLDGEYFTLQIRGRERYEMFRNLNEALELKDAQSGKEPGGSRAHSSHLKAKKGQSTSRHKKLMFKREGLDSD",
    "MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKDTYTMKEVLFYLGQYIMTKRLYDEKQQHIVYCSNDLLGDLFGVPSFSVKEHRKIYTMIYRNLVVVNQQESSDSGTSVSENRCHLEGGSDQKDLVQELQEEKPSSSHLVSRPSTSSRRRAISETEENSDELSGERQRKRHKSDSISLSFDESLALCVIREICCERSSSSESTGTPSNPDLDAGVSEHSGDWLDQDSVSDQFSVEFEVESLDSEDYSLSEEGQELSDEDDEVYQVTVYQAGESDTDSFEEDPEISLADYWKCTSCNEMNPPLPSHCNRCWALRENWLPEDKGKDKGEISEKAKLENSTQAEEGFDVPDCKKTIVNDSRESCVEENDDKITQASQSQESEDYSQPSTSSSIIYSSQEDVKEFEREETQDKEESVESSLPLNAIEPCVICQGRPKNGCIVHGKTGHLMACFTCAKKLKKRNKPCPVCRQPIQMIVLTYFP",
    "MNRGVPFRHLLLVLQLALLPAATQGKKVVLGKKGDTVELTCTASQKKSIQFHWKNSNQIKILGNQGSFLTKGPSKLNDRADSRRSLWDQGNFPLIIKNLKIEDSDTYICEVEDQKEEVQLLVFGLTANSDTHLLQGQSLTLTLESPPGSSPSVQCRSPRGKNIQGGKTLSVSQLELQDSGTWTCTVLQNQKKVEFKIDIVVLAFQKASSIVYKKEGEQVEFSFPLAFTVEKLTGSGELWWQAERASSSKSWITFDLKNKEVSVKRVTQDPKLQMGKKLPLHLTLPQALPQYAGSGNLTLALEAKTGKLHQEVNLVVMRATQLQKNLTCEVWGPTSPKLMLSLKLENKEAKVSKREKAVWVLNPEAGMWQCLLSDSGQVLLESNIKVLPTWSTPVQPMALIVLGGVAGLLLFIGLGIFFCVRCRHRRRQAERMSQIKRLLSEKKTCQCPHRFQKTCSPI",
    "MRVKEKYQHLWRWGWKWGTMLLGILMICSATEKLWVTVYYGVPVWKEATTTLFCASDAKAYDTEVHNVWATHACVPTDPNPQEVVLVNVTENFNMWKNDMVEQMHEDIISLWDQSLKPCVKLTPLCVSLKCTDLGNATNTNSSNTNSSSGEMMMEKGEIKNCSFNISTSIRGKVQKEYAFFYKLDIIPIDNDTTSYTLTSCNTSVITQACPKVSFEPIPIHYCAPAGFAILKCNNKTFNGTGPCTNVSTVQCTHGIRPVVSTQLLLNGSLAEEEVVIRSANFTDNAKTIIVQLNQSVEINCTRPNNNTRKSIRIQRGPGRAFVTIGKIGNMRQAHCNISRAKWNATLKQIASKLREQFGNNKTIIFKQSSGGDPEIVTHSFNCGGEFFYCNSTQLFNSTWFNSTWSTEGSNNTEGSDTITLPCRIKQFINMWQEVGKAMYAPPISGQIRCSSNITGLLLTRDGGNNNNGSEIFRPGGGDMRDNWRSELYKYKVVKIEPLGVAPTKAKRRVVQREKRAVGIGALFLGFLGAAGSTMGARSMTLTVQARQLLSGIVQQQNNLLRAIEAQQHLLQLTVWGIKQLQARILAVERYLKDQQLLGIWGCSGKLICTTAVPWNASWSNKSLEQIWNNMTWMEWDREINNYTSLIHSLIEESQNQQEKNEQELLELDKWASLWNWFNITNWLWYIKIFIMIVGGLVGLRIVFAVLSIVNRVRQGYSPLSFQTHLPTPRGPDRPEGIEEEGGERDRDRSIRLVNGSLALIWDDLRSLCLFSYHRLRDLLLIVTRIVELLGRRGWEALKYWWNLLQYWSQELKNSAVSLLNATAIAVAEGTDRVIEVVQGACRAIRHIPRRIRQGLERILL",
    "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN", 
    "MATGGRRGAAAAPLLVAVAALLLGAAGHLYPGEVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSFPKLIMITDYLLLFRVYGLESLKDLFPNLTVIRGSRLFFNYALVIFEMVHLKELGLYNLMNITRGSVRIEKNNELCYLATIDWSRILDSVEDNYIVLNKDDNEECGDICPGTAKGKTNCPATVINGQFVERCWTHSHCQKVCPTICKSHGCTAEGLCCHSECLGNCSQPDDPTKCVACRNFYLDGRCVETCPPPYYHFQDWRCVNFSFCQDLHHKCKNSRRQGCHQYVIHNNKCIPECPSGYTMNSSNLLCTPCLGPCPKVCHLLEGEKTIDSVTSAQELRGCTVINGSLIINIRGGNNLAAELEANLGLIEEISGYLKIRRSYALVSLSFFRKLRLIRGETLEIGNYSFYALDNQNLRQLWDWSKHNLTITQGKLFFHYNPKLCLSEIHKMEEVSGTKGRQERNDIALKTNGDQASCENELLKFSYIRTSFDKILLRWEPYWPPDFRDLLGFMLFYKEAPYQNVTEFDGQDACGSNSWTVVDIDPPLRSNDPKSQNHPGWLMRGLKPWTQYAIFVKTLVTFSDERRTYGAKSDIIYVQTDATNPSVPLDPISVSNSSSQIILKWKPPSDPNGNITHYLVFWERQAEDSELFELDYCLKGLKLPSRTWSPPFESEDSQKHNQSEYEDSAGECCSCPKTDSQILKELEESSFRKTFEDYLHNVVFVPRKTSSGTGAEDPRPSRKRRSLGDVGNVTVAVPTVAAFPNTSSTSVPTSPEEHRPFEKVVNKESLVISGLRHFTGYRIELQACNQDTPEERCSVAAYVSARTMPEAKADDIVGPVTHEIFENNVVHLMWQEPKEPNGLIVLYEVSYRRYGDEELHLCVSRKHFALERGCRLRGLSPGNYSVRIRATSLAGNGSWTEPTYFYVTDYLDVPSNIAKIIIGPLIFVFLFSVVIGSIYLFLRKRQPDGPLGPLYASSNPEYLSASDVFPCSVYVPDEWEVSR" 
]  # A list of protein sequences

5. Defining the MLM Loss Function

Our goal is to find which protein pairs produce the lowest MLM loss. We'll define a function that computes the MLM loss for a batch of protein pairs:

BATCH_SIZE = 2
NUM_MASKS = 10
P_MASK = 0.15

# Function to compute MLM loss for a batch of protein pairs
def compute_mlm_loss_batch(pairs):
    avg_losses = []
    for _ in range(NUM_MASKS):
        # Tokenize the concatenated protein pairs
        inputs = tokenizer(pairs, return_tensors="pt", truncation=True, padding=True, max_length=1022)
        
        # Move input tensors to GPU if available
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get the mask token ID
        mask_token_id = tokenizer.mask_token_id
        
        # Clone input IDs for labels
        labels = inputs["input_ids"].clone()
        
        # Randomly mask 15% of the residues for each sequence in the batch
        for idx in range(inputs["input_ids"].shape[0]):
            mask_indices = np.random.choice(inputs["input_ids"].shape[1], size=int(P_MASK * inputs["input_ids"].shape[1]), replace=False)
            inputs["input_ids"][idx, mask_indices] = mask_token_id
            labels[idx, [i for i in range(inputs["input_ids"].shape[1]) if i not in mask_indices]] = -100
        
        # Compute the MLM loss
        outputs = model(**inputs, labels=labels)
        avg_losses.append(outputs.loss.item())
    
    # Return the average loss for the batch
    return sum(avg_losses) / NUM_MASKS

Inside the function:

  • We tokenize each protein pair.
  • 15% of the residues in each sequence are randomly masked. This can be adjusted per your preferences, and different percentages should be tested.
  • We compute the MLM loss using the ESM-2 model.
  • This process is repeated NUM_MASKS times for each pair, and the average loss is computed.

Here, it is important to note that using a larger model, with a longer context window will likely improve results, but will also require more compute. Currently this can be run in a free Colab instance. If you want to use a larger model, with a longer context window, you might consider using one of these other ESM-2 models, for example, esm2_t36_3B_UR50D. You should also try adjusting max_length in the above code. The above context window isn't really sufficient for the long proteins we have chosen here, and using the larger models with longer context window, or using smaller proteins will almost certainly yield better results, but this should get you started.

6. Constructing the Loss Matrix

We need to compute the MLM loss for each possible pairing of proteins:

# Compute loss matrix
loss_matrix = np.zeros((len(all_proteins), len(all_proteins)))

for i in range(len(all_proteins)):
    for j in range(i+1, len(all_proteins), BATCH_SIZE):  # to avoid self-pairing and use batches
        pairs = [all_proteins[i] + all_proteins[k] for k in range(j, min(j+BATCH_SIZE, len(all_proteins)))]
        batch_loss = compute_mlm_loss_batch(pairs)
        for k in range(len(pairs)):
            loss_matrix[i, j+k] = batch_loss
            loss_matrix[j+k, i] = batch_loss  # the matrix is symmetric

# Set the diagonal of the loss matrix to a large value to prevent self-pairings
np.fill_diagonal(loss_matrix, np.inf)

For each pair, the MLM loss is computed and stored in the loss_matrix.

7. Finding Optimal Pairs

Finally, we use the linear_sum_assignment function to find the optimal pairing of proteins:

# Use the linear assignment problem to find the optimal pairing based on MLM loss
rows, cols = linear_sum_assignment(loss_matrix)
optimal_pairs = list(zip(rows, cols))

print(optimal_pairs)

This will print the following:

[(0, 1), (1, 0), (2, 5), (3, 4), (4, 3), (5, 2)]

This function tries to find the optimal assignment that minimizes the total MLM loss.


Next Steps

Now, you might consider choosing a protein of interest, then computing the binding sites of your protein following this tutorial. Once you have binding sites, you can head over to the RFDiffusion notebook and design several binding partners for your protein. Once you have designed several binding partners for your protein, you can then use this code to test which ones have the highest binding affinity for your protein by computing the MLM loss as demonstrated above for your protein of interest concatenated with the various potential binding partners. Remember, the longer your proteins are, the longer of a context window you will need. Here is an example of how to rank some binding partners for a fixed protein using this method:

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

# Load the base model and tokenizer
base_model_path = "facebook/esm2_t12_35M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(base_model_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_path)

# Ensure the model is in evaluation mode
model.eval()

# Define the protein of interest and its potential binders
protein_of_interest = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM"
potential_binders = [
    "MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL", 
    "MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ", 
    "EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN"
]  # Add potential binding sequences here

def compute_mlm_loss(protein, binder, iterations=3):
    total_loss = 0.0
    
    for _ in range(iterations):
        # Concatenate protein sequences with a separator
        concatenated_sequence = protein + ":" + binder

        # Mask a subset of amino acids in the concatenated sequence (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15  # For instance, masking 15% of the sequence
        num_mask = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [i for i, token in enumerate(tokens) if token != ":"]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_mask, replacement=False)

        for idx in mask_indices:
            tokens[available_indices[idx]] = tokenizer.mask_token

        masked_sequence = "".join(tokens)
        inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')

        # Compute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
        
        total_loss += loss.item()

    # Return the average loss
    return total_loss / iterations

# Compute MLM loss for each potential binder
mlm_losses = {}
for binder in potential_binders:
    loss = compute_mlm_loss(protein_of_interest, binder)
    mlm_losses[binder] = loss

# Rank binders based on MLM loss
ranked_binders = sorted(mlm_losses, key=mlm_losses.get)

print("Ranking of Potential Binders:")
for idx, binder in enumerate(ranked_binders, 1):
    print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}")

which will print the following:

Ranking of Potential Binders:
1. MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL
- MLM Loss: 8.279285748799643
2. MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ
- MLM Loss: 12.046218872070312
3. EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN
- MLM Loss: 12.776518821716309

We might also use a trick used in the literature for complexes, where a long string of 20 or so G amino acids is used in place of the : to separate the two proteins. This might lead to somewhat different predictions and is worth testing. We might also decide to preferentially mask binding sites predicted by our binding site predictor model (referenced above) due to the fact that protein language models are known to apply more attention to binding residues and other special residues. Yet another variation on this method might apply masks to only one of the sequences in each pair of concatenated sequences representing a potential protein complex. All of this may improve the performance of this method. Try these variations and see which method works best!

Building a PPI Network

We can also build a protein-protein interaction (PPI) network based on this method. We simply create a graph from a threshold for the MLM loss computation for pairs of proteins. If the loss is below a certain threshold, then the graph has an edge between those two proteins. This provides a very fast way of approximating protein interactomes and finding candidate interaction. We can do this as follows.

import networkx as nx
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import plotly.graph_objects as go
from ipywidgets import interact
from ipywidgets import widgets

# Check if CUDA is available and set the default device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"  # You can change this to your fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Send the model to the device (GPU or CPU)
model.to(device)

# Ensure the model is in evaluation mode
model.eval()

# Define Protein Sequences (Replace with your list)
all_proteins = [
    "MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV",
    "MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV", 
    "MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ", 
    
    "MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL", 
    "MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS", 
    "MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV", 
    
    "MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN", 
    "MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE", 
    "MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND", 
    
    "MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV", 
    "MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH", 
    "MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD", 
    
    "MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS", 
    "MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD", 
    "MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV"
]

def compute_average_mlm_loss(protein1, protein2, iterations=10):
    total_loss = 0.0
    connector = "G" * 25  # Connector sequence of G's
    for _ in range(iterations):
        concatenated_sequence = protein1 + connector + protein2
        inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024)
        
        mask_prob = 0.55
        mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob
        
        # Locate the positions of the connector 'G's and set their mask indices to False
        connector_indices = tokenizer.encode(connector, add_special_tokens=False)
        connector_length = len(connector_indices)
        start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))
        end_connector = start_connector + connector_length
        
        # Avoid masking the connector 'G's
        mask_indices[0, start_connector:end_connector] = False
        
        # Apply the mask to the input IDs
        inputs["input_ids"][mask_indices] = tokenizer.mask_token_id
        inputs = {k: v.to(device) for k, v in inputs.items()}  # Send inputs to the device

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
        
        loss = outputs.loss
        total_loss += loss.item()
    
    return total_loss / iterations

# Compute all average losses to determine the maximum threshold for the slider
all_losses = []
for i, protein1 in enumerate(all_proteins):
    for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
        avg_loss = compute_average_mlm_loss(protein1, protein2)
        all_losses.append(avg_loss)

# Set the maximum threshold to the maximum loss computed
max_threshold = max(all_losses)
print(f"Maximum loss (maximum threshold for slider): {max_threshold}")

def plot_graph(threshold):
    G = nx.Graph()

    # Add all protein nodes to the graph
    for i, protein in enumerate(all_proteins):
        G.add_node(f"protein {i+1}")

    # Loop through all pairs of proteins and calculate average MLM loss
    loss_idx = 0  # Index to keep track of the position in the all_losses list
    for i, protein1 in enumerate(all_proteins):
        for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):
            avg_loss = all_losses[loss_idx]
            loss_idx += 1
            
            # Add an edge if the loss is below the threshold
            if avg_loss < threshold:
                G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3))

    # 3D Network Plot
    # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.
    k_value = 2  # Lower value will bring nodes closer together
    pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)

    edge_x = []
    edge_y = []
    edge_z = []
    for edge in G.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])
    
    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))
    
    node_x = []
    node_y = []
    node_z = []
    node_text = []
    for node in G.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)
        node_text.append(node)
    
    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)
    
    layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))

    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
    fig.show()

# Create an interactive slider for the threshold value with a default of 8.50
interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))

This will print an interactive 3D graph representing the PPI network. You can adjust the threshold slider for the MLM loss value to make the requirements for interactions more or less strict.

image/png

So, for example, try setting the slider to 8.20 or 8.30 and see what kind of predicted interactome results from this choice of MLM loss threshold. You should also adjust the amount of masked tokens to see how this effects the graph. In general, masking more residues will make the threshold necessary for connections to appear in the predicted PPI graph higher. Note, this code may take a few moments to run, since we are computing the loss for all pairs of proteins in our list, but in general the method is a very fast zero shot method for predicting PPI networks. As a next step you might also consider training the model to predict masked residues of known interacting pairs in order to finetune it to this task further. Another interesting and important question to answer is how the length of the proteins effect this computation. Is the method robust to large variations in lengths, or do the proteins need to be of similar lengths for the method to work?

Conclusion

Now you should be able to use the ESM-2 model to predict potential protein-protein interactions by comparing the MLM loss of different protein pairings. This method provides a novel way of inferring interactions using deep learning techniques. Remember, this approach provides a heuristic and should be combined with experimental validation for conclusive results. As a next step, you might try implementing the ideas in PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling, which finetunes ESM-2 for generating binding partners using masked language modeling, or you might try finetuning ESM-2 in a similar fashion on concatenated pairs of binding partners, with some percentage of the tokens in each binding partner masked, to see if performance is improved.