--- license: mit datasets: - AmelieSchreiber/interaction_pairs language: - en library_name: transformers tags: - ESM-2 - biology - protein language model --- # ESM-2 for Interacting Proteins This model was finetuned on concatenated pairs of interacting proteins in much the same way as [PepMLM](https://huggingface.co/spaces/TianlaiChen/PepMLM). It is meant to generate an interaction partners for proteins using the masked language modeling capabilities of ESM-2. The model is not well tested, so use with caution. This is just a preliminary experiment. ## Using the Model To use the model, try running: ```python from transformers import AutoTokenizer, EsmForMaskedLM import torch import pandas as pd import numpy as np from torch.distributions import Categorical def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq): sequence = protein_seq + binder_seq tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device) # Create a mask for the binder sequence binder_mask = torch.zeros(tensor_input.shape).to(model.device) binder_mask[0, -len(binder_seq)-1:-1] = 1 # Mask the binder sequence in the input and create labels masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id) labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100) with torch.no_grad(): loss = model(masked_input, labels=labels).loss return np.exp(loss.item()) def generate_peptide_for_single_sequence(protein_seq, peptide_length = 15, top_k = 3, num_binders = 4): peptide_length = int(peptide_length) top_k = int(top_k) num_binders = int(num_binders) binders_with_ppl = [] for _ in range(num_binders): # Generate binder masked_peptide = '' * peptide_length input_sequence = protein_seq + masked_peptide inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) with torch.no_grad(): logits = model(**inputs).logits mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] logits_at_masks = logits[0, mask_token_indices] # Apply top-k sampling top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) predicted_indices = Categorical(probabilities).sample() predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '') # Compute PPL for the generated binder ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder) # Add the generated binder and its PPL to the results list binders_with_ppl.append([generated_binder, ppl_value]) return binders_with_ppl def generate_peptide(input_seqs, peptide_length=15, top_k=3, num_binders=4): if isinstance(input_seqs, str): # Single sequence binders = generate_peptide_for_single_sequence(input_seqs, peptide_length, top_k, num_binders) return pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity']) elif isinstance(input_seqs, list): # List of sequences results = [] for seq in input_seqs: binders = generate_peptide_for_single_sequence(seq, peptide_length, top_k, num_binders) for binder, ppl in binders: results.append([seq, binder, ppl]) return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity']) model = EsmForMaskedLM.from_pretrained("AmelieSchreiber/esm_interact") tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D") protein_seq = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE" results_df = generate_peptide(protein_seq, peptide_length=15, top_k=3, num_binders=5) print(results_df) ```