--- license: mit --- # ESM-2 Finetuned for PPI Prediction This is a finetuned version of `facebook/esm2_t33_650M_UR50D` using the masked language modeling objective. The model was finetuned for four epochs on concatenated pairs of interacting proteins, clustered using persistent homology landscapes as explained in [this post](https://huggingface.co/blog/AmelieSchreiber/faster-pha). The dataset consists of 10,000 protein pairs, which can be [found here](https://huggingface.co/datasets/AmelieSchreiber/pha_clustered_protein_complexes). This is a very new method for clustering protein-protein complexes. Using the MLM loss to predict pairs of interacting proteins was inspired by [this paper](https://arxiv.org/abs/2308.07136). However, the authors do not finetune the models for this task. Thus we reasoned that improved performance on this method could be achieved by finetuning the model on pairs of interacting proteins. ## Using the Model To use the model, we follow [this blog post](https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2). Below we see how to use the model for ranking potential binders for a target protein of interest. The lower the MLM loss average, the more likely the two proteins are to interact with one another. ```python import numpy as np from transformers import AutoTokenizer, EsmForMaskedLM import torch # Load the base model and tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50_v3") # Ensure the model is in evaluation mode model.eval() # Define the protein of interest and its potential binders protein_of_interest = "MLTEVMEVWHGLVIAVVSLFLQACFLTAINYLLSRHMAHKSEQILKAASLQVPRPSPGHHHPPAVKEMKETQTERDIPMSDSLYRHDSDTPSDSLDSSCSSPPACQATEDVDYTQVVFSDPGELKNDSPLDYENIKEITDYVNVNPERHKPSFWYFVNPALSEPAEYDQVAM" potential_binders = [ # Known to interact "MASPGSGFWSFGSEDGSGDSENPGTARAWCQVAQKFTGGIGNKLCALLYGDAEKPAESGGSQPPRAAARKAACACDQKPCSCSKVDVNYAFLHATDLLPACDGERPTLAFLQDVMNILLQYVVKSFDRSTKVIDFHYPNELLQEYNWELADQPQNLEEILMHCQTTLKYAIKTGHPRYFNQLSTGLDMVGLAADWLTSTANTNMFTYEIAPVFVLLEYVTLKKMREIIGWPGGSGDGIFSPGGAISNMYAMMIARFKMFPEVKEKGMAALPRLIAFTSEHSHFSLKKGAAALGIGTDSVILIKCDERGKMIPSDLERRILEAKQKGFVPFLVSATAGTTVYGAFDPLLAVADICKKYKIWMHVDAAWGGGLLMSRKHKWKLSGVERANSVTWNPHKMMGVPLQCSALLVREEGLMQNCNQMHASYLFQQDKHYDLSYDTGDKALQCGRHVDVFKLWLMWRAKGTTGFEAHVDKCLELAEYLYNIIKNREGYEMVFDGKPQHTNVCFWYIPPSLRTLEDNEERMSRLSKVAPVIKARMMEYGTTMVSYQPLGDKVNFFRMVISNPAATHQDIDFLIEEIERLGQDL", "MAAGVAGWGVEAEEFEDAPDVEPLEPTLSNIIEQRSLKWIFVGGKGGVGKTTCSCSLAVQLSKGRESVLIISTDPAHNISDAFDQKFSKVPTKVKGYDNLFAMEIDPSLGVAELPDEFFEEDNMLSMGKKMMQEAMSAFPGIDEAMSYAEVMRLVKGMNFSVVVFDTAPTGHTLRLLNFPTIVERGLGRLMQIKNQISPFISQMCNMLGLGDMNADQLASKLEETLPVIRSVSEQFKDPEQTTFICVCIAEFLSLYETERLIQELAKCKIDTHNIIVNQLVFPDPEKPCKMCEARHKIQAKYLDQMEDLYEDFHIVKLPLLPHEVRGADKVNTFSALLLEPYKPPSAQ", "EKTGLSIRGAQEEDPPDPQLMRLDNMLLAEGVSGPEKGGGSAAAAAAAAASGGSSDNSIEHSDYRAKLTQIRQIYHTELEKYEQACNEFTTHVMNLLREQSRTRPISPKEIERMVGIIHRKFSSIQMQLKQSTCEAVMILRSRFLDARRKRRNFSKQATEILNEYFYSHLSNPYPSEEAKEELAKKCSITVSQSLVKDPKERGSKGSDIQPTSVVSNWFGNKRIRYKKNIGKFQEEANLYAAKTAVTAAHAVAAAVQNNQTNSPTTPNSGSSGSFNLPNSGDMFMNMQSLNGDSYQGSQVGANVQSQVDTLRHVINQTGGYSDGLGGNSLYSPHNLNANGGWQDATTPSSVTSPTEGPGSVHSDTSN", # Not known to interact "MRQRLLPSVTSLLLVALLFPGSSQARHVNHSATEALGELRERAPGQGTNGFQLLRHAVKRDLLPPRTPPYQVHISHREARGPSFRICVDFLGPRWARGCSTGN", "MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS" ] # Add potential binding sequences here def compute_mlm_loss(protein, binder, iterations=5): total_loss = 0.0 for _ in range(iterations): # Concatenate protein sequences with a separator concatenated_sequence = protein + binder # Mask a subset of amino acids in the concatenated sequence (excluding the separator) tokens = list(concatenated_sequence) mask_rate = 0.35 # For instance, masking 35% of the sequence num_mask = int(len(tokens) * mask_rate) # Exclude the separator from potential mask indices available_indices = [i for i, token in enumerate(tokens) if token != ":"] probs = torch.ones(len(available_indices)) mask_indices = torch.multinomial(probs, num_mask, replacement=False) for idx in mask_indices: tokens[available_indices[idx]] = tokenizer.mask_token masked_sequence = "".join(tokens) inputs = tokenizer(masked_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length', add_special_tokens=False) # Compute the MLM loss with torch.no_grad(): outputs = model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss total_loss += loss.item() # Return the average loss return total_loss / iterations # Compute MLM loss for each potential binder mlm_losses = {} for binder in potential_binders: loss = compute_mlm_loss(protein_of_interest, binder) mlm_losses[binder] = loss # Rank binders based on MLM loss ranked_binders = sorted(mlm_losses, key=mlm_losses.get) print("Ranking of Potential Binders:") for idx, binder in enumerate(ranked_binders, 1): print(f"{idx}. {binder} - MLM Loss: {mlm_losses[binder]}") ``` ## PPI Networks To construct a protein-protein interaction network, try running the following code. Try adjusting the length of the connector (0-25 for example). Try also adjusting the number of iterations and the masking percentage. ```python import networkx as nx import numpy as np import torch from transformers import AutoTokenizer, AutoModelForMaskedLM, EsmForMaskedLM import plotly.graph_objects as go from ipywidgets import interact from ipywidgets import widgets # Check if CUDA is available and set the default device accordingly device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the pretrained (or fine-tuned) ESM-2 model and tokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_mlmppi_ph50_v3") # Send the model to the device (GPU or CPU) model.to(device) # Ensure the model is in evaluation mode model.eval() # Define Protein Sequences (Replace with your list) all_proteins = [ "MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV", "MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV", "MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ", "MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL", "MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS", "MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV", "MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN", "MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE", "MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND", "MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV", "MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH", "MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD", "MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS", "MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD", "MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV" ] def compute_average_mlm_loss(protein1, protein2, iterations=10): total_loss = 0.0 connector = "G" * 25 # Connector sequence of G's for _ in range(iterations): concatenated_sequence = protein1 + connector + protein2 inputs = tokenizer(concatenated_sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024) mask_prob = 0.35 mask_indices = torch.rand(inputs["input_ids"].shape, device=device) < mask_prob # Locate the positions of the connector 'G's and set their mask indices to False connector_indices = tokenizer.encode(connector, add_special_tokens=False) connector_length = len(connector_indices) start_connector = len(tokenizer.encode(protein1, add_special_tokens=False)) end_connector = start_connector + connector_length # Avoid masking the connector 'G's mask_indices[0, start_connector:end_connector] = False # Apply the mask to the input IDs inputs["input_ids"][mask_indices] = tokenizer.mask_token_id inputs = {k: v.to(device) for k, v in inputs.items()} # Send inputs to the device with torch.no_grad(): outputs = model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss total_loss += loss.item() return total_loss / iterations # Compute all average losses to determine the maximum threshold for the slider all_losses = [] for i, protein1 in enumerate(all_proteins): for j, protein2 in enumerate(all_proteins[i+1:], start=i+1): avg_loss = compute_average_mlm_loss(protein1, protein2) all_losses.append(avg_loss) # Set the maximum threshold to the maximum loss computed max_threshold = max(all_losses) print(f"Maximum loss (maximum threshold for slider): {max_threshold}") def plot_graph(threshold): G = nx.Graph() # Add all protein nodes to the graph for i, protein in enumerate(all_proteins): G.add_node(f"protein {i+1}") # Loop through all pairs of proteins and calculate average MLM loss loss_idx = 0 # Index to keep track of the position in the all_losses list for i, protein1 in enumerate(all_proteins): for j, protein2 in enumerate(all_proteins[i+1:], start=i+1): avg_loss = all_losses[loss_idx] loss_idx += 1 # Add an edge if the loss is below the threshold if avg_loss < threshold: G.add_edge(f"protein {i+1}", f"protein {j+1}", weight=round(avg_loss, 3)) # 3D Network Plot # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value. k_value = 2 # Lower value will bring nodes closer together pos = nx.spring_layout(G, dim=3, seed=42, k=k_value) edge_x = [] edge_y = [] edge_z = [] for edge in G.edges(): x0, y0, z0 = pos[edge[0]] x1, y1, z1 = pos[edge[1]] edge_x.extend([x0, x1, None]) edge_y.extend([y0, y1, None]) edge_z.extend([z0, z1, None]) edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey')) node_x = [] node_y = [] node_z = [] node_text = [] for node in G.nodes(): x, y, z = pos[node] node_x.append(x) node_y.append(y) node_z.append(z) node_text.append(node) node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text) layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False))) fig = go.Figure(data=[edge_trace, node_trace], layout=layout) fig.show() # Create an interactive slider for the threshold value with a default of 8.50 interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25)) ```