Massimo G. Totaro commited on
Commit
ddc1bd3
·
1 Parent(s): fba8f5e

QOL and gradio upgrade

Browse files
Files changed (7) hide show
  1. .gitignore +2 -1
  2. README.md +1 -1
  3. app.py +14 -26
  4. data.py +22 -31
  5. instructions.md +58 -36
  6. model.py +11 -7
  7. requirements.txt +1 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  Dockerfile
2
  *.ipynb
3
- */
 
 
1
  Dockerfile
2
  *.ipynb
3
+ out.*
4
+ */
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-2-clause
 
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
10
  license: bsd-2-clause
app.py CHANGED
@@ -1,5 +1,4 @@
1
- from tempfile import NamedTemporaryFile
2
- from gradio import Blocks, Button, Checkbox, Dropdown, Examples, File, HTML, Markdown, Textbox
3
 
4
  from model import get_models
5
  from data import Data
@@ -17,19 +16,14 @@ def app(*argv):
17
  # Unpack the arguments
18
  seq, trg, model_name, *_ = argv
19
  scoring = SCORING[scoring_strategy.value]
20
- try:
21
- # Calculate the data based on the input parameters
22
- data = Data(seq, trg, model_name, scoring, out_file).calculate()
23
- except Exception as e:
24
- # If an error occurs, return an HTML error message
25
- return f'<!DOCTYPE html><html><body><h1 style="background-color:#F70D1A;text-align:center;">Error: {str(e)}</h1></body></html>', None
26
  # If no error occurs, return the calculated data
27
- return repr(data), File(value=out_file.name, visible=True)
28
 
29
  # Create the Gradio interface
30
- with open("instructions.md", "r", encoding="utf-8") as md,\
31
- NamedTemporaryFile(mode='w+') as out_file,\
32
- Blocks() as esm_scan:
33
 
34
  # Define the interface components
35
  Markdown(md.read())
@@ -46,20 +40,14 @@ with open("instructions.md", "r", encoding="utf-8") as md,\
46
  value=""
47
  )
48
  model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
49
- scoring_strategy = Checkbox(value=True, label="Use masked-marginals scoring")
50
- btn = Button(value="Run")
51
- out = HTML()
52
- bto = File(
53
- value=out_file.name,
54
- visible=False,
55
- label="Download",
56
- file_count='single',
57
- interactive=False
58
- )
59
  btn.click(
60
  fn=app,
61
  inputs=[seq, trg, model_name],
62
- outputs=[out, bto]
63
  )
64
  ex = Examples(
65
  examples=[
@@ -87,9 +75,9 @@ with open("instructions.md", "r", encoding="utf-8") as md,\
87
  inputs=[seq,
88
  trg,
89
  model_name],
90
- outputs=[out,
91
- bto],
92
- fn=app
93
  )
94
 
95
  # Launch the Gradio interface
 
1
+ from gradio import Blocks, Button, Checkbox, DownloadButton, Dropdown, Examples, File, Image, Markdown, Textbox
 
2
 
3
  from model import get_models
4
  from data import Data
 
16
  # Unpack the arguments
17
  seq, trg, model_name, *_ = argv
18
  scoring = SCORING[scoring_strategy.value]
19
+ # Calculate the data based on the input parameters
20
+ data = Data(seq, trg, model_name, scoring).calculate()
21
+
 
 
 
22
  # If no error occurs, return the calculated data
23
+ return Image(value=data.image(), type='filepath', visible=True), DownloadButton(value=data.csv(), visible=True)
24
 
25
  # Create the Gradio interface
26
+ with open("instructions.md", "r", encoding="utf-8") as md, Blocks() as esm_scan:
 
 
27
 
28
  # Define the interface components
29
  Markdown(md.read())
 
40
  value=""
41
  )
42
  model_name = Dropdown(MODELS, label="Model", value="facebook/esm2_t30_150M_UR50D")
43
+ scoring_strategy = Checkbox(value=True, label="Use higher accuracy scoring", interactive=True)
44
+ dlb = DownloadButton(label="Download raw data", visible=False)
45
+ out = Image(visible=False)
46
+ btn = Button(value="Run", variant="primary")
 
 
 
 
 
 
47
  btn.click(
48
  fn=app,
49
  inputs=[seq, trg, model_name],
50
+ outputs=[out, dlb]
51
  )
52
  ex = Examples(
53
  examples=[
 
75
  inputs=[seq,
76
  trg,
77
  model_name],
78
+ outputs=[out],
79
+ fn=app,
80
+ cache_examples=False
81
  )
82
 
83
  # Launch the Gradio interface
data.py CHANGED
@@ -1,12 +1,8 @@
 
1
  from math import ceil
2
- from re import match
3
- import seaborn as sns
4
-
5
- from model import Model
6
-
7
-
8
  import matplotlib.pyplot as plt
9
  import pandas as pd
 
10
  import seaborn as sns
11
 
12
  from model import Model
@@ -26,19 +22,18 @@ class Data:
26
  """Parse input substitutions"""
27
  self.mode = None
28
  self.sub = list()
29
- self.trg = trg.strip().upper()
30
  self.resi = list()
31
 
32
  # Identify running mode
33
- if len(self.trg.split()) == 1 and len(self.trg.split()[0]) == len(self.seq) and all(match(r'\w+', x) for x in self.trg):
34
  # If single string of same length as sequence, seq vs seq mode
35
  self.mode = 'MUT'
36
- for resi, (src, trg) in enumerate(zip(self.seq, self.trg), 1):
37
  if src != trg:
38
  self.sub.append(f"{src}{resi}{trg}")
39
  self.resi.append(resi)
40
  else:
41
- self.trg = self.trg.split()
42
  if all(match(r'\d+', x) for x in self.trg):
43
  # If all strings are numbers, deep mutational scanning mode
44
  self.mode = 'DMS'
@@ -64,7 +59,7 @@ class Data:
64
 
65
  self.sub = pd.DataFrame(self.sub, columns=['0'])
66
 
67
- def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file=None):
68
  "initialise data"
69
  # if model has changed, load new model
70
  if self.model.model_name != model_name:
@@ -76,13 +71,14 @@ class Data:
76
  self.scoring_strategy = scoring_strategy
77
  self.token_probs = None
78
  self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
79
- self.out_str = None
80
- self.out_buffer = out_file.name if 'name' in dir(out_file) else out_file
81
 
82
  def parse_output(self) -> None:
83
  "format output data for visualisation"
84
  if self.mode == 'TMS':
85
  self.process_tms_mode()
 
86
  else:
87
  if self.mode == 'DMS':
88
  self.sort_by_residue_and_score()
@@ -90,14 +86,12 @@ class Data:
90
  self.sort_by_score()
91
  else:
92
  raise RuntimeError(f"Unrecognised mode {self.mode}")
93
- if self.out_buffer:
94
- self.out.round(2).to_csv(self.out_buffer, index=False, header=False)
95
- self.out_str = (self.out.style
96
- .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
97
- .hide(axis=0)
98
- .hide(axis=1)
99
- .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)
100
- .to_html(justify='center'))
101
 
102
  def sort_by_score(self):
103
  self.out = self.out.sort_values(self.model_name, ascending=False)
@@ -155,10 +149,7 @@ class Data:
155
  else:
156
  self.plot_multiple_heatmaps(ncols, nrows)
157
 
158
- if self.out_buffer:
159
- plt.savefig(self.out_buffer, format='svg')
160
- with open(self.out_buffer, 'r', encoding='utf-8') as f:
161
- self.out_str = f.read()
162
 
163
  def plot_single_heatmap(self):
164
  fig = plt.figure(figsize=(12, 6))
@@ -200,10 +191,10 @@ class Data:
200
  self.parse_output()
201
  return self
202
 
203
- def __str__(self):
204
- "return output data in DataFrame format"
205
- return str(self.out)
206
 
207
- def __repr__(self):
208
- "return output data in html format"
209
- return self.out_str
 
1
+ import dataframe_image as dfi
2
  from math import ceil
 
 
 
 
 
 
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
+ from re import match
6
  import seaborn as sns
7
 
8
  from model import Model
 
22
  """Parse input substitutions"""
23
  self.mode = None
24
  self.sub = list()
25
+ self.trg = trg.strip().upper().split()
26
  self.resi = list()
27
 
28
  # Identify running mode
29
+ if len(self.trg) == 1 and len(self.trg[0]) == len(self.seq) and match(r'^\w+$', self.trg[0]):
30
  # If single string of same length as sequence, seq vs seq mode
31
  self.mode = 'MUT'
32
+ for resi, (src, trg) in enumerate(zip(self.seq, self.trg[0]), 1):
33
  if src != trg:
34
  self.sub.append(f"{src}{resi}{trg}")
35
  self.resi.append(resi)
36
  else:
 
37
  if all(match(r'\d+', x) for x in self.trg):
38
  # If all strings are numbers, deep mutational scanning mode
39
  self.mode = 'DMS'
 
59
 
60
  self.sub = pd.DataFrame(self.sub, columns=['0'])
61
 
62
+ def __init__(self, src:str, trg:str, model_name:str='facebook/esm2_t33_650M_UR50D', scoring_strategy:str='masked-marginals', out_file='out'):
63
  "initialise data"
64
  # if model has changed, load new model
65
  if self.model.model_name != model_name:
 
71
  self.scoring_strategy = scoring_strategy
72
  self.token_probs = None
73
  self.out = pd.DataFrame(self.sub, columns=['0', self.model_name])
74
+ self.out_img = f'{out_file}.png'
75
+ self.out_csv = f'{out_file}.csv'
76
 
77
  def parse_output(self) -> None:
78
  "format output data for visualisation"
79
  if self.mode == 'TMS':
80
  self.process_tms_mode()
81
+ self.out.to_csv(self.out_csv, float_format='%.2f')
82
  else:
83
  if self.mode == 'DMS':
84
  self.sort_by_residue_and_score()
 
86
  self.sort_by_score()
87
  else:
88
  raise RuntimeError(f"Unrecognised mode {self.mode}")
89
+ out_df = (self.out.style
90
+ .format(lambda x: f'{x:.2f}' if isinstance(x, float) else x)
91
+ .hide(axis=0).hide(axis=1)
92
+ .background_gradient(cmap="RdYlGn", vmax=8, vmin=-8))
93
+ dfi.export(out_df, self.out_img, max_rows=-1, max_cols=-1, dpi=300)
94
+ self.out.to_csv(self.out_csv, float_format='%.2f', index=False, header=False)
 
 
95
 
96
  def sort_by_score(self):
97
  self.out = self.out.sort_values(self.model_name, ascending=False)
 
149
  else:
150
  self.plot_multiple_heatmaps(ncols, nrows)
151
 
152
+ plt.savefig(self.out_img, format='png', dpi=300)
 
 
 
153
 
154
  def plot_single_heatmap(self):
155
  fig = plt.figure(figsize=(12, 6))
 
191
  self.parse_output()
192
  return self
193
 
194
+ def csv(self):
195
+ "return output data"
196
+ return self.out_csv
197
 
198
+ def image(self):
199
+ "return output data"
200
+ return self.out_img
instructions.md CHANGED
@@ -1,39 +1,61 @@
1
  # **ESM-Scan**
 
2
  Calculate the <u>fitness of single amino acid substitutions</u> on proteins, using a [zero-shot](https://doi.org/10.1101/2021.07.09.450648) [language model predictor](https://github.com/facebookresearch/esm)
3
 
4
- <details>
5
- <summary> <b> USAGE INSTRUCTIONS </b> </summary>
6
-
7
- ### **Setup**
8
- No setup is required, just fill the input boxes with the required data and click on the `Run` button.
9
- A list of examples can be found at the bottom of the page, click on them to autofill the fields.
10
- If the server is not used for some time, it will go into standby.
11
- Running a calculation resumes the tool from standby, the first run might take longer due to startup and model loading.
12
-
13
- ### **Input**
14
- - write the protein full amino acid sequence to be analysed in the **Sequence** text box
15
- jolly charachters (e.g. `-X.B`) can be inserted but, at the moment, visualisation cannot handle them
16
- - write the substitutions to test in the **Substitutions** box
17
- there are three running modes that can be used, depending on the input:
18
- + *single substitution* or list thereof (in the form of `R218K R218W`): the single substitution is scored
19
- + *residue position* or list thereof: all possible substitutions will be evaluated
20
- + *same-length sequence*: the differing amino acid substitutions will be evaluated, one by one
21
- + any other *different input*: a deep mutational scan of the full sequence will be performed
22
- - the ESM model to use for the calculations can be chosen among those that are available on Hugging Face Model Hub;
23
- `esm2_t33_650M_UR50D` offers the best expense-accuracy tradeoff[*](https://doi.org/10.1126/science.ade2574)
24
- - the `masked-marginals` scoring strategy considers sequence context at inference time, being slower but more accurate;
25
- in case of long runtimes, you can tick the box off to speed the calculations up significantly, sacrificing accuracy
26
- - when running a deep mutational scan, it is recommended to use smaller models (8M, 35M, 150M parameters), since the runtime is significant, especially for longer sequences and the server might be overloaded;
27
- over 30 min might be necessary for calculating a 300-residue-long sequence with larger models
28
- in general, accuracy is influenced significantly by the scoring strategy and less so by the model size, so it is suggested to reduce the latter first when optimising for runtime;
29
- the scoring strategy computational cost scales with the number of substitutions tested, while the model’s with the wild-type sequence length
30
- - it is possible to calculate the effect of multiple concurrent substitutions, but this has to be done manually, by changing the input sequence and running the calculation again
31
-
32
- ### **Output**
33
- Your results will be shown in a color-coded table, except for the deep mutational scan which will yield a heatmap.
34
- The output data can be downloaded from the box at the bottom.
35
- File extensions are not supported by the server and need to be appended to the filenames after downloading:
36
- - `CSV` for tables
37
- - `SVG` for full-sequence deep mutational scan
38
-
39
- </details>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # **ESM-Scan**
2
+
3
  Calculate the <u>fitness of single amino acid substitutions</u> on proteins, using a [zero-shot](https://doi.org/10.1101/2021.07.09.450648) [language model predictor](https://github.com/facebookresearch/esm)
4
 
5
+ <details>
6
+ <summary> <b> USAGE INSTRUCTIONS </b> </summary>
7
+
8
+ ## Setup
9
+
10
+ No setup is required. Simply fill in the input boxes with the necessary data and click the **Run** button.
11
+ You can find a list of examples at the bottom of the page; clicking on them will autofill the fields for you.
12
+ If the server remains idle for a period, it will enter standby mode. Running a calculation will wake the tool from standby, but note that the first run may take longer due to startup and model loading.
13
+
14
+ ## Input
15
+
16
+ **Sequence**: Enter the full amino acid sequence to be analyzed in the **Sequence** text box.
17
+ Note: While jolly characters (e.g., `-X.B`) can be included, they currently cannot be visualised.
18
+
19
+ **Substitutions**: Specify the substitutions you wish to test in the **Substitutions** box. The tool supports three running modes based on your input:
20
+
21
+ - **Single Substitution**: Input one or more substitutions (e.g. `R218K R218W`) to score specific changes.
22
+ - **Residue Position**: Provide residue positions to evaluate all possible substitutions at those sites.
23
+ - **Same-Length Sequence**: Analyze differing amino acid substitutions one by one within sequences of equal length.
24
+ - **Different Inputs**: For any other input format, a deep mutational scan of the full sequence will be performed.
25
+
26
+ **Model Selection**: Choose an ESM model for calculations from those available on Hugging Face Model Hub.
27
+ The model `esm2_t33_650M_UR50D` offers an optimal balance between cost and accuracy [*](https://doi.org/10.1126/science.ade2574).
28
+
29
+ **Accuracy Option**: The **Use higher accuracy** option applies a masked-marginals scoring strategy, which considers sequence context during inference.
30
+ While this method is slower, it enhances accuracy. If you experience long runtimes, unchecking this option can significantly speed up calculations at the cost of some accuracy.
31
+
32
+ **Deep Mutational Scan Recommendations**: When performing a deep mutational scan, it is advisable to use smaller models (8M, 35M, or 150M parameters) due to significant runtime concerns—especially with longer sequences or during peak server usage times.
33
+ For example, calculating a 300-residue-long sequence with larger models may require over 30 minutes.
34
+ Generally, accuracy is more affected by the scoring strategy than by model size; therefore, prioritise reducing model size when optimizing for runtime.
35
+ The computational cost of the scoring strategy scales with the number of substitutions tested, while model cost scales with wild-type sequence length.
36
+
37
+ **Concurrent Substitutions**: To calculate the effect of multiple concurrent substitutions, you must manually change the input sequence and rerun the calculation. Accuracy is not guaranteed as this use case is yet untested.
38
+
39
+ ## Output
40
+
41
+ Results are displayed in a color-coded table, except for deep mutational scans, which produce a heatmap.
42
+ In the table:
43
+
44
+ - Beneficial substitutions are highlighted in blue with positive values.
45
+ - Detrimental substitutions appear in red with negative values.
46
+
47
+ As a rule of thumb, score differences of *4* or more are considered significant. For instance:
48
+
49
+ - A substitution scoring *-6* is likely detrimental to protein functionality.
50
+ - A score of *+2* is generally regarded as neutral.
51
+
52
+ You can download the output raw data from the **button at the bottom of the page.
53
+
54
+ <b>
55
+ If you use this tool in your research, please cite:
56
+
57
+ - Totaro, M.G. (2023). “ESM-Scan - a tool to guide amino acid substitutions.” bioRxiv. [doi.org/10.1101/2023.12.12.571273](https://doi.org/10.1101/2023.12.12.571273)
58
+ - Meier, J. (2021). “Language Models Enable Zero-Shot Prediction of the Effects of Mutations on Protein Function.” bioRxiv (Cold Spring Harbor Laboratory), July. [doi.org/10.1101/2021.07.09.450648](https://doi.org/10.1101/2021.07.09.450648)
59
+ </b>
60
+
61
+ </details>
model.py CHANGED
@@ -1,5 +1,6 @@
1
- from huggingface_hub import HfApi, ModelFilter
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
4
  from transformers.tokenization_utils_base import BatchEncoding
5
  from transformers.modeling_outputs import MaskedLMOutput
@@ -10,9 +11,9 @@ def get_models() -> list[None|str]:
10
  if not any(
11
  out := [
12
  m.modelId for m in HfApi().list_models(
13
- filter=ModelFilter(
14
- author="facebook", model_name="esm", task="fill-mask"
15
- ),
16
  sort="lastModified",
17
  direction=-1
18
  )
@@ -34,6 +35,9 @@ class Model:
34
  # Check if CUDA is available and if so, use it
35
  if torch.cuda.is_available():
36
  self.model = self.model.cuda()
 
 
 
37
 
38
  def tokenise(self, input: str) -> BatchEncoding:
39
  """Convert input string to batch of tokens."""
@@ -41,7 +45,7 @@ class Model:
41
 
42
  def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
43
  """Run model on batch of tokens."""
44
- return self.model(batch_tokens, **kwargs)
45
 
46
  def __getitem__(self, key: str) -> int:
47
  """Get token ID from character."""
@@ -70,7 +74,7 @@ class Model:
70
  if data.scoring_strategy.startswith("masked-marginals"):
71
  all_token_probs = []
72
  # For each token in the batch
73
- for i in range(batch_tokens.size()[1]):
74
  # If the token is in the list of residues
75
  if i in data.resi:
76
  # Clone the batch tokens and mask the current token
@@ -96,4 +100,4 @@ class Model:
96
  token_probs,
97
  ),
98
  axis=1,
99
- )
 
1
+ from huggingface_hub import HfApi
2
  import torch
3
+ from tqdm import tqdm
4
  from transformers import AutoTokenizer, AutoModelForMaskedLM
5
  from transformers.tokenization_utils_base import BatchEncoding
6
  from transformers.modeling_outputs import MaskedLMOutput
 
11
  if not any(
12
  out := [
13
  m.modelId for m in HfApi().list_models(
14
+ author="facebook",
15
+ model_name="esm",
16
+ task="fill-mask",
17
  sort="lastModified",
18
  direction=-1
19
  )
 
35
  # Check if CUDA is available and if so, use it
36
  if torch.cuda.is_available():
37
  self.model = self.model.cuda()
38
+ self.device = torch.device("cuda")
39
+ else:
40
+ self.device = torch.device("cpu")
41
 
42
  def tokenise(self, input: str) -> BatchEncoding:
43
  """Convert input string to batch of tokens."""
 
45
 
46
  def __call__(self, batch_tokens: torch.Tensor, **kwargs) -> MaskedLMOutput:
47
  """Run model on batch of tokens."""
48
+ return self.model(batch_tokens.to(self.device), **kwargs)
49
 
50
  def __getitem__(self, key: str) -> int:
51
  """Get token ID from character."""
 
74
  if data.scoring_strategy.startswith("masked-marginals"):
75
  all_token_probs = []
76
  # For each token in the batch
77
+ for i in tqdm(range(batch_tokens.size()[1])):
78
  # If the token is in the list of residues
79
  if i in data.resi:
80
  # Clone the batch tokens and mask the current token
 
100
  token_probs,
101
  ),
102
  axis=1,
103
+ )
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio
2
  pandas
3
  seaborn
 
1
+ dataframe-image
2
  gradio
3
  pandas
4
  seaborn