AmelieSchreiber's picture
Update app.py
00d2cb3
import gradio as gr
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
# Load the model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name)
# Tokenize the input sequence
input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
sequence_length = input_ids.shape[1] - 2 # Excluding the special tokens
# Adjust end position if not specified
if end_pos is None:
end_pos = sequence_length
# List of amino acids
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
# Initialize heatmap
heatmap = np.zeros((20, end_pos - start_pos + 1))
# Calculate LLRs for each position and amino acid
for position in range(start_pos, end_pos + 1):
# Mask the target position
masked_input_ids = input_ids.clone()
masked_input_ids[0, position] = tokenizer.mask_token_id
# Get logits for the masked token
with torch.no_grad():
logits = model(masked_input_ids).logits
# Calculate log probabilities
probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
log_probabilities = torch.log(probabilities)
# Get the log probability of the wild-type residue
wt_residue = input_ids[0, position].item()
log_prob_wt = log_probabilities[wt_residue].item()
# Calculate LLR for each variant
for i, amino_acid in enumerate(amino_acids):
log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt
# Visualize the heatmap
plt.figure(figsize=(15, 5))
plt.imshow(heatmap, cmap="viridis_r", aspect="auto")
plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
plt.yticks(range(20), amino_acids)
plt.xlabel("Position in Protein Sequence")
plt.ylabel("Amino Acid Mutations")
plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
plt.colorbar(label="Log Likelihood Ratio (LLR)")
plt.show()
# Save the plot to a temporary file and return the file path
temp_file = "temp_heatmap.png"
plt.savefig(temp_file)
plt.close()
return temp_file
def heatmap_interface(sequence, start, end=None):
# Convert start and end to integers
start = int(start)
if end is not None:
end = int(end)
# If end is None or greater than sequence length, set it to sequence length
if end is None or end > len(sequence) or end <= 0:
end = len(sequence)
# Ensure start is within bounds
if start < 1 or start > len(sequence):
return "Start position is out of bounds."
# Generate heatmap
heatmap_path = generate_heatmap(sequence, start, end)
return heatmap_path
# Define the Gradio interface
iface = gr.Interface(
fn=heatmap_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."),
gr.Number(label="Start Position", value=1),
gr.Number(label="End Position") # No default value needed
],
outputs="image",
live=True
)
# Run the Gradio app
iface.launch()