Spaces:
Sleeping
Sleeping
Pocket-conditioned generation
Browse files- app.py +331 -61
- examples/3hz1_fragments.sdf +54 -0
- examples/3hz1_protein.pdb +0 -0
- examples/5ou2_fragments.sdf +56 -0
- examples/5ou2_protein.pdb +0 -0
- output.py +184 -9
- src/datasets.py +14 -4
- src/generation.py +83 -4
- src/lightning.py +7 -4
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import argparse
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
@@ -9,10 +10,12 @@ import output
|
|
9 |
|
10 |
from rdkit import Chem
|
11 |
from src import const
|
12 |
-
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
|
13 |
from src.lightning import DDPM
|
14 |
from src.linker_size_lightning import SizeClassifier
|
15 |
-
from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf
|
|
|
|
|
16 |
|
17 |
MODELS_METADATA = {
|
18 |
'geom_difflinker': {
|
@@ -85,65 +88,167 @@ def read_molecule(path):
|
|
85 |
raise Exception('Unknown file extension')
|
86 |
|
87 |
|
88 |
-
def
|
89 |
-
if
|
90 |
-
|
91 |
-
if isinstance(input_file, str):
|
92 |
-
path = input_file
|
93 |
else:
|
94 |
-
path =
|
95 |
extension = path.split('.')[-1]
|
96 |
-
|
|
|
97 |
msg = output.INVALID_FORMAT_MSG.format(extension=extension)
|
98 |
-
return
|
99 |
-
output.IFRAME_TEMPLATE.format(html=msg),
|
100 |
-
gr.Radio.update(visible=False),
|
101 |
-
None,
|
102 |
-
]
|
103 |
|
104 |
try:
|
105 |
-
|
106 |
except Exception as e:
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
|
121 |
def draw_sample(idx, out_files):
|
|
|
|
|
122 |
if isinstance(idx, str):
|
123 |
idx = int(idx.strip().split(' ')[-1]) - 1
|
124 |
|
125 |
-
in_file = out_files[
|
126 |
in_sdf = in_file if isinstance(in_file, str) else in_file.name
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
|
|
|
132 |
generated_molecule_content = read_molecule_content(out_sdf)
|
133 |
-
|
134 |
-
fragments_fmt = in_sdf.split('.')[-1]
|
135 |
molecule_fmt = out_sdf.split('.')[-1]
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
return output.IFRAME_TEMPLATE.format(html=html)
|
144 |
|
145 |
|
146 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
# Parsing selected atoms (javascript output)
|
148 |
selected_atoms = selected_atoms.strip()
|
149 |
if selected_atoms == '':
|
@@ -157,9 +262,6 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
|
157 |
else:
|
158 |
selected_model_name = 'geom_difflinker_given_anchors'
|
159 |
|
160 |
-
if input_file is None:
|
161 |
-
return [None, None, None, None]
|
162 |
-
|
163 |
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
|
164 |
ddpm = diffusion_models[selected_model_name]
|
165 |
path = input_file.name
|
@@ -170,20 +272,25 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
|
170 |
|
171 |
try:
|
172 |
molecule = read_molecule(path)
|
173 |
-
|
|
|
|
|
|
|
174 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
175 |
inp_sdf = f'results/input_{name}.sdf'
|
176 |
except Exception as e:
|
|
|
177 |
error = f'Could not read the molecule: {e}'
|
178 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
179 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
180 |
|
181 |
-
if molecule.GetNumAtoms() >
|
182 |
-
error = f'Too large molecule: upper limit is
|
183 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
184 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
185 |
|
186 |
with Chem.SDWriter(inp_sdf) as w:
|
|
|
187 |
w.write(molecule)
|
188 |
|
189 |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
@@ -227,14 +334,152 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
|
227 |
|
228 |
for data in dataloader:
|
229 |
try:
|
230 |
-
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name)
|
231 |
except Exception as e:
|
|
|
232 |
error = f'Caught exception while generating linkers: {e}'
|
233 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
234 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
235 |
|
236 |
out_files = try_to_convert_to_sdf(name)
|
237 |
out_files = [inp_sdf] + out_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
|
239 |
return [
|
240 |
draw_sample(radio_samples, out_files),
|
@@ -260,19 +505,34 @@ with demo:
|
|
260 |
with gr.Box():
|
261 |
with gr.Row():
|
262 |
with gr.Column():
|
263 |
-
gr.Markdown('## Input
|
264 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
n_atoms = gr.Slider(
|
268 |
minimum=0, maximum=20,
|
269 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
270 |
step=1
|
271 |
)
|
272 |
examples = gr.Dataset(
|
273 |
-
components=[gr.File(visible=False)],
|
274 |
-
samples=[
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
)
|
277 |
|
278 |
button = gr.Button('Generate Linker!')
|
@@ -294,24 +554,34 @@ with demo:
|
|
294 |
)
|
295 |
visualization = gr.HTML()
|
296 |
|
297 |
-
|
298 |
fn=show_input,
|
299 |
-
inputs=[
|
300 |
outputs=[visualization, samples, hidden],
|
301 |
)
|
302 |
-
|
303 |
-
fn=
|
304 |
-
inputs=[],
|
305 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
)
|
307 |
examples.click(
|
308 |
-
fn=
|
309 |
inputs=[examples],
|
310 |
-
outputs=[
|
311 |
)
|
312 |
button.click(
|
313 |
fn=generate,
|
314 |
-
inputs=[
|
315 |
outputs=[visualization, output_files, samples, hidden],
|
316 |
_js=output.RETURN_SELECTION_JS,
|
317 |
)
|
|
|
1 |
import argparse
|
2 |
+
import shutil
|
3 |
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
|
|
10 |
|
11 |
from rdkit import Chem
|
12 |
from src import const
|
13 |
+
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
|
14 |
from src.lightning import DDPM
|
15 |
from src.linker_size_lightning import SizeClassifier
|
16 |
+
from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf, get_pocket
|
17 |
+
from zipfile import ZipFile
|
18 |
+
|
19 |
|
20 |
MODELS_METADATA = {
|
21 |
'geom_difflinker': {
|
|
|
88 |
raise Exception('Unknown file extension')
|
89 |
|
90 |
|
91 |
+
def read_molecule_file(in_file, allowed_extentions):
|
92 |
+
if isinstance(in_file, str):
|
93 |
+
path = in_file
|
|
|
|
|
94 |
else:
|
95 |
+
path = in_file.name
|
96 |
extension = path.split('.')[-1]
|
97 |
+
|
98 |
+
if extension not in allowed_extentions:
|
99 |
msg = output.INVALID_FORMAT_MSG.format(extension=extension)
|
100 |
+
return None, None, msg
|
|
|
|
|
|
|
|
|
101 |
|
102 |
try:
|
103 |
+
mol = read_molecule(path)
|
104 |
except Exception as e:
|
105 |
+
e = str(e).replace('\'', '')
|
106 |
+
msg = output.ERROR_FORMAT_MSG.format(message=e)
|
107 |
+
return None, None, msg
|
108 |
+
|
109 |
+
if extension == 'pdb':
|
110 |
+
content = Chem.MolToPDBBlock(mol)
|
111 |
+
elif extension in ['mol', 'mol2', 'sdf']:
|
112 |
+
content = Chem.MolToMolBlock(mol, kekulize=False)
|
113 |
+
extension = 'mol'
|
114 |
+
else:
|
115 |
+
raise NotImplementedError
|
116 |
|
117 |
+
return content, extension, None
|
118 |
+
|
119 |
+
|
120 |
+
def show_input(in_fragments, in_protein):
|
121 |
+
vis = ''
|
122 |
+
if in_fragments is not None and in_protein is None:
|
123 |
+
vis = show_fragments(in_fragments)
|
124 |
+
elif in_fragments is None and in_protein is not None:
|
125 |
+
vis = show_target(in_protein)
|
126 |
+
elif in_fragments is not None and in_protein is not None:
|
127 |
+
vis = show_fragments_and_target(in_fragments, in_protein)
|
128 |
+
return [vis, gr.Radio.update(visible=False), None]
|
129 |
+
|
130 |
+
|
131 |
+
def show_fragments(in_fragments):
|
132 |
+
molecule, extension, html = read_molecule_file(in_fragments, allowed_extentions=['sdf', 'pdb', 'mol', 'mol2'])
|
133 |
+
if molecule is not None:
|
134 |
+
html = output.FRAGMENTS_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
|
135 |
+
|
136 |
+
return output.IFRAME_TEMPLATE.format(html=html)
|
137 |
+
|
138 |
+
|
139 |
+
def show_target(in_protein):
|
140 |
+
molecule, extension, html = read_molecule_file(in_protein, allowed_extentions=['pdb'])
|
141 |
+
if molecule is not None:
|
142 |
+
html = output.TARGET_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
|
143 |
+
|
144 |
+
return output.IFRAME_TEMPLATE.format(html=html)
|
145 |
+
|
146 |
+
|
147 |
+
def show_fragments_and_target(in_fragments, in_protein):
|
148 |
+
fragments_molecule, fragments_extension, msg = read_molecule_file(in_fragments, ['sdf', 'pdb', 'mol', 'mol2'])
|
149 |
+
if fragments_molecule is None:
|
150 |
+
return output.IFRAME_TEMPLATE.format(html=msg)
|
151 |
+
|
152 |
+
target_molecule, target_extension, msg = read_molecule_file(in_protein, allowed_extentions=['pdb'])
|
153 |
+
if fragments_molecule is None:
|
154 |
+
return output.IFRAME_TEMPLATE.format(html=msg)
|
155 |
+
|
156 |
+
html = output.FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE.format(
|
157 |
+
molecule=fragments_molecule,
|
158 |
+
fmt=fragments_extension,
|
159 |
+
target=target_molecule,
|
160 |
+
target_fmt=target_extension,
|
161 |
+
)
|
162 |
+
|
163 |
+
return output.IFRAME_TEMPLATE.format(html=html)
|
164 |
+
|
165 |
+
|
166 |
+
def clear_fragments_input(in_protein):
|
167 |
+
vis = ''
|
168 |
+
if in_protein is not None:
|
169 |
+
vis = show_target(in_protein)
|
170 |
+
return [None, vis, gr.Radio.update(visible=False), None]
|
171 |
+
|
172 |
+
|
173 |
+
def clear_protein_input(in_fragments):
|
174 |
+
vis = ''
|
175 |
+
if in_fragments is not None:
|
176 |
+
vis = show_fragments(in_fragments)
|
177 |
+
return [None, vis, gr.Radio.update(visible=False), None]
|
178 |
+
|
179 |
+
|
180 |
+
def click_on_example(example):
|
181 |
+
print('Clicked:', example)
|
182 |
+
fragment_fname, target_fname = example
|
183 |
+
fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
|
184 |
+
target_path = f'examples/{target_fname}' if target_fname != '' else None
|
185 |
+
return [fragment_path, target_path, 50, 0] + show_input(fragment_path, target_path)
|
186 |
|
187 |
|
188 |
def draw_sample(idx, out_files):
|
189 |
+
with_protein = (len(out_files) == N_SAMPLES + 3)
|
190 |
+
|
191 |
if isinstance(idx, str):
|
192 |
idx = int(idx.strip().split(' ')[-1]) - 1
|
193 |
|
194 |
+
in_file = out_files[1]
|
195 |
in_sdf = in_file if isinstance(in_file, str) else in_file.name
|
196 |
+
input_fragments_content = read_molecule_content(in_sdf)
|
197 |
+
fragments_fmt = in_sdf.split('.')[-1]
|
198 |
|
199 |
+
offset = 2
|
200 |
+
input_target_content = None
|
201 |
+
target_fmt = None
|
202 |
+
if with_protein:
|
203 |
+
offset += 1
|
204 |
+
in_pdb = out_files[2] if isinstance(out_files[2], str) else out_files[2].name
|
205 |
+
input_target_content = read_molecule_content(in_pdb)
|
206 |
+
target_fmt = in_pdb.split('.')[-1]
|
207 |
|
208 |
+
out_file = out_files[idx + offset]
|
209 |
+
out_sdf = out_file if isinstance(out_file, str) else out_file.name
|
210 |
generated_molecule_content = read_molecule_content(out_sdf)
|
|
|
|
|
211 |
molecule_fmt = out_sdf.split('.')[-1]
|
212 |
|
213 |
+
if with_protein:
|
214 |
+
html = output.SAMPLES_WITH_TARGET_RENDERING_TEMPLATE.format(
|
215 |
+
fragments=input_fragments_content,
|
216 |
+
fragments_fmt=fragments_fmt,
|
217 |
+
molecule=generated_molecule_content,
|
218 |
+
molecule_fmt=molecule_fmt,
|
219 |
+
target=input_target_content,
|
220 |
+
target_fmt=target_fmt,
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
html = output.SAMPLES_RENDERING_TEMPLATE.format(
|
224 |
+
fragments=input_fragments_content,
|
225 |
+
fragments_fmt=fragments_fmt,
|
226 |
+
molecule=generated_molecule_content,
|
227 |
+
molecule_fmt=molecule_fmt,
|
228 |
+
)
|
229 |
return output.IFRAME_TEMPLATE.format(html=html)
|
230 |
|
231 |
|
232 |
+
def compress(output_fnames, name):
|
233 |
+
archive_path = f'results/all_files_{name}.zip'
|
234 |
+
with ZipFile(archive_path, 'w') as archive:
|
235 |
+
for fname in output_fnames:
|
236 |
+
archive.write(fname)
|
237 |
+
|
238 |
+
return archive_path
|
239 |
+
|
240 |
+
|
241 |
+
def generate(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
|
242 |
+
if in_fragments is None:
|
243 |
+
return [None, None, None, None]
|
244 |
+
|
245 |
+
if in_protein is None:
|
246 |
+
return generate_without_pocket(in_fragments, n_steps, n_atoms, radio_samples, selected_atoms)
|
247 |
+
else:
|
248 |
+
return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms)
|
249 |
+
|
250 |
+
|
251 |
+
def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
|
252 |
# Parsing selected atoms (javascript output)
|
253 |
selected_atoms = selected_atoms.strip()
|
254 |
if selected_atoms == '':
|
|
|
262 |
else:
|
263 |
selected_model_name = 'geom_difflinker_given_anchors'
|
264 |
|
|
|
|
|
|
|
265 |
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
|
266 |
ddpm = diffusion_models[selected_model_name]
|
267 |
path = input_file.name
|
|
|
272 |
|
273 |
try:
|
274 |
molecule = read_molecule(path)
|
275 |
+
try:
|
276 |
+
molecule = Chem.RemoveAllHs(molecule)
|
277 |
+
except:
|
278 |
+
pass
|
279 |
name = '.'.join(path.split('/')[-1].split('.')[:-1])
|
280 |
inp_sdf = f'results/input_{name}.sdf'
|
281 |
except Exception as e:
|
282 |
+
e = str(e).replace('\'', '')
|
283 |
error = f'Could not read the molecule: {e}'
|
284 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
285 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
286 |
|
287 |
+
if molecule.GetNumAtoms() > 100:
|
288 |
+
error = f'Too large molecule: upper limit is 100 heavy atoms'
|
289 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
290 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
291 |
|
292 |
with Chem.SDWriter(inp_sdf) as w:
|
293 |
+
w.SetKekulize(False)
|
294 |
w.write(molecule)
|
295 |
|
296 |
positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
|
|
|
334 |
|
335 |
for data in dataloader:
|
336 |
try:
|
337 |
+
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
|
338 |
except Exception as e:
|
339 |
+
e = str(e).replace('\'', '')
|
340 |
error = f'Caught exception while generating linkers: {e}'
|
341 |
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
342 |
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
343 |
|
344 |
out_files = try_to_convert_to_sdf(name)
|
345 |
out_files = [inp_sdf] + out_files
|
346 |
+
out_files = [compress(out_files, name=name)] + out_files
|
347 |
+
|
348 |
+
return [
|
349 |
+
draw_sample(radio_samples, out_files),
|
350 |
+
out_files,
|
351 |
+
gr.Radio.update(visible=True),
|
352 |
+
None
|
353 |
+
]
|
354 |
+
|
355 |
+
|
356 |
+
def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
|
357 |
+
# Parsing selected atoms (javascript output)
|
358 |
+
selected_atoms = selected_atoms.strip()
|
359 |
+
if selected_atoms == '':
|
360 |
+
selected_atoms = []
|
361 |
+
else:
|
362 |
+
selected_atoms = list(map(int, selected_atoms.split(',')))
|
363 |
+
|
364 |
+
# Selecting model
|
365 |
+
if len(selected_atoms) == 0:
|
366 |
+
selected_model_name = 'pockets_difflinker'
|
367 |
+
else:
|
368 |
+
selected_model_name = 'pockets_difflinker_given_anchors'
|
369 |
+
|
370 |
+
print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
|
371 |
+
ddpm = diffusion_models[selected_model_name]
|
372 |
+
|
373 |
+
fragments_path = in_fragments.name
|
374 |
+
fragments_extension = fragments_path.split('.')[-1]
|
375 |
+
if fragments_extension not in ['sdf', 'pdb', 'mol', 'mol2']:
|
376 |
+
msg = output.INVALID_FORMAT_MSG.format(extension=fragments_extension)
|
377 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
378 |
+
|
379 |
+
protein_path = in_protein.name
|
380 |
+
protein_extension = protein_path.split('.')[-1]
|
381 |
+
if protein_extension not in ['pdb']:
|
382 |
+
msg = output.INVALID_FORMAT_MSG.format(extension=protein_extension)
|
383 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
384 |
+
|
385 |
+
try:
|
386 |
+
fragments_mol = read_molecule(fragments_path)
|
387 |
+
name = '.'.join(fragments_path.split('/')[-1].split('.')[:-1])
|
388 |
+
except Exception as e:
|
389 |
+
e = str(e).replace('\'', '')
|
390 |
+
error = f'Could not read the molecule: {e}'
|
391 |
+
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
392 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
393 |
+
|
394 |
+
if fragments_mol.GetNumAtoms() > 100:
|
395 |
+
error = f'Too large molecule: upper limit is 100 heavy atoms'
|
396 |
+
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
397 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
398 |
+
|
399 |
+
inp_sdf = f'results/input_{name}.sdf'
|
400 |
+
with Chem.SDWriter(inp_sdf) as w:
|
401 |
+
w.SetKekulize(False)
|
402 |
+
w.write(fragments_mol)
|
403 |
+
|
404 |
+
inp_pdb = f'results/target_{name}.pdb'
|
405 |
+
shutil.copy(protein_path, inp_pdb)
|
406 |
+
|
407 |
+
frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments_mol, is_geom=True)
|
408 |
+
pocket_pos, pocket_one_hot, pocket_charges = get_pocket(fragments_mol, protein_path)
|
409 |
+
print(f'Detected pocket with {len(pocket_pos)} atoms')
|
410 |
+
|
411 |
+
positions = np.concatenate([frag_pos, pocket_pos], axis=0)
|
412 |
+
one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0)
|
413 |
+
charges = np.concatenate([frag_charges, pocket_charges], axis=0)
|
414 |
+
anchors = np.zeros_like(charges)
|
415 |
+
anchors[selected_atoms] = 1
|
416 |
+
|
417 |
+
fragment_only_mask = np.concatenate([
|
418 |
+
np.ones_like(frag_charges),
|
419 |
+
np.zeros_like(pocket_charges),
|
420 |
+
])
|
421 |
+
pocket_mask = np.concatenate([
|
422 |
+
np.zeros_like(frag_charges),
|
423 |
+
np.ones_like(pocket_charges),
|
424 |
+
])
|
425 |
+
linker_mask = np.concatenate([
|
426 |
+
np.zeros_like(frag_charges),
|
427 |
+
np.zeros_like(pocket_charges),
|
428 |
+
])
|
429 |
+
fragment_mask = np.concatenate([
|
430 |
+
np.ones_like(frag_charges),
|
431 |
+
np.ones_like(pocket_charges),
|
432 |
+
])
|
433 |
+
print('Read and parsed molecule')
|
434 |
+
|
435 |
+
dataset = [{
|
436 |
+
'uuid': '0',
|
437 |
+
'name': '0',
|
438 |
+
'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
|
439 |
+
'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
|
440 |
+
'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
|
441 |
+
'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
|
442 |
+
'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
|
443 |
+
'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
|
444 |
+
'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
|
445 |
+
'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
|
446 |
+
'num_atoms': len(positions),
|
447 |
+
}] * N_SAMPLES
|
448 |
+
dataset = MOADDataset(data=dataset)
|
449 |
+
ddpm.val_dataset = dataset
|
450 |
+
|
451 |
+
dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
|
452 |
+
print('Created dataloader')
|
453 |
+
|
454 |
+
ddpm.edm.T = n_steps
|
455 |
+
|
456 |
+
if n_atoms == 0:
|
457 |
+
def sample_fn(_data):
|
458 |
+
out, _ = size_nn.forward(_data, return_loss=False)
|
459 |
+
probabilities = torch.softmax(out, dim=1)
|
460 |
+
distribution = torch.distributions.Categorical(probs=probabilities)
|
461 |
+
samples = distribution.sample()
|
462 |
+
sizes = []
|
463 |
+
for label in samples.detach().cpu().numpy():
|
464 |
+
sizes.append(size_nn.linker_id2size[label])
|
465 |
+
sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
|
466 |
+
return sizes
|
467 |
+
else:
|
468 |
+
def sample_fn(_data):
|
469 |
+
return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
|
470 |
+
|
471 |
+
for data in dataloader:
|
472 |
+
try:
|
473 |
+
generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=True)
|
474 |
+
except Exception as e:
|
475 |
+
e = str(e).replace('\'', '')
|
476 |
+
error = f'Caught exception while generating linkers: {e}'
|
477 |
+
msg = output.ERROR_FORMAT_MSG.format(message=error)
|
478 |
+
return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
|
479 |
+
|
480 |
+
out_files = try_to_convert_to_sdf(name)
|
481 |
+
out_files = [inp_sdf, inp_pdb] + out_files
|
482 |
+
out_files = [compress(out_files, name=name)] + out_files
|
483 |
|
484 |
return [
|
485 |
draw_sample(radio_samples, out_files),
|
|
|
505 |
with gr.Box():
|
506 |
with gr.Row():
|
507 |
with gr.Column():
|
508 |
+
gr.Markdown('## Input')
|
509 |
gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
|
510 |
+
with gr.Column():
|
511 |
+
input_fragments_file = gr.File(
|
512 |
+
file_count='single',
|
513 |
+
label='Input Fragments',
|
514 |
+
file_types=['.sdf', '.pdb', '.mol', '.mol2']
|
515 |
+
)
|
516 |
+
# gr.Markdown('(Optionally) upload the file of the target protein in .pdb format:')
|
517 |
+
with gr.Column():
|
518 |
+
input_protein_file = gr.File(file_count='single', label='Target Protein', file_types=['.pdb'])
|
519 |
+
|
520 |
+
n_steps = gr.Slider(minimum=50, maximum=500, label="Number of Denoising Steps", step=10)
|
521 |
n_atoms = gr.Slider(
|
522 |
minimum=0, maximum=20,
|
523 |
label="Linker Size: DiffLinker will predict it if set to 0",
|
524 |
step=1
|
525 |
)
|
526 |
examples = gr.Dataset(
|
527 |
+
components=[gr.File(visible=False), gr.File(visible=False)],
|
528 |
+
samples=[
|
529 |
+
['examples/example_1.sdf', None],
|
530 |
+
['examples/example_2.sdf', None],
|
531 |
+
['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
|
532 |
+
['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
|
533 |
+
],
|
534 |
+
headers=['Fragments', 'Target Protein'],
|
535 |
+
type='values',
|
536 |
)
|
537 |
|
538 |
button = gr.Button('Generate Linker!')
|
|
|
554 |
)
|
555 |
visualization = gr.HTML()
|
556 |
|
557 |
+
input_fragments_file.change(
|
558 |
fn=show_input,
|
559 |
+
inputs=[input_fragments_file, input_protein_file],
|
560 |
outputs=[visualization, samples, hidden],
|
561 |
)
|
562 |
+
input_protein_file.change(
|
563 |
+
fn=show_input,
|
564 |
+
inputs=[input_fragments_file, input_protein_file],
|
565 |
+
outputs=[visualization, samples, hidden],
|
566 |
+
)
|
567 |
+
input_fragments_file.clear(
|
568 |
+
fn=clear_fragments_input,
|
569 |
+
inputs=[input_protein_file],
|
570 |
+
outputs=[input_fragments_file, visualization, samples, hidden],
|
571 |
+
)
|
572 |
+
input_protein_file.clear(
|
573 |
+
fn=clear_protein_input,
|
574 |
+
inputs=[input_fragments_file],
|
575 |
+
outputs=[input_protein_file, visualization, samples, hidden],
|
576 |
)
|
577 |
examples.click(
|
578 |
+
fn=click_on_example,
|
579 |
inputs=[examples],
|
580 |
+
outputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, visualization, samples, hidden]
|
581 |
)
|
582 |
button.click(
|
583 |
fn=generate,
|
584 |
+
inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, samples, hidden],
|
585 |
outputs=[visualization, output_files, samples, hidden],
|
586 |
_js=output.RETURN_SELECTION_JS,
|
587 |
)
|
examples/3hz1_fragments.sdf
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fragments
|
2 |
+
PyMOL2.5 3D 0
|
3 |
+
|
4 |
+
23 25 0 0 0 0 0 0 0 0999 V2000
|
5 |
+
0.7050 10.1160 25.5000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
6 |
+
-0.4250 10.6930 24.7810 C 0 0 0 0 0 0 0 0 0 0 0 0
|
7 |
+
-1.6420 10.9060 25.5370 C 0 0 0 0 0 0 0 0 0 0 0 0
|
8 |
+
-1.7510 10.5210 26.8370 N 0 0 0 0 0 0 0 0 0 0 0 0
|
9 |
+
-0.6900 9.9510 27.4380 C 0 0 0 0 0 0 0 0 0 0 0 0
|
10 |
+
0.4770 9.7630 26.7990 N 0 0 0 0 0 0 0 0 0 0 0 0
|
11 |
+
-0.6830 11.1870 23.5600 N 0 0 0 0 0 0 0 0 0 0 0 0
|
12 |
+
-1.9660 11.6240 23.5390 C 0 0 0 0 0 0 0 0 0 0 0 0
|
13 |
+
-2.5810 11.4250 24.7070 N 0 0 0 0 0 0 0 0 0 0 0 0
|
14 |
+
1.9520 9.8170 24.8700 N 0 0 0 0 0 0 0 0 0 0 0 0
|
15 |
+
3.1230 9.3980 25.6290 C 0 0 0 0 0 0 0 0 0 0 0 0
|
16 |
+
2.1100 9.7530 23.4320 C 0 0 0 0 0 0 0 0 0 0 0 0
|
17 |
+
7.8600 10.1360 22.6040 C 0 0 0 0 0 0 0 0 0 0 0 0
|
18 |
+
6.5530 9.6800 22.8080 C 0 0 0 0 0 0 0 0 0 0 0 0
|
19 |
+
5.8720 10.7150 23.6130 O 0 0 0 0 0 0 0 0 0 0 0 0
|
20 |
+
6.8390 11.6780 23.7840 C 0 0 0 0 0 0 0 0 0 0 0 0
|
21 |
+
8.0580 11.3690 23.2280 C 0 0 0 0 0 0 0 0 0 0 0 0
|
22 |
+
6.6560 12.9400 24.5720 C 0 0 0 0 0 0 0 0 0 0 0 0
|
23 |
+
7.6630 13.4980 25.2340 N 0 0 0 0 0 0 0 0 0 0 0 0
|
24 |
+
7.1190 14.6210 25.8930 N 0 0 0 0 0 0 0 0 0 0 0 0
|
25 |
+
5.8050 14.8140 25.6500 C 0 0 0 0 0 0 0 0 0 0 0 0
|
26 |
+
5.4220 13.6990 24.7720 C 0 0 0 0 0 0 0 0 0 0 0 0
|
27 |
+
4.9170 15.9400 26.1920 C 0 0 0 0 0 0 0 0 0 0 0 0
|
28 |
+
1 2 4 0 0 0 0
|
29 |
+
1 6 4 0 0 0 0
|
30 |
+
1 10 1 0 0 0 0
|
31 |
+
2 3 4 0 0 0 0
|
32 |
+
2 7 4 0 0 0 0
|
33 |
+
3 4 4 0 0 0 0
|
34 |
+
3 9 4 0 0 0 0
|
35 |
+
4 5 4 0 0 0 0
|
36 |
+
5 6 4 0 0 0 0
|
37 |
+
7 8 4 0 0 0 0
|
38 |
+
8 9 4 0 0 0 0
|
39 |
+
10 11 1 0 0 0 0
|
40 |
+
10 12 1 0 0 0 0
|
41 |
+
13 14 4 0 0 0 0
|
42 |
+
13 17 4 0 0 0 0
|
43 |
+
14 15 4 0 0 0 0
|
44 |
+
15 16 4 0 0 0 0
|
45 |
+
16 17 4 0 0 0 0
|
46 |
+
16 18 1 0 0 0 0
|
47 |
+
18 19 4 0 0 0 0
|
48 |
+
18 22 4 0 0 0 0
|
49 |
+
19 20 4 0 0 0 0
|
50 |
+
20 21 4 0 0 0 0
|
51 |
+
21 22 4 0 0 0 0
|
52 |
+
21 23 1 0 0 0 0
|
53 |
+
M END
|
54 |
+
$$$$
|
examples/3hz1_protein.pdb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/5ou2_fragments.sdf
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
5ou2_fragments
|
2 |
+
PyMOL2.5 3D 0
|
3 |
+
|
4 |
+
24 26 0 0 0 0 0 0 0 0999 V2000
|
5 |
+
135.6651 -15.3583 0.1325 N 0 0 0 0 0 0 0 0 0 0 0 0
|
6 |
+
134.8356 -14.4706 -0.4078 C 0 0 0 0 0 0 0 0 0 0 0 0
|
7 |
+
134.5969 -13.5549 0.5236 N 0 0 0 0 0 0 0 0 0 0 0 0
|
8 |
+
135.2672 -13.8787 1.6104 C 0 0 0 0 0 0 0 0 0 0 0 0
|
9 |
+
135.9361 -15.0095 1.3626 C 0 0 0 0 0 0 0 0 0 0 0 0
|
10 |
+
135.2407 -13.1072 2.8878 C 0 0 0 0 0 0 0 0 0 0 0 0
|
11 |
+
135.5339 -13.7328 4.0539 C 0 0 0 0 0 0 0 0 0 0 0 0
|
12 |
+
135.5239 -13.0695 5.2284 C 0 0 0 0 0 0 0 0 0 0 0 0
|
13 |
+
135.1995 -11.7489 5.2810 C 0 0 0 0 0 0 0 0 0 0 0 0
|
14 |
+
134.9023 -11.1173 4.1089 C 0 0 0 0 0 0 0 0 0 0 0 0
|
15 |
+
134.9113 -11.7774 2.9035 C 0 0 0 0 0 0 0 0 0 0 0 0
|
16 |
+
135.1362 -10.8138 6.9517 Br 0 0 0 0 0 0 0 0 0 0 0 0
|
17 |
+
126.8521 -19.0355 0.2522 N 0 0 0 0 0 0 0 0 0 0 0 0
|
18 |
+
126.0921 -18.0299 -0.2360 C 0 0 0 0 0 0 0 0 0 0 0 0
|
19 |
+
126.8721 -17.2548 -1.0322 N 0 0 0 0 0 0 0 0 0 0 0 0
|
20 |
+
128.1098 -17.7707 -1.0325 C 0 0 0 0 0 0 0 0 0 0 0 0
|
21 |
+
128.0889 -18.8815 -0.2256 C 0 0 0 0 0 0 0 0 0 0 0 0
|
22 |
+
129.3145 -17.2106 -1.7791 C 0 0 0 0 0 0 0 0 0 0 0 0
|
23 |
+
130.5850 -17.7185 -1.5264 C 0 0 0 0 0 0 0 0 0 0 0 0
|
24 |
+
131.6879 -17.2095 -2.1865 C 0 0 0 0 0 0 0 0 0 0 0 0
|
25 |
+
131.5211 -16.1844 -3.1052 C 0 0 0 0 0 0 0 0 0 0 0 0
|
26 |
+
130.2586 -15.6644 -3.3699 C 0 0 0 0 0 0 0 0 0 0 0 0
|
27 |
+
129.1548 -16.1795 -2.7058 C 0 0 0 0 0 0 0 0 0 0 0 0
|
28 |
+
133.0656 -15.5029 -4.0086 Br 0 0 0 0 0 0 0 0 0 0 0 0
|
29 |
+
1 2 4 0 0 0 0
|
30 |
+
2 3 4 0 0 0 0
|
31 |
+
3 4 4 0 0 0 0
|
32 |
+
4 6 1 0 0 0 0
|
33 |
+
1 5 4 0 0 0 0
|
34 |
+
4 5 4 0 0 0 0
|
35 |
+
6 7 4 0 0 0 0
|
36 |
+
6 11 4 0 0 0 0
|
37 |
+
7 8 4 0 0 0 0
|
38 |
+
8 9 4 0 0 0 0
|
39 |
+
9 10 4 0 0 0 0
|
40 |
+
9 12 1 0 0 0 0
|
41 |
+
10 11 4 0 0 0 0
|
42 |
+
13 14 4 0 0 0 0
|
43 |
+
14 15 4 0 0 0 0
|
44 |
+
15 16 4 0 0 0 0
|
45 |
+
16 18 1 0 0 0 0
|
46 |
+
13 17 4 0 0 0 0
|
47 |
+
16 17 4 0 0 0 0
|
48 |
+
18 19 4 0 0 0 0
|
49 |
+
18 23 4 0 0 0 0
|
50 |
+
19 20 4 0 0 0 0
|
51 |
+
20 21 4 0 0 0 0
|
52 |
+
21 22 4 0 0 0 0
|
53 |
+
21 24 1 0 0 0 0
|
54 |
+
22 23 4 0 0 0 0
|
55 |
+
M END
|
56 |
+
$$$$
|
examples/5ou2_protein.pdb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
output.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
<html>
|
3 |
<head>
|
4 |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
@@ -26,7 +26,6 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
26 |
let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
|
27 |
viewer.addModel(`{molecule}`, "{fmt}");
|
28 |
viewer.getModel(0).setStyle(defaultStyle);
|
29 |
-
// document.cookie = document.cookie + "|selected_atoms:";
|
30 |
|
31 |
viewer.getModel(0).setClickable(
|
32 |
{{}},
|
@@ -38,20 +37,16 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
38 |
{{"serial": _atom.serial, "model": 0}},
|
39 |
{{"sphere": {{"color": "magenta", "radius": 0.4}} }}
|
40 |
);
|
41 |
-
// document.cookie = document.cookie + "atom_" + String(_atom.serial) + "-";
|
42 |
window.parent.postMessage({{
|
43 |
name: "atom_selection",
|
44 |
data: {{"atom": _atom.serial, "add": true}}
|
45 |
-
// data: JSON.stringify({{"add": _atom.serial}})
|
46 |
}}, "*");
|
47 |
}} else {{
|
48 |
delete _atom.isClicked;
|
49 |
_viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
|
50 |
-
// document.cookie = document.cookie.replace("atom_" + String(_atom.serial) + "-", "");
|
51 |
window.parent.postMessage({{
|
52 |
name: "atom_selection",
|
53 |
data: {{"atom": _atom.serial, "add": false}}
|
54 |
-
// data: JSON.stringify({{"remove": _atom.serial}})
|
55 |
}}, "*");
|
56 |
}}
|
57 |
_viewer.render();
|
@@ -67,6 +62,112 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
67 |
</html>
|
68 |
"""
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
72 |
<html>
|
@@ -88,6 +189,7 @@ SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
88 |
|
89 |
<body>
|
90 |
<div id="container" class="mol-container"></div>
|
|
|
91 |
<button id="fragments">Input Fragments</button>
|
92 |
<button id="molecule">Output Molecule</button>
|
93 |
<script>
|
@@ -120,6 +222,74 @@ SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
|
120 |
</html>
|
121 |
"""
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
INVALID_FORMAT_MSG = """
|
125 |
<!DOCTYPE html>
|
@@ -135,13 +305,18 @@ INVALID_FORMAT_MSG = """
|
|
135 |
|
136 |
<body>
|
137 |
<h3>Invalid file format: {extension}</h3>
|
138 |
-
|
139 |
<ul>
|
140 |
<li>.pdb</li>
|
141 |
<li>.sdf</li>
|
142 |
<li>.mol</li>
|
143 |
<li>.mol2</li>
|
144 |
</ul>
|
|
|
|
|
|
|
|
|
|
|
145 |
</body>
|
146 |
</html>
|
147 |
"""
|
@@ -190,7 +365,7 @@ STARTUP_JS = """
|
|
190 |
"""
|
191 |
|
192 |
RETURN_SELECTION_JS = """
|
193 |
-
(input_file, n_steps, n_atoms, samples, hidden) => {
|
194 |
let selected = []
|
195 |
for (const [atom, add] of Object.entries(window.selected_elements)) {
|
196 |
if (add) {
|
@@ -203,6 +378,6 @@ RETURN_SELECTION_JS = """
|
|
203 |
}
|
204 |
}
|
205 |
console.log("Finished parsing");
|
206 |
-
return [input_file, n_steps, n_atoms, samples, selected.join(",")];
|
207 |
}
|
208 |
"""
|
|
|
1 |
+
FRAGMENTS_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
2 |
<html>
|
3 |
<head>
|
4 |
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
|
|
26 |
let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
|
27 |
viewer.addModel(`{molecule}`, "{fmt}");
|
28 |
viewer.getModel(0).setStyle(defaultStyle);
|
|
|
29 |
|
30 |
viewer.getModel(0).setClickable(
|
31 |
{{}},
|
|
|
37 |
{{"serial": _atom.serial, "model": 0}},
|
38 |
{{"sphere": {{"color": "magenta", "radius": 0.4}} }}
|
39 |
);
|
|
|
40 |
window.parent.postMessage({{
|
41 |
name: "atom_selection",
|
42 |
data: {{"atom": _atom.serial, "add": true}}
|
|
|
43 |
}}, "*");
|
44 |
}} else {{
|
45 |
delete _atom.isClicked;
|
46 |
_viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
|
|
|
47 |
window.parent.postMessage({{
|
48 |
name: "atom_selection",
|
49 |
data: {{"atom": _atom.serial, "add": false}}
|
|
|
50 |
}}, "*");
|
51 |
}}
|
52 |
_viewer.render();
|
|
|
62 |
</html>
|
63 |
"""
|
64 |
|
65 |
+
TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
66 |
+
<html>
|
67 |
+
<head>
|
68 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
69 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
70 |
+
<script src="https://3Dmol.org/build/3Dmol.js"></script>
|
71 |
+
<style>
|
72 |
+
.mol-container {{
|
73 |
+
width: 600px;
|
74 |
+
height: 600px;
|
75 |
+
position: relative;
|
76 |
+
}}
|
77 |
+
.mol-container select{{
|
78 |
+
background-image:None;
|
79 |
+
}}
|
80 |
+
</style>
|
81 |
+
</head>
|
82 |
+
|
83 |
+
<body>
|
84 |
+
<div id="container" class="mol-container"></div>
|
85 |
+
<script>
|
86 |
+
$(document).ready(function() {{
|
87 |
+
let element = $("#container");
|
88 |
+
let config = {{ backgroundColor: "white" }};
|
89 |
+
let viewer = $3Dmol.createViewer(element, config);
|
90 |
+
let proteinStyle = {{ cartoon: {{ colorscheme: "ssPyMOL" }} }};
|
91 |
+
viewer.addModel(`{molecule}`, "{fmt}");
|
92 |
+
viewer.getModel(0).setStyle(proteinStyle);
|
93 |
+
|
94 |
+
viewer.zoomTo();
|
95 |
+
viewer.zoom(0.7);
|
96 |
+
viewer.render();
|
97 |
+
}});
|
98 |
+
</script>
|
99 |
+
</body>
|
100 |
+
</html>
|
101 |
+
"""
|
102 |
+
|
103 |
+
FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
104 |
+
<html>
|
105 |
+
<head>
|
106 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
107 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
108 |
+
<script src="https://3Dmol.org/build/3Dmol.js"></script>
|
109 |
+
<style>
|
110 |
+
.mol-container {{
|
111 |
+
width: 600px;
|
112 |
+
height: 600px;
|
113 |
+
position: relative;
|
114 |
+
}}
|
115 |
+
.mol-container select{{
|
116 |
+
background-image:None;
|
117 |
+
}}
|
118 |
+
</style>
|
119 |
+
</head>
|
120 |
+
|
121 |
+
<body>
|
122 |
+
<div id="container" class="mol-container"></div>
|
123 |
+
<script>
|
124 |
+
$(document).ready(function() {{
|
125 |
+
let element = $("#container");
|
126 |
+
let config = {{ backgroundColor: "white" }};
|
127 |
+
let viewer = $3Dmol.createViewer(element, config);
|
128 |
+
let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
|
129 |
+
let proteinStyle = {{ cartoon: {{ colorscheme: "ssPyMOL" }} }};
|
130 |
+
|
131 |
+
viewer.addModel(`{molecule}`, "{fmt}");
|
132 |
+
viewer.getModel(0).setStyle(defaultStyle);
|
133 |
+
viewer.getModel(0).setClickable(
|
134 |
+
{{}},
|
135 |
+
true,
|
136 |
+
function (_atom, _viewer, _event, _container) {{
|
137 |
+
if (!_atom.isClicked) {{
|
138 |
+
_atom.isClicked = true;
|
139 |
+
_viewer.addStyle(
|
140 |
+
{{"serial": _atom.serial, "model": 0}},
|
141 |
+
{{"sphere": {{"color": "magenta", "radius": 0.4}} }}
|
142 |
+
);
|
143 |
+
window.parent.postMessage({{
|
144 |
+
name: "atom_selection",
|
145 |
+
data: {{"atom": _atom.serial, "add": true}}
|
146 |
+
}}, "*");
|
147 |
+
}} else {{
|
148 |
+
delete _atom.isClicked;
|
149 |
+
_viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
|
150 |
+
window.parent.postMessage({{
|
151 |
+
name: "atom_selection",
|
152 |
+
data: {{"atom": _atom.serial, "add": false}}
|
153 |
+
}}, "*");
|
154 |
+
}}
|
155 |
+
_viewer.render();
|
156 |
+
}}
|
157 |
+
);
|
158 |
+
|
159 |
+
viewer.addModel(`{target}`, "{target_fmt}");
|
160 |
+
viewer.getModel(1).setStyle(proteinStyle);
|
161 |
+
|
162 |
+
viewer.zoomTo();
|
163 |
+
viewer.zoom(0.7);
|
164 |
+
viewer.render();
|
165 |
+
}});
|
166 |
+
</script>
|
167 |
+
</body>
|
168 |
+
</html>
|
169 |
+
"""
|
170 |
+
|
171 |
|
172 |
SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
173 |
<html>
|
|
|
189 |
|
190 |
<body>
|
191 |
<div id="container" class="mol-container"></div>
|
192 |
+
<br>
|
193 |
<button id="fragments">Input Fragments</button>
|
194 |
<button id="molecule">Output Molecule</button>
|
195 |
<script>
|
|
|
222 |
</html>
|
223 |
"""
|
224 |
|
225 |
+
SAMPLES_WITH_TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
|
226 |
+
<html>
|
227 |
+
<head>
|
228 |
+
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
|
229 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
|
230 |
+
<script src="https://3Dmol.org/build/3Dmol.js"></script>
|
231 |
+
<style>
|
232 |
+
.mol-container {{
|
233 |
+
width: 600px;
|
234 |
+
height: 600px;
|
235 |
+
position: relative;
|
236 |
+
}}
|
237 |
+
.mol-container select{{
|
238 |
+
background-image:None;
|
239 |
+
}}
|
240 |
+
</style>
|
241 |
+
</head>
|
242 |
+
|
243 |
+
<body>
|
244 |
+
<div id="container" class="mol-container"></div>
|
245 |
+
<br>
|
246 |
+
<button id="fragments">Input Fragments</button>
|
247 |
+
<button id="molecule">Output Molecule</button>
|
248 |
+
<button id="show-target">Show Target</button>
|
249 |
+
<button id="hide-target">Hide Target</button>
|
250 |
+
<script>
|
251 |
+
let element = $("#container");
|
252 |
+
let config = {{ backgroundColor: "white" }};
|
253 |
+
let viewer = $3Dmol.createViewer( element, config );
|
254 |
+
|
255 |
+
$(document).ready(function() {{
|
256 |
+
viewer.addModel(`{fragments}`, "{fragments_fmt}")
|
257 |
+
viewer.getModel(0).setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
|
258 |
+
viewer.getModel(0).hide();
|
259 |
+
|
260 |
+
viewer.addModel(`{molecule}`, "{molecule_fmt}")
|
261 |
+
viewer.getModel(1).setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
|
262 |
+
|
263 |
+
viewer.addModel(`{target}`, "{target_fmt}")
|
264 |
+
viewer.getModel(2).setStyle({{ cartoon: {{ colorscheme: "ssPyMOL" }} }})
|
265 |
+
|
266 |
+
viewer.zoomTo();
|
267 |
+
viewer.zoom(0.7);
|
268 |
+
viewer.render();
|
269 |
+
}});
|
270 |
+
$("#fragments").click(function() {{
|
271 |
+
viewer.getModel(0).show();
|
272 |
+
viewer.getModel(1).hide();
|
273 |
+
viewer.render();
|
274 |
+
}});
|
275 |
+
$("#molecule").click(function() {{
|
276 |
+
viewer.getModel(1).show();
|
277 |
+
viewer.getModel(0).hide();
|
278 |
+
viewer.render();
|
279 |
+
}});
|
280 |
+
$("#show-target").click(function() {{
|
281 |
+
viewer.getModel(2).show();
|
282 |
+
viewer.render();
|
283 |
+
}});
|
284 |
+
$("#hide-target").click(function() {{
|
285 |
+
viewer.getModel(2).hide();
|
286 |
+
viewer.render();
|
287 |
+
}});
|
288 |
+
</script>
|
289 |
+
</body>
|
290 |
+
</html>
|
291 |
+
"""
|
292 |
+
|
293 |
|
294 |
INVALID_FORMAT_MSG = """
|
295 |
<!DOCTYPE html>
|
|
|
305 |
|
306 |
<body>
|
307 |
<h3>Invalid file format: {extension}</h3>
|
308 |
+
Allowed formats for the fragments file:
|
309 |
<ul>
|
310 |
<li>.pdb</li>
|
311 |
<li>.sdf</li>
|
312 |
<li>.mol</li>
|
313 |
<li>.mol2</li>
|
314 |
</ul>
|
315 |
+
|
316 |
+
Allowed formats for the optional protein file:
|
317 |
+
<ul>
|
318 |
+
<li>.pdb</li>
|
319 |
+
</ul>
|
320 |
</body>
|
321 |
</html>
|
322 |
"""
|
|
|
365 |
"""
|
366 |
|
367 |
RETURN_SELECTION_JS = """
|
368 |
+
(input_file, input_protein_file, n_steps, n_atoms, samples, hidden) => {
|
369 |
let selected = []
|
370 |
for (const [atom, add] of Object.entries(window.selected_elements)) {
|
371 |
if (add) {
|
|
|
378 |
}
|
379 |
}
|
380 |
console.log("Finished parsing");
|
381 |
+
return [input_file, input_protein_file, n_steps, n_atoms, samples, selected.join(",")];
|
382 |
}
|
383 |
"""
|
src/datasets.py
CHANGED
@@ -101,15 +101,25 @@ class ZincDataset(Dataset):
|
|
101 |
|
102 |
|
103 |
class MOADDataset(Dataset):
|
104 |
-
def __init__(self, data_path, prefix, device):
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
|
108 |
if os.path.exists(dataset_path):
|
109 |
self.data = torch.load(dataset_path, map_location=device)
|
110 |
else:
|
111 |
print(f'Preprocessing dataset with prefix {prefix}')
|
112 |
-
self.data =
|
113 |
torch.save(self.data, dataset_path)
|
114 |
|
115 |
def __len__(self):
|
@@ -264,7 +274,7 @@ def collate_with_fragment_edges(batch):
|
|
264 |
out = {}
|
265 |
|
266 |
# Filter out big molecules
|
267 |
-
batch = [data for data in batch if data['num_atoms'] <= 50]
|
268 |
|
269 |
for i, data in enumerate(batch):
|
270 |
for key, value in data.items():
|
|
|
101 |
|
102 |
|
103 |
class MOADDataset(Dataset):
|
104 |
+
def __init__(self, data=None, data_path=None, prefix=None, device=None):
|
105 |
+
assert (data is not None) or all(x is not None for x in (data_path, prefix, device))
|
106 |
+
if data is not None:
|
107 |
+
self.data = data
|
108 |
+
return
|
109 |
+
|
110 |
+
if '.' in prefix:
|
111 |
+
prefix, pocket_mode = prefix.split('.')
|
112 |
+
else:
|
113 |
+
parts = prefix.split('_')
|
114 |
+
prefix = '_'.join(parts[:-1])
|
115 |
+
pocket_mode = parts[-1]
|
116 |
|
117 |
dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
|
118 |
if os.path.exists(dataset_path):
|
119 |
self.data = torch.load(dataset_path, map_location=device)
|
120 |
else:
|
121 |
print(f'Preprocessing dataset with prefix {prefix}')
|
122 |
+
self.data = self.preprocess(data_path, prefix, pocket_mode, device)
|
123 |
torch.save(self.data, dataset_path)
|
124 |
|
125 |
def __len__(self):
|
|
|
274 |
out = {}
|
275 |
|
276 |
# Filter out big molecules
|
277 |
+
# batch = [data for data in batch if data['num_atoms'] <= 50]
|
278 |
|
279 |
for i, data in enumerate(batch):
|
280 |
for key, value in data.items():
|
src/generation.py
CHANGED
@@ -1,24 +1,44 @@
|
|
|
|
1 |
import os.path
|
2 |
import subprocess
|
3 |
import torch
|
4 |
|
|
|
|
|
5 |
from src.visualizer import save_xyz_file
|
|
|
|
|
6 |
|
7 |
N_SAMPLES = 5
|
8 |
|
9 |
|
10 |
-
def generate_linkers(ddpm, data, sample_fn, name):
|
11 |
-
chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
print('Generated linker')
|
13 |
x = chain[0][:, :, :ddpm.n_dims]
|
14 |
h = chain[0][:, :, ddpm.n_dims:]
|
15 |
|
16 |
# Put the molecule back to the initial orientation
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
20 |
x = x + mean * node_mask
|
21 |
|
|
|
|
|
|
|
22 |
names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
|
23 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
24 |
print('Saved XYZ files')
|
@@ -36,3 +56,62 @@ def try_to_convert_to_sdf(name):
|
|
36 |
out_files.append(out_xyz)
|
37 |
|
38 |
return out_files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
import os.path
|
3 |
import subprocess
|
4 |
import torch
|
5 |
|
6 |
+
from Bio.PDB import PDBParser
|
7 |
+
from src import const
|
8 |
from src.visualizer import save_xyz_file
|
9 |
+
from src.utils import FoundNaNException
|
10 |
+
from src.datasets import get_one_hot
|
11 |
|
12 |
N_SAMPLES = 5
|
13 |
|
14 |
|
15 |
+
def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
|
16 |
+
chain = node_mask = None
|
17 |
+
for i in range(5):
|
18 |
+
try:
|
19 |
+
chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
|
20 |
+
break
|
21 |
+
except FoundNaNException:
|
22 |
+
continue
|
23 |
+
|
24 |
print('Generated linker')
|
25 |
x = chain[0][:, :, :ddpm.n_dims]
|
26 |
h = chain[0][:, :, ddpm.n_dims:]
|
27 |
|
28 |
# Put the molecule back to the initial orientation
|
29 |
+
if with_pocket:
|
30 |
+
com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
|
31 |
+
else:
|
32 |
+
com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
|
33 |
+
|
34 |
+
pos_masked = data['positions'] * com_mask
|
35 |
+
N = com_mask.sum(1, keepdims=True)
|
36 |
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
37 |
x = x + mean * node_mask
|
38 |
|
39 |
+
if with_pocket:
|
40 |
+
node_mask[torch.where(data['pocket_mask'])] = 0
|
41 |
+
|
42 |
names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
|
43 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
44 |
print('Saved XYZ files')
|
|
|
56 |
out_files.append(out_xyz)
|
57 |
|
58 |
return out_files
|
59 |
+
|
60 |
+
|
61 |
+
def get_pocket(mol, pdb_path):
|
62 |
+
struct = PDBParser().get_structure('', pdb_path)
|
63 |
+
residue_ids = []
|
64 |
+
atom_coords = []
|
65 |
+
|
66 |
+
for residue in struct.get_residues():
|
67 |
+
resid = residue.get_id()[1]
|
68 |
+
for atom in residue.get_atoms():
|
69 |
+
atom_coords.append(atom.get_coord())
|
70 |
+
residue_ids.append(resid)
|
71 |
+
|
72 |
+
residue_ids = np.array(residue_ids)
|
73 |
+
atom_coords = np.array(atom_coords)
|
74 |
+
mol_atom_coords = mol.GetConformer().GetPositions()
|
75 |
+
|
76 |
+
distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
|
77 |
+
contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
|
78 |
+
|
79 |
+
pocket_coords_full = []
|
80 |
+
pocket_types_full = []
|
81 |
+
|
82 |
+
pocket_coords_bb = []
|
83 |
+
pocket_types_bb = []
|
84 |
+
|
85 |
+
for residue in struct.get_residues():
|
86 |
+
resid = residue.get_id()[1]
|
87 |
+
if resid not in contact_residues:
|
88 |
+
continue
|
89 |
+
|
90 |
+
for atom in residue.get_atoms():
|
91 |
+
atom_name = atom.get_name()
|
92 |
+
atom_type = atom.element.upper()
|
93 |
+
atom_coord = atom.get_coord()
|
94 |
+
|
95 |
+
pocket_coords_full.append(atom_coord.tolist())
|
96 |
+
pocket_types_full.append(atom_type)
|
97 |
+
|
98 |
+
if atom_name in {'N', 'CA', 'C', 'O'}:
|
99 |
+
pocket_coords_bb.append(atom_coord.tolist())
|
100 |
+
pocket_types_bb.append(atom_type)
|
101 |
+
|
102 |
+
pocket_pos = []
|
103 |
+
pocket_one_hot = []
|
104 |
+
pocket_charges = []
|
105 |
+
for coord, atom_type in zip(pocket_coords_full, pocket_types_full):
|
106 |
+
if atom_type not in const.GEOM_ATOM2IDX.keys():
|
107 |
+
continue
|
108 |
+
|
109 |
+
pocket_pos.append(coord)
|
110 |
+
pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
|
111 |
+
pocket_charges.append(const.GEOM_CHARGES[atom_type])
|
112 |
+
|
113 |
+
pocket_pos = np.array(pocket_pos)
|
114 |
+
pocket_one_hot = np.array(pocket_one_hot)
|
115 |
+
pocket_charges = np.array(pocket_charges)
|
116 |
+
|
117 |
+
return pocket_pos, pocket_one_hot, pocket_charges
|
src/lightning.py
CHANGED
@@ -21,7 +21,6 @@ from pdb import set_trace
|
|
21 |
|
22 |
|
23 |
def get_activation(activation):
|
24 |
-
print(activation)
|
25 |
if activation == 'silu':
|
26 |
return torch.nn.SiLU()
|
27 |
else:
|
@@ -158,7 +157,7 @@ class DDPM(pl.LightningModule):
|
|
158 |
context = fragment_mask
|
159 |
|
160 |
# Add information about pocket to the context
|
161 |
-
if
|
162 |
fragment_pocket_mask = fragment_mask
|
163 |
fragment_only_mask = data['fragment_only_mask']
|
164 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
@@ -170,6 +169,8 @@ class DDPM(pl.LightningModule):
|
|
170 |
# Removing COM of fragment from the atom coordinates
|
171 |
if self.inpainting:
|
172 |
center_of_mass_mask = node_mask
|
|
|
|
|
173 |
elif self.center_of_mass == 'fragments':
|
174 |
center_of_mass_mask = fragment_mask
|
175 |
elif self.center_of_mass == 'anchors':
|
@@ -423,9 +424,9 @@ class DDPM(pl.LightningModule):
|
|
423 |
context = fragment_mask
|
424 |
|
425 |
# Add information about pocket to the context
|
426 |
-
if
|
427 |
fragment_pocket_mask = fragment_mask
|
428 |
-
fragment_only_mask =
|
429 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
430 |
if self.anchors_context:
|
431 |
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
|
@@ -435,6 +436,8 @@ class DDPM(pl.LightningModule):
|
|
435 |
# Removing COM of fragment from the atom coordinates
|
436 |
if self.inpainting:
|
437 |
center_of_mass_mask = node_mask
|
|
|
|
|
438 |
elif self.center_of_mass == 'fragments':
|
439 |
center_of_mass_mask = fragment_mask
|
440 |
elif self.center_of_mass == 'anchors':
|
|
|
21 |
|
22 |
|
23 |
def get_activation(activation):
|
|
|
24 |
if activation == 'silu':
|
25 |
return torch.nn.SiLU()
|
26 |
else:
|
|
|
157 |
context = fragment_mask
|
158 |
|
159 |
# Add information about pocket to the context
|
160 |
+
if isinstance(self.train_dataset, MOADDataset):
|
161 |
fragment_pocket_mask = fragment_mask
|
162 |
fragment_only_mask = data['fragment_only_mask']
|
163 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
|
|
169 |
# Removing COM of fragment from the atom coordinates
|
170 |
if self.inpainting:
|
171 |
center_of_mass_mask = node_mask
|
172 |
+
elif isinstance(self.train_dataset, MOADDataset) and self.center_of_mass == 'fragments':
|
173 |
+
center_of_mass_mask = data['fragment_only_mask']
|
174 |
elif self.center_of_mass == 'fragments':
|
175 |
center_of_mass_mask = fragment_mask
|
176 |
elif self.center_of_mass == 'anchors':
|
|
|
424 |
context = fragment_mask
|
425 |
|
426 |
# Add information about pocket to the context
|
427 |
+
if isinstance(self.val_dataset, MOADDataset):
|
428 |
fragment_pocket_mask = fragment_mask
|
429 |
+
fragment_only_mask = template_data['fragment_only_mask']
|
430 |
pocket_only_mask = fragment_pocket_mask - fragment_only_mask
|
431 |
if self.anchors_context:
|
432 |
context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
|
|
|
436 |
# Removing COM of fragment from the atom coordinates
|
437 |
if self.inpainting:
|
438 |
center_of_mass_mask = node_mask
|
439 |
+
elif isinstance(self.val_dataset, MOADDataset) and self.center_of_mass == 'fragments':
|
440 |
+
center_of_mass_mask = template_data['fragment_only_mask']
|
441 |
elif self.center_of_mass == 'fragments':
|
442 |
center_of_mass_mask = fragment_mask
|
443 |
elif self.center_of_mass == 'anchors':
|