PepMLM / app.py
TianlaiChen's picture
Update app.py
c6b3e88
raw
history blame
No virus
3.9 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch.distributions.categorical import Categorical
import numpy as np
import pandas as pd
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("ChatterjeeLab/PepMLM-650M")
model = AutoModelForMaskedLM.from_pretrained("ChatterjeeLab/PepMLM-650M")
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
sequence = protein_seq + binder_seq
tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
total_loss = 0
# Loop through each token in the binder sequence
for i in range(-len(binder_seq)-1, -1):
# Create a copy of the original tensor
masked_input = tensor_input.clone()
# Mask one token at a time
masked_input[0, i] = tokenizer.mask_token_id
# Create labels
labels = torch.full(tensor_input.shape, -100).to(model.device)
labels[0, i] = tensor_input[0, i]
# Get model prediction and loss
with torch.no_grad():
outputs = model(masked_input, labels=labels)
total_loss += outputs.loss.item()
# Calculate the average loss
avg_loss = total_loss / len(binder_seq)
# Calculate pseudo perplexity
pseudo_perplexity = np.exp(avg_loss)
return pseudo_perplexity
def generate_peptide(protein_seq, peptide_length, top_k, num_binders):
peptide_length = int(peptide_length)
top_k = int(top_k)
num_binders = int(num_binders)
binders_with_ppl = []
for _ in range(num_binders):
# Generate binder
masked_peptide = '<mask>' * peptide_length
input_sequence = protein_seq + masked_peptide
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
logits_at_masks = logits[0, mask_token_indices]
# Apply top-k sampling
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
predicted_indices = Categorical(probabilities).sample()
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
# Compute PPL for the generated binder
ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
# Add the generated binder and its PPL to the results list
binders_with_ppl.append([generated_binder, ppl_value])
# Convert the list of lists to a pandas dataframe
df = pd.DataFrame(binders_with_ppl, columns=["Binder", "Perplexity"])
# Save the dataframe to a CSV file
output_filename = "output.csv"
df.to_csv(output_filename, index=False)
return binders_with_ppl, output_filename
# Define the Gradio interface
interface = gr.Interface(
fn=generate_peptide,
inputs=[
gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"),
gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'),
gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'),
gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=1)
],
outputs=[
gr.Dataframe(
headers=["Binder", "Perplexity"],
datatype=["str", "number"],
col_count=(2, "fixed")
),
gr.outputs.File(label="Download CSV")
],
title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling"
)
interface.launch()