Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, EsmForMaskedLM | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| def generate_heatmap(protein_sequence, start_pos=1, end_pos=None): | |
| # Load the model and tokenizer | |
| model_name = "facebook/esm2_t6_8M_UR50D" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = EsmForMaskedLM.from_pretrained(model_name) | |
| # Tokenize the input sequence | |
| input_ids = tokenizer.encode(protein_sequence, return_tensors="pt") | |
| sequence_length = input_ids.shape[1] - 2 # Excluding the special tokens | |
| # Adjust end position if not specified | |
| if end_pos is None: | |
| end_pos = sequence_length | |
| # List of amino acids | |
| amino_acids = list("ACDEFGHIKLMNPQRSTVWY") | |
| # Initialize heatmap | |
| heatmap = np.zeros((20, end_pos - start_pos + 1)) | |
| # Calculate LLRs for each position and amino acid | |
| for position in range(start_pos, end_pos + 1): | |
| # Mask the target position | |
| masked_input_ids = input_ids.clone() | |
| masked_input_ids[0, position] = tokenizer.mask_token_id | |
| # Get logits for the masked token | |
| with torch.no_grad(): | |
| logits = model(masked_input_ids).logits | |
| # Calculate log probabilities | |
| probabilities = torch.nn.functional.softmax(logits[0, position], dim=0) | |
| log_probabilities = torch.log(probabilities) | |
| # Get the log probability of the wild-type residue | |
| wt_residue = input_ids[0, position].item() | |
| log_prob_wt = log_probabilities[wt_residue].item() | |
| # Calculate LLR for each variant | |
| for i, amino_acid in enumerate(amino_acids): | |
| log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item() | |
| heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt | |
| # Visualize the heatmap | |
| plt.figure(figsize=(15, 5)) | |
| plt.imshow(heatmap, cmap="viridis_r", aspect="auto") | |
| plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos])) | |
| plt.yticks(range(20), amino_acids) | |
| plt.xlabel("Position in Protein Sequence") | |
| plt.ylabel("Amino Acid Mutations") | |
| plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)") | |
| plt.colorbar(label="Log Likelihood Ratio (LLR)") | |
| plt.show() | |
| # Save the plot to a temporary file and return the file path | |
| temp_file = "temp_heatmap.png" | |
| plt.savefig(temp_file) | |
| plt.close() | |
| return temp_file | |
| def heatmap_interface(sequence, start, end=None): | |
| # Convert start and end to integers | |
| start = int(start) | |
| if end is not None: | |
| end = int(end) | |
| # If end is None or greater than sequence length, set it to sequence length | |
| if end is None or end > len(sequence) or end <= 0: | |
| end = len(sequence) | |
| # Ensure start is within bounds | |
| if start < 1 or start > len(sequence): | |
| return "Start position is out of bounds." | |
| # Generate heatmap | |
| heatmap_path = generate_heatmap(sequence, start, end) | |
| return heatmap_path | |
| # Define the Gradio interface | |
| iface = gr.Interface( | |
| fn=heatmap_interface, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."), | |
| gr.Number(label="Start Position", value=1), | |
| gr.Number(label="End Position") # No default value needed | |
| ], | |
| outputs="image", | |
| live=True | |
| ) | |
| # Run the Gradio app | |
| iface.launch() |