Spaces:
Running
on
Zero
Running
on
Zero
Anton Bushuiev
commited on
Commit
·
a82c17b
1
Parent(s):
5626a5b
Implement basic 3DMol.js
Browse files
app.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1 |
-
import
|
2 |
import tempfile
|
3 |
from pathlib import Path
|
4 |
from functools import partial
|
5 |
|
6 |
import gradio as gr
|
7 |
-
import numpy as np
|
8 |
import torch
|
|
|
|
|
|
|
9 |
|
10 |
from mutils.pdb import download_pdb
|
|
|
11 |
from ppiref.extraction import PPIExtractor
|
12 |
from ppiref.utils.ppi import PPIPath
|
|
|
13 |
from ppiformer.tasks.node import DDGPPIformer
|
14 |
from ppiformer.utils.api import predict_ddg
|
|
|
15 |
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
|
16 |
|
17 |
|
@@ -63,26 +68,149 @@ def process_inputs(inputs, temp_dir):
|
|
63 |
|
64 |
muts = list(map(lambda x: x.strip(), muts.split(';')))
|
65 |
|
66 |
-
return ppi_path, muts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
|
69 |
def predict(models, temp_dir, *inputs):
|
70 |
# Process input
|
71 |
-
ppi_path, muts = process_inputs(inputs, temp_dir)
|
72 |
|
73 |
print(ppi_path, muts)
|
74 |
|
75 |
# Predict
|
76 |
ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
|
77 |
|
|
|
78 |
ddg = ddg.detach().numpy().tolist()
|
79 |
df = list(zip(muts, ddg))
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
app = gr.Blocks()
|
85 |
with app:
|
|
|
|
|
|
|
86 |
|
87 |
# Input GUI
|
88 |
with gr.Row():
|
@@ -114,6 +242,7 @@ with app:
|
|
114 |
datatype=["str", "number"],
|
115 |
col_count=(2, "fixed"),
|
116 |
)
|
|
|
117 |
|
118 |
# Load models
|
119 |
models = [
|
@@ -130,7 +259,8 @@ with app:
|
|
130 |
|
131 |
# Main logic
|
132 |
inputs = [pdb_code, pdb_path, partners, muts, muts_path]
|
|
|
133 |
predict = partial(predict, models, temp_dir)
|
134 |
-
predict_button.click(predict, inputs=inputs, outputs=
|
135 |
|
136 |
app.launch()
|
|
|
1 |
+
import copy
|
2 |
import tempfile
|
3 |
from pathlib import Path
|
4 |
from functools import partial
|
5 |
|
6 |
import gradio as gr
|
|
|
7 |
import torch
|
8 |
+
import py3Dmol
|
9 |
+
from biopandas.pdb import PandasPdb
|
10 |
+
from colour import Color
|
11 |
|
12 |
from mutils.pdb import download_pdb
|
13 |
+
from mutils.mutations import Mutation
|
14 |
from ppiref.extraction import PPIExtractor
|
15 |
from ppiref.utils.ppi import PPIPath
|
16 |
+
from ppiref.utils.residue import Residue
|
17 |
from ppiformer.tasks.node import DDGPPIformer
|
18 |
from ppiformer.utils.api import predict_ddg
|
19 |
+
from ppiformer.utils.torch import fill_diagonal
|
20 |
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
|
21 |
|
22 |
|
|
|
68 |
|
69 |
muts = list(map(lambda x: x.strip(), muts.split(';')))
|
70 |
|
71 |
+
return pdb_path, ppi_path, muts
|
72 |
+
|
73 |
+
|
74 |
+
def plot_3dmol(pdb_path, ppi_path, muts, attn):
|
75 |
+
# 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
|
76 |
+
|
77 |
+
# Read PDB for 3Dmol.js
|
78 |
+
with open(pdb_path, "r") as fp:
|
79 |
+
lines = fp.readlines()
|
80 |
+
mol = ""
|
81 |
+
for l in lines:
|
82 |
+
mol += l
|
83 |
+
mol = mol.replace("OT1", "O ")
|
84 |
+
mol = mol.replace("OT2", "OXT")
|
85 |
+
|
86 |
+
# Read PPI to customize 3Dmol.js visualization
|
87 |
+
ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
|
88 |
+
ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
|
89 |
+
chains = ppi_df['chain_id'].unique()
|
90 |
+
ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
|
91 |
+
ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
|
92 |
+
muts_id = sum([Mutation(mut).wt_to_graphein() for mut in muts], start=[]) # flatten ids of all sp muts
|
93 |
+
ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
|
94 |
+
|
95 |
+
# Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
|
96 |
+
attn = torch.nan_to_num(attn, nan=1e-10)
|
97 |
+
attn_sub = attn[:, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
|
98 |
+
idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
|
99 |
+
attn_sub = fill_diagonal(attn_sub, 1e-10)
|
100 |
+
attn_mutated = attn_sub[..., idx_mutated, :]
|
101 |
+
attn_mutated.shape
|
102 |
+
attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3))
|
103 |
+
attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min())
|
104 |
+
attns_per_token += 1e-10
|
105 |
+
ppi_df['attn'] = attns_per_token.numpy()
|
106 |
+
|
107 |
+
# Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
|
108 |
+
styles = []
|
109 |
+
zoom_atoms = []
|
110 |
+
|
111 |
+
# Cartoon chains
|
112 |
+
colors = [Color(c) for c in ['LimeGreen', 'HotPink', 'RoyalBlue']]
|
113 |
+
chain_to_color = dict(zip(chains, colors))
|
114 |
+
for chain in chains:
|
115 |
+
styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}])
|
116 |
+
|
117 |
+
# Stick PPI and atoms for zoom
|
118 |
+
# TODO Insertions
|
119 |
+
for _, row in ppi_df.iterrows():
|
120 |
+
color = copy.deepcopy(chain_to_color[row['chain_id']])
|
121 |
+
color.saturation = row['attn']
|
122 |
+
color = color.hex_l
|
123 |
+
if row['mutated']:
|
124 |
+
styles.append([{'chain': row['chain_id'], 'resi': str(row['residue_number'])}, {'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}}])
|
125 |
+
zoom_atoms.append(row['atom_number'])
|
126 |
+
else:
|
127 |
+
styles.append([{'chain': row['chain_id'], 'resi': str(row['residue_number'])}, {'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}}])
|
128 |
+
|
129 |
+
# Convert style dicts to JS lines
|
130 |
+
styles = '\n'.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');' for dcts in styles])
|
131 |
+
|
132 |
+
# Connert zoom atoms to 3DMol.js selection
|
133 |
+
zoom = 'viewer.zoomTo({\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}, 1000);'
|
134 |
+
|
135 |
+
# Construct 3Dmol.js visualization script in HTML
|
136 |
+
html = (
|
137 |
+
"""<!DOCTYPE html>
|
138 |
+
<html>
|
139 |
+
<head>
|
140 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
141 |
+
<style>
|
142 |
+
body{
|
143 |
+
font-family:sans-serif
|
144 |
+
}
|
145 |
+
.mol-container {
|
146 |
+
width: 100%;
|
147 |
+
height: 600px;
|
148 |
+
position: relative;
|
149 |
+
}
|
150 |
+
.mol-container select{
|
151 |
+
background-image:None;
|
152 |
+
}
|
153 |
+
</style>
|
154 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
155 |
+
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
|
156 |
+
</head>
|
157 |
+
<body>
|
158 |
+
<div id="container" class="mol-container"></div>
|
159 |
+
|
160 |
+
<script>
|
161 |
+
let pdb = `"""
|
162 |
+
+ mol
|
163 |
+
+ """`
|
164 |
+
|
165 |
+
$(document).ready(function () {
|
166 |
+
let element = $("#container");
|
167 |
+
let config = { backgroundColor: "white" };
|
168 |
+
let viewer = $3Dmol.createViewer(element, config);
|
169 |
+
viewer.addModel(pdb, "pdb");
|
170 |
+
viewer.setBackgroundColor("black");
|
171 |
+
viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
|
172 |
+
"""
|
173 |
+
+ styles
|
174 |
+
+ zoom
|
175 |
+
+ """
|
176 |
+
viewer.render();
|
177 |
+
})
|
178 |
+
</script>
|
179 |
+
</body></html>"""
|
180 |
+
)
|
181 |
+
print(html)
|
182 |
+
|
183 |
+
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
|
184 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
185 |
+
allow-scripts allow-same-origin allow-popups
|
186 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
187 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
|
188 |
|
189 |
|
190 |
def predict(models, temp_dir, *inputs):
|
191 |
# Process input
|
192 |
+
pdb_path, ppi_path, muts = process_inputs(inputs, temp_dir)
|
193 |
|
194 |
print(ppi_path, muts)
|
195 |
|
196 |
# Predict
|
197 |
ddg, attn = predict_ddg(models, ppi_path, muts, return_attn=True)
|
198 |
|
199 |
+
# Create dataframe
|
200 |
ddg = ddg.detach().numpy().tolist()
|
201 |
df = list(zip(muts, ddg))
|
202 |
+
|
203 |
+
# Create 3DMol plot
|
204 |
+
plot = plot_3dmol(pdb_path, ppi_path, muts, attn)
|
205 |
+
|
206 |
+
return df, plot
|
207 |
|
208 |
|
209 |
app = gr.Blocks()
|
210 |
with app:
|
211 |
+
# print('app.theme.background_fill_primary', app.theme.background_fill_primary, type(app.theme.background_fill_primary))
|
212 |
+
# print('app.theme.background_fill_primary', app.theme.background_fill_primary_dark, type(app.theme.background_fill_primary))
|
213 |
+
# print(app.theme.to_dict())
|
214 |
|
215 |
# Input GUI
|
216 |
with gr.Row():
|
|
|
242 |
datatype=["str", "number"],
|
243 |
col_count=(2, "fixed"),
|
244 |
)
|
245 |
+
plot = gr.HTML()
|
246 |
|
247 |
# Load models
|
248 |
models = [
|
|
|
259 |
|
260 |
# Main logic
|
261 |
inputs = [pdb_code, pdb_path, partners, muts, muts_path]
|
262 |
+
outputs = [df, plot]
|
263 |
predict = partial(predict, models, temp_dir)
|
264 |
+
predict_button.click(predict, inputs=inputs, outputs=outputs)
|
265 |
|
266 |
app.launch()
|