AmelieSchreiber commited on
Commit
f3ea76e
·
1 Parent(s): c88b154

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, EsmForMaskedLM
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+
8
+ def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
9
+ # Load the model and tokenizer
10
+ model_name = "facebook/esm2_t6_8M_UR50D"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = EsmForMaskedLM.from_pretrained(model_name)
13
+
14
+ # Tokenize the input sequence
15
+ input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
16
+ sequence_length = input_ids.shape[1] - 2 # Excluding the special tokens
17
+
18
+ # Adjust end position if not specified
19
+ if end_pos is None:
20
+ end_pos = sequence_length
21
+
22
+ # List of amino acids
23
+ amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
24
+
25
+ # Initialize heatmap
26
+ heatmap = np.zeros((20, end_pos - start_pos + 1))
27
+
28
+ # Calculate LLRs for each position and amino acid
29
+ for position in range(start_pos, end_pos + 1):
30
+ # Mask the target position
31
+ masked_input_ids = input_ids.clone()
32
+ masked_input_ids[0, position] = tokenizer.mask_token_id
33
+
34
+ # Get logits for the masked token
35
+ with torch.no_grad():
36
+ logits = model(masked_input_ids).logits
37
+
38
+ # Calculate log probabilities
39
+ probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
40
+ log_probabilities = torch.log(probabilities)
41
+
42
+ # Get the log probability of the wild-type residue
43
+ wt_residue = input_ids[0, position].item()
44
+ log_prob_wt = log_probabilities[wt_residue].item()
45
+
46
+ # Calculate LLR for each variant
47
+ for i, amino_acid in enumerate(amino_acids):
48
+ log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
49
+ heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt
50
+
51
+ # Visualize the heatmap
52
+ plt.figure(figsize=(15, 5))
53
+ plt.imshow(heatmap, cmap="viridis_r", aspect="auto")
54
+ plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
55
+ plt.yticks(range(20), amino_acids)
56
+ plt.xlabel("Position in Protein Sequence")
57
+ plt.ylabel("Amino Acid Mutations")
58
+ plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
59
+ plt.colorbar(label="Log Likelihood Ratio (LLR)")
60
+ plt.show()
61
+
62
+ # Save the plot to a temporary file and return the file path
63
+ temp_file = "temp_heatmap.png"
64
+ plt.savefig(temp_file)
65
+ plt.close()
66
+ return temp_file
67
+
68
+ def heatmap_interface(sequence, start, end):
69
+ # Ensure start and end positions are within bounds
70
+ if start < 1 or end > len(sequence):
71
+ return "Start or end position is out of bounds."
72
+
73
+ # Generate heatmap
74
+ heatmap_path = generate_heatmap(sequence, start, end)
75
+ return heatmap_path
76
+
77
+ # Define the Gradio interface
78
+ iface = gr.Interface(
79
+ fn=heatmap_interface,
80
+ inputs=[
81
+ gr.Textbox(lines=2, placeholder="Enter Protein Sequence Here..."),
82
+ gr.Number(label="Start Position", default=1),
83
+ gr.Number(label="End Position")
84
+ ],
85
+ outputs="image",
86
+ live=True
87
+ )
88
+
89
+ # Run the Gradio app
90
+ iface.launch()