Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, EsmForMaskedLM | |
import torch | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
def generate_heatmap(protein_sequence, start_pos=1, end_pos=None): | |
# Load the model and tokenizer | |
model_name = "facebook/esm2_t6_8M_UR50D" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = EsmForMaskedLM.from_pretrained(model_name) | |
# Tokenize the input sequence | |
input_ids = tokenizer.encode(protein_sequence, return_tensors="pt") | |
sequence_length = input_ids.shape[1] - 2 # Excluding the special tokens | |
# Adjust end position if not specified | |
if end_pos is None: | |
end_pos = sequence_length | |
# List of amino acids | |
amino_acids = list("ACDEFGHIKLMNPQRSTVWY") | |
# Initialize heatmap | |
heatmap = np.zeros((20, end_pos - start_pos + 1)) | |
# Calculate LLRs for each position and amino acid | |
for position in range(start_pos, end_pos + 1): | |
# Mask the target position | |
masked_input_ids = input_ids.clone() | |
masked_input_ids[0, position] = tokenizer.mask_token_id | |
# Get logits for the masked token | |
with torch.no_grad(): | |
logits = model(masked_input_ids).logits | |
# Calculate log probabilities | |
probabilities = torch.nn.functional.softmax(logits[0, position], dim=0) | |
log_probabilities = torch.log(probabilities) | |
# Get the log probability of the wild-type residue | |
wt_residue = input_ids[0, position].item() | |
log_prob_wt = log_probabilities[wt_residue].item() | |
# Calculate LLR for each variant | |
for i, amino_acid in enumerate(amino_acids): | |
log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item() | |
heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt | |
# Visualize the heatmap | |
plt.figure(figsize=(15, 5)) | |
plt.imshow(heatmap, cmap="viridis_r", aspect="auto") | |
plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos])) | |
plt.yticks(range(20), amino_acids) | |
plt.xlabel("Position in Protein Sequence") | |
plt.ylabel("Amino Acid Mutations") | |
plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)") | |
plt.colorbar(label="Log Likelihood Ratio (LLR)") | |
plt.show() | |
# Save the plot to a temporary file and return the file path | |
temp_file = "temp_heatmap.png" | |
plt.savefig(temp_file) | |
plt.close() | |
return temp_file | |
def heatmap_interface(sequence, start, end=None): | |
# Convert start and end to integers | |
start = int(start) | |
if end is not None: | |
end = int(end) | |
# If end is None or greater than sequence length, set it to sequence length | |
if end is None or end > len(sequence) or end <= 0: | |
end = len(sequence) | |
# Ensure start is within bounds | |
if start < 1 or start > len(sequence): | |
return "Start position is out of bounds." | |
# Generate heatmap | |
heatmap_path = generate_heatmap(sequence, start, end) | |
return heatmap_path | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=heatmap_interface, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."), | |
gr.Number(label="Start Position", value=1), | |
gr.Number(label="End Position") # No default value needed | |
], | |
outputs="image", | |
live=True | |
) | |
# Run the Gradio app | |
iface.launch() |