Anton Bushuiev commited on
Commit
a82c17b
·
1 Parent(s): 5626a5b

Implement basic 3DMol.js

Browse files
Files changed (1) hide show
  1. app.py +137 -7
app.py CHANGED
@@ -1,17 +1,22 @@
1
- import shutil
2
  import tempfile
3
  from pathlib import Path
4
  from functools import partial
5
 
6
  import gradio as gr
7
- import numpy as np
8
  import torch
 
 
 
9
 
10
  from mutils.pdb import download_pdb
 
11
  from ppiref.extraction import PPIExtractor
12
  from ppiref.utils.ppi import PPIPath
 
13
  from ppiformer.tasks.node import DDGPPIformer
14
  from ppiformer.utils.api import predict_ddg
 
15
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
16
 
17
 
@@ -63,26 +68,149 @@ def process_inputs(inputs, temp_dir):
63
 
64
  muts = list(map(lambda x: x.strip(), muts.split(';')))
65
 
66
- return ppi_path, muts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  def predict(models, temp_dir, *inputs):
70
  # Process input
71
- ppi_path, muts = process_inputs(inputs, temp_dir)
72
 
73
  print(ppi_path, muts)
74
 
75
  # Predict
76
  ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
77
 
 
78
  ddg = ddg.detach().numpy().tolist()
79
  df = list(zip(muts, ddg))
80
-
81
- return df
 
 
 
82
 
83
 
84
  app = gr.Blocks()
85
  with app:
 
 
 
86
 
87
  # Input GUI
88
  with gr.Row():
@@ -114,6 +242,7 @@ with app:
114
  datatype=["str", "number"],
115
  col_count=(2, "fixed"),
116
  )
 
117
 
118
  # Load models
119
  models = [
@@ -130,7 +259,8 @@ with app:
130
 
131
  # Main logic
132
  inputs = [pdb_code, pdb_path, partners, muts, muts_path]
 
133
  predict = partial(predict, models, temp_dir)
134
- predict_button.click(predict, inputs=inputs, outputs=df)
135
 
136
  app.launch()
 
1
+ import copy
2
  import tempfile
3
  from pathlib import Path
4
  from functools import partial
5
 
6
  import gradio as gr
 
7
  import torch
8
+ import py3Dmol
9
+ from biopandas.pdb import PandasPdb
10
+ from colour import Color
11
 
12
  from mutils.pdb import download_pdb
13
+ from mutils.mutations import Mutation
14
  from ppiref.extraction import PPIExtractor
15
  from ppiref.utils.ppi import PPIPath
16
+ from ppiref.utils.residue import Residue
17
  from ppiformer.tasks.node import DDGPPIformer
18
  from ppiformer.utils.api import predict_ddg
19
+ from ppiformer.utils.torch import fill_diagonal
20
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
21
 
22
 
 
68
 
69
  muts = list(map(lambda x: x.strip(), muts.split(';')))
70
 
71
+ return pdb_path, ppi_path, muts
72
+
73
+
74
+ def plot_3dmol(pdb_path, ppi_path, muts, attn):
75
+ # 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
76
+
77
+ # Read PDB for 3Dmol.js
78
+ with open(pdb_path, "r") as fp:
79
+ lines = fp.readlines()
80
+ mol = ""
81
+ for l in lines:
82
+ mol += l
83
+ mol = mol.replace("OT1", "O ")
84
+ mol = mol.replace("OT2", "OXT")
85
+
86
+ # Read PPI to customize 3Dmol.js visualization
87
+ ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
88
+ ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
89
+ chains = ppi_df['chain_id'].unique()
90
+ ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
91
+ ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
92
+ muts_id = sum([Mutation(mut).wt_to_graphein() for mut in muts], start=[]) # flatten ids of all sp muts
93
+ ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
94
+
95
+ # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
96
+ attn = torch.nan_to_num(attn, nan=1e-10)
97
+ attn_sub = attn[:, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
98
+ idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
99
+ attn_sub = fill_diagonal(attn_sub, 1e-10)
100
+ attn_mutated = attn_sub[..., idx_mutated, :]
101
+ attn_mutated.shape
102
+ attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3))
103
+ attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min())
104
+ attns_per_token += 1e-10
105
+ ppi_df['attn'] = attns_per_token.numpy()
106
+
107
+ # Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
108
+ styles = []
109
+ zoom_atoms = []
110
+
111
+ # Cartoon chains
112
+ colors = [Color(c) for c in ['LimeGreen', 'HotPink', 'RoyalBlue']]
113
+ chain_to_color = dict(zip(chains, colors))
114
+ for chain in chains:
115
+ styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}])
116
+
117
+ # Stick PPI and atoms for zoom
118
+ # TODO Insertions
119
+ for _, row in ppi_df.iterrows():
120
+ color = copy.deepcopy(chain_to_color[row['chain_id']])
121
+ color.saturation = row['attn']
122
+ color = color.hex_l
123
+ if row['mutated']:
124
+ styles.append([{'chain': row['chain_id'], 'resi': str(row['residue_number'])}, {'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}}])
125
+ zoom_atoms.append(row['atom_number'])
126
+ else:
127
+ styles.append([{'chain': row['chain_id'], 'resi': str(row['residue_number'])}, {'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}}])
128
+
129
+ # Convert style dicts to JS lines
130
+ styles = '\n'.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');' for dcts in styles])
131
+
132
+ # Connert zoom atoms to 3DMol.js selection
133
+ zoom = 'viewer.zoomTo({\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}, 1000);'
134
+
135
+ # Construct 3Dmol.js visualization script in HTML
136
+ html = (
137
+ """<!DOCTYPE html>
138
+ <html>
139
+ <head>
140
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
141
+ <style>
142
+ body{
143
+ font-family:sans-serif
144
+ }
145
+ .mol-container {
146
+ width: 100%;
147
+ height: 600px;
148
+ position: relative;
149
+ }
150
+ .mol-container select{
151
+ background-image:None;
152
+ }
153
+ </style>
154
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
155
+ <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
156
+ </head>
157
+ <body>
158
+ <div id="container" class="mol-container"></div>
159
+
160
+ <script>
161
+ let pdb = `"""
162
+ + mol
163
+ + """`
164
+
165
+ $(document).ready(function () {
166
+ let element = $("#container");
167
+ let config = { backgroundColor: "white" };
168
+ let viewer = $3Dmol.createViewer(element, config);
169
+ viewer.addModel(pdb, "pdb");
170
+ viewer.setBackgroundColor("black");
171
+ viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
172
+ """
173
+ + styles
174
+ + zoom
175
+ + """
176
+ viewer.render();
177
+ })
178
+ </script>
179
+ </body></html>"""
180
+ )
181
+ print(html)
182
+
183
+ return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
184
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
185
+ allow-scripts allow-same-origin allow-popups
186
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
187
+ allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
188
 
189
 
190
  def predict(models, temp_dir, *inputs):
191
  # Process input
192
+ pdb_path, ppi_path, muts = process_inputs(inputs, temp_dir)
193
 
194
  print(ppi_path, muts)
195
 
196
  # Predict
197
  ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
198
 
199
+ # Create dataframe
200
  ddg = ddg.detach().numpy().tolist()
201
  df = list(zip(muts, ddg))
202
+
203
+ # Create 3DMol plot
204
+ plot = plot_3dmol(pdb_path, ppi_path, muts, attn)
205
+
206
+ return df, plot
207
 
208
 
209
  app = gr.Blocks()
210
  with app:
211
+ # print('app.theme.background_fill_primary', app.theme.background_fill_primary, type(app.theme.background_fill_primary))
212
+ # print('app.theme.background_fill_primary', app.theme.background_fill_primary_dark, type(app.theme.background_fill_primary))
213
+ # print(app.theme.to_dict())
214
 
215
  # Input GUI
216
  with gr.Row():
 
242
  datatype=["str", "number"],
243
  col_count=(2, "fixed"),
244
  )
245
+ plot = gr.HTML()
246
 
247
  # Load models
248
  models = [
 
259
 
260
  # Main logic
261
  inputs = [pdb_code, pdb_path, partners, muts, muts_path]
262
+ outputs = [df, plot]
263
  predict = partial(predict, models, temp_dir)
264
+ predict_button.click(predict, inputs=inputs, outputs=outputs)
265
 
266
  app.launch()