igashov commited on
Commit
c104a99
1 Parent(s): 76db25b

Pocket-conditioned generation

Browse files
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
 
3
  import gradio as gr
4
  import numpy as np
@@ -9,10 +10,12 @@ import output
9
 
10
  from rdkit import Chem
11
  from src import const
12
- from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
13
  from src.lightning import DDPM
14
  from src.linker_size_lightning import SizeClassifier
15
- from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf
 
 
16
 
17
  MODELS_METADATA = {
18
  'geom_difflinker': {
@@ -85,65 +88,167 @@ def read_molecule(path):
85
  raise Exception('Unknown file extension')
86
 
87
 
88
- def show_input(input_file):
89
- if input_file is None:
90
- return ['', gr.Radio.update(visible=False, value='Sample 1'), None]
91
- if isinstance(input_file, str):
92
- path = input_file
93
  else:
94
- path = input_file.name
95
  extension = path.split('.')[-1]
96
- if extension not in ['sdf', 'pdb', 'mol', 'mol2']:
 
97
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
98
- return [
99
- output.IFRAME_TEMPLATE.format(html=msg),
100
- gr.Radio.update(visible=False),
101
- None,
102
- ]
103
 
104
  try:
105
- molecule = read_molecule_content(path)
106
  except Exception as e:
107
- return [
108
- f'Could not read the molecule: {e}',
109
- gr.Radio.update(visible=False),
110
- None,
111
- ]
 
 
 
 
 
 
112
 
113
- html = output.INITIAL_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
114
- return [
115
- output.IFRAME_TEMPLATE.format(html=html),
116
- gr.Radio.update(visible=False),
117
- None,
118
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  def draw_sample(idx, out_files):
 
 
122
  if isinstance(idx, str):
123
  idx = int(idx.strip().split(' ')[-1]) - 1
124
 
125
- in_file = out_files[0]
126
  in_sdf = in_file if isinstance(in_file, str) else in_file.name
 
 
127
 
128
- out_file = out_files[idx + 1]
129
- out_sdf = out_file if isinstance(out_file, str) else out_file.name
 
 
 
 
 
 
130
 
131
- input_fragments_content = read_molecule_content(in_sdf)
 
132
  generated_molecule_content = read_molecule_content(out_sdf)
133
-
134
- fragments_fmt = in_sdf.split('.')[-1]
135
  molecule_fmt = out_sdf.split('.')[-1]
136
 
137
- html = output.SAMPLES_RENDERING_TEMPLATE.format(
138
- fragments=input_fragments_content,
139
- fragments_fmt=fragments_fmt,
140
- molecule=generated_molecule_content,
141
- molecule_fmt=molecule_fmt,
142
- )
 
 
 
 
 
 
 
 
 
 
143
  return output.IFRAME_TEMPLATE.format(html=html)
144
 
145
 
146
- def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  # Parsing selected atoms (javascript output)
148
  selected_atoms = selected_atoms.strip()
149
  if selected_atoms == '':
@@ -157,9 +262,6 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
157
  else:
158
  selected_model_name = 'geom_difflinker_given_anchors'
159
 
160
- if input_file is None:
161
- return [None, None, None, None]
162
-
163
  print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
164
  ddpm = diffusion_models[selected_model_name]
165
  path = input_file.name
@@ -170,20 +272,25 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
170
 
171
  try:
172
  molecule = read_molecule(path)
173
- molecule = Chem.RemoveAllHs(molecule)
 
 
 
174
  name = '.'.join(path.split('/')[-1].split('.')[:-1])
175
  inp_sdf = f'results/input_{name}.sdf'
176
  except Exception as e:
 
177
  error = f'Could not read the molecule: {e}'
178
  msg = output.ERROR_FORMAT_MSG.format(message=error)
179
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
180
 
181
- if molecule.GetNumAtoms() > 50:
182
- error = f'Too large molecule: upper limit is 50 heavy atoms'
183
  msg = output.ERROR_FORMAT_MSG.format(message=error)
184
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
185
 
186
  with Chem.SDWriter(inp_sdf) as w:
 
187
  w.write(molecule)
188
 
189
  positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
@@ -227,14 +334,152 @@ def generate(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
227
 
228
  for data in dataloader:
229
  try:
230
- generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name)
231
  except Exception as e:
 
232
  error = f'Caught exception while generating linkers: {e}'
233
  msg = output.ERROR_FORMAT_MSG.format(message=error)
234
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
235
 
236
  out_files = try_to_convert_to_sdf(name)
237
  out_files = [inp_sdf] + out_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  return [
240
  draw_sample(radio_samples, out_files),
@@ -260,19 +505,34 @@ with demo:
260
  with gr.Box():
261
  with gr.Row():
262
  with gr.Column():
263
- gr.Markdown('## Input Fragments')
264
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
265
- input_file = gr.File(file_count='single', label='Input Fragments')
266
- n_steps = gr.Slider(minimum=10, maximum=500, label="Number of Denoising Steps", step=10)
 
 
 
 
 
 
 
 
 
267
  n_atoms = gr.Slider(
268
  minimum=0, maximum=20,
269
  label="Linker Size: DiffLinker will predict it if set to 0",
270
  step=1
271
  )
272
  examples = gr.Dataset(
273
- components=[gr.File(visible=False)],
274
- samples=[['examples/example_1.sdf'], ['examples/example_2.sdf']],
275
- type='index',
 
 
 
 
 
 
276
  )
277
 
278
  button = gr.Button('Generate Linker!')
@@ -294,24 +554,34 @@ with demo:
294
  )
295
  visualization = gr.HTML()
296
 
297
- input_file.change(
298
  fn=show_input,
299
- inputs=[input_file],
300
  outputs=[visualization, samples, hidden],
301
  )
302
- input_file.clear(
303
- fn=lambda: [None, '', gr.Radio.update(visible=False), None],
304
- inputs=[],
305
- outputs=[input_file, visualization, samples, hidden],
 
 
 
 
 
 
 
 
 
 
306
  )
307
  examples.click(
308
- fn=lambda idx: [f'examples/example_{idx+1}.sdf', 10, 0] + show_input(f'examples/example_{idx+1}.sdf'),
309
  inputs=[examples],
310
- outputs=[input_file, n_steps, n_atoms, visualization, samples, hidden]
311
  )
312
  button.click(
313
  fn=generate,
314
- inputs=[input_file, n_steps, n_atoms, samples, hidden],
315
  outputs=[visualization, output_files, samples, hidden],
316
  _js=output.RETURN_SELECTION_JS,
317
  )
 
1
  import argparse
2
+ import shutil
3
 
4
  import gradio as gr
5
  import numpy as np
 
10
 
11
  from rdkit import Chem
12
  from src import const
13
+ from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
+ from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf, get_pocket
17
+ from zipfile import ZipFile
18
+
19
 
20
  MODELS_METADATA = {
21
  'geom_difflinker': {
 
88
  raise Exception('Unknown file extension')
89
 
90
 
91
+ def read_molecule_file(in_file, allowed_extentions):
92
+ if isinstance(in_file, str):
93
+ path = in_file
 
 
94
  else:
95
+ path = in_file.name
96
  extension = path.split('.')[-1]
97
+
98
+ if extension not in allowed_extentions:
99
  msg = output.INVALID_FORMAT_MSG.format(extension=extension)
100
+ return None, None, msg
 
 
 
 
101
 
102
  try:
103
+ mol = read_molecule(path)
104
  except Exception as e:
105
+ e = str(e).replace('\'', '')
106
+ msg = output.ERROR_FORMAT_MSG.format(message=e)
107
+ return None, None, msg
108
+
109
+ if extension == 'pdb':
110
+ content = Chem.MolToPDBBlock(mol)
111
+ elif extension in ['mol', 'mol2', 'sdf']:
112
+ content = Chem.MolToMolBlock(mol, kekulize=False)
113
+ extension = 'mol'
114
+ else:
115
+ raise NotImplementedError
116
 
117
+ return content, extension, None
118
+
119
+
120
+ def show_input(in_fragments, in_protein):
121
+ vis = ''
122
+ if in_fragments is not None and in_protein is None:
123
+ vis = show_fragments(in_fragments)
124
+ elif in_fragments is None and in_protein is not None:
125
+ vis = show_target(in_protein)
126
+ elif in_fragments is not None and in_protein is not None:
127
+ vis = show_fragments_and_target(in_fragments, in_protein)
128
+ return [vis, gr.Radio.update(visible=False), None]
129
+
130
+
131
+ def show_fragments(in_fragments):
132
+ molecule, extension, html = read_molecule_file(in_fragments, allowed_extentions=['sdf', 'pdb', 'mol', 'mol2'])
133
+ if molecule is not None:
134
+ html = output.FRAGMENTS_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
135
+
136
+ return output.IFRAME_TEMPLATE.format(html=html)
137
+
138
+
139
+ def show_target(in_protein):
140
+ molecule, extension, html = read_molecule_file(in_protein, allowed_extentions=['pdb'])
141
+ if molecule is not None:
142
+ html = output.TARGET_RENDERING_TEMPLATE.format(molecule=molecule, fmt=extension)
143
+
144
+ return output.IFRAME_TEMPLATE.format(html=html)
145
+
146
+
147
+ def show_fragments_and_target(in_fragments, in_protein):
148
+ fragments_molecule, fragments_extension, msg = read_molecule_file(in_fragments, ['sdf', 'pdb', 'mol', 'mol2'])
149
+ if fragments_molecule is None:
150
+ return output.IFRAME_TEMPLATE.format(html=msg)
151
+
152
+ target_molecule, target_extension, msg = read_molecule_file(in_protein, allowed_extentions=['pdb'])
153
+ if fragments_molecule is None:
154
+ return output.IFRAME_TEMPLATE.format(html=msg)
155
+
156
+ html = output.FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE.format(
157
+ molecule=fragments_molecule,
158
+ fmt=fragments_extension,
159
+ target=target_molecule,
160
+ target_fmt=target_extension,
161
+ )
162
+
163
+ return output.IFRAME_TEMPLATE.format(html=html)
164
+
165
+
166
+ def clear_fragments_input(in_protein):
167
+ vis = ''
168
+ if in_protein is not None:
169
+ vis = show_target(in_protein)
170
+ return [None, vis, gr.Radio.update(visible=False), None]
171
+
172
+
173
+ def clear_protein_input(in_fragments):
174
+ vis = ''
175
+ if in_fragments is not None:
176
+ vis = show_fragments(in_fragments)
177
+ return [None, vis, gr.Radio.update(visible=False), None]
178
+
179
+
180
+ def click_on_example(example):
181
+ print('Clicked:', example)
182
+ fragment_fname, target_fname = example
183
+ fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
184
+ target_path = f'examples/{target_fname}' if target_fname != '' else None
185
+ return [fragment_path, target_path, 50, 0] + show_input(fragment_path, target_path)
186
 
187
 
188
  def draw_sample(idx, out_files):
189
+ with_protein = (len(out_files) == N_SAMPLES + 3)
190
+
191
  if isinstance(idx, str):
192
  idx = int(idx.strip().split(' ')[-1]) - 1
193
 
194
+ in_file = out_files[1]
195
  in_sdf = in_file if isinstance(in_file, str) else in_file.name
196
+ input_fragments_content = read_molecule_content(in_sdf)
197
+ fragments_fmt = in_sdf.split('.')[-1]
198
 
199
+ offset = 2
200
+ input_target_content = None
201
+ target_fmt = None
202
+ if with_protein:
203
+ offset += 1
204
+ in_pdb = out_files[2] if isinstance(out_files[2], str) else out_files[2].name
205
+ input_target_content = read_molecule_content(in_pdb)
206
+ target_fmt = in_pdb.split('.')[-1]
207
 
208
+ out_file = out_files[idx + offset]
209
+ out_sdf = out_file if isinstance(out_file, str) else out_file.name
210
  generated_molecule_content = read_molecule_content(out_sdf)
 
 
211
  molecule_fmt = out_sdf.split('.')[-1]
212
 
213
+ if with_protein:
214
+ html = output.SAMPLES_WITH_TARGET_RENDERING_TEMPLATE.format(
215
+ fragments=input_fragments_content,
216
+ fragments_fmt=fragments_fmt,
217
+ molecule=generated_molecule_content,
218
+ molecule_fmt=molecule_fmt,
219
+ target=input_target_content,
220
+ target_fmt=target_fmt,
221
+ )
222
+ else:
223
+ html = output.SAMPLES_RENDERING_TEMPLATE.format(
224
+ fragments=input_fragments_content,
225
+ fragments_fmt=fragments_fmt,
226
+ molecule=generated_molecule_content,
227
+ molecule_fmt=molecule_fmt,
228
+ )
229
  return output.IFRAME_TEMPLATE.format(html=html)
230
 
231
 
232
+ def compress(output_fnames, name):
233
+ archive_path = f'results/all_files_{name}.zip'
234
+ with ZipFile(archive_path, 'w') as archive:
235
+ for fname in output_fnames:
236
+ archive.write(fname)
237
+
238
+ return archive_path
239
+
240
+
241
+ def generate(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
242
+ if in_fragments is None:
243
+ return [None, None, None, None]
244
+
245
+ if in_protein is None:
246
+ return generate_without_pocket(in_fragments, n_steps, n_atoms, radio_samples, selected_atoms)
247
+ else:
248
+ return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms)
249
+
250
+
251
+ def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
252
  # Parsing selected atoms (javascript output)
253
  selected_atoms = selected_atoms.strip()
254
  if selected_atoms == '':
 
262
  else:
263
  selected_model_name = 'geom_difflinker_given_anchors'
264
 
 
 
 
265
  print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
266
  ddpm = diffusion_models[selected_model_name]
267
  path = input_file.name
 
272
 
273
  try:
274
  molecule = read_molecule(path)
275
+ try:
276
+ molecule = Chem.RemoveAllHs(molecule)
277
+ except:
278
+ pass
279
  name = '.'.join(path.split('/')[-1].split('.')[:-1])
280
  inp_sdf = f'results/input_{name}.sdf'
281
  except Exception as e:
282
+ e = str(e).replace('\'', '')
283
  error = f'Could not read the molecule: {e}'
284
  msg = output.ERROR_FORMAT_MSG.format(message=error)
285
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
286
 
287
+ if molecule.GetNumAtoms() > 100:
288
+ error = f'Too large molecule: upper limit is 100 heavy atoms'
289
  msg = output.ERROR_FORMAT_MSG.format(message=error)
290
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
291
 
292
  with Chem.SDWriter(inp_sdf) as w:
293
+ w.SetKekulize(False)
294
  w.write(molecule)
295
 
296
  positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
 
334
 
335
  for data in dataloader:
336
  try:
337
+ generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
338
  except Exception as e:
339
+ e = str(e).replace('\'', '')
340
  error = f'Caught exception while generating linkers: {e}'
341
  msg = output.ERROR_FORMAT_MSG.format(message=error)
342
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
343
 
344
  out_files = try_to_convert_to_sdf(name)
345
  out_files = [inp_sdf] + out_files
346
+ out_files = [compress(out_files, name=name)] + out_files
347
+
348
+ return [
349
+ draw_sample(radio_samples, out_files),
350
+ out_files,
351
+ gr.Radio.update(visible=True),
352
+ None
353
+ ]
354
+
355
+
356
+ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
357
+ # Parsing selected atoms (javascript output)
358
+ selected_atoms = selected_atoms.strip()
359
+ if selected_atoms == '':
360
+ selected_atoms = []
361
+ else:
362
+ selected_atoms = list(map(int, selected_atoms.split(',')))
363
+
364
+ # Selecting model
365
+ if len(selected_atoms) == 0:
366
+ selected_model_name = 'pockets_difflinker'
367
+ else:
368
+ selected_model_name = 'pockets_difflinker_given_anchors'
369
+
370
+ print(f'Start generating with model {selected_model_name}, selected_atoms:', selected_atoms)
371
+ ddpm = diffusion_models[selected_model_name]
372
+
373
+ fragments_path = in_fragments.name
374
+ fragments_extension = fragments_path.split('.')[-1]
375
+ if fragments_extension not in ['sdf', 'pdb', 'mol', 'mol2']:
376
+ msg = output.INVALID_FORMAT_MSG.format(extension=fragments_extension)
377
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
378
+
379
+ protein_path = in_protein.name
380
+ protein_extension = protein_path.split('.')[-1]
381
+ if protein_extension not in ['pdb']:
382
+ msg = output.INVALID_FORMAT_MSG.format(extension=protein_extension)
383
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
384
+
385
+ try:
386
+ fragments_mol = read_molecule(fragments_path)
387
+ name = '.'.join(fragments_path.split('/')[-1].split('.')[:-1])
388
+ except Exception as e:
389
+ e = str(e).replace('\'', '')
390
+ error = f'Could not read the molecule: {e}'
391
+ msg = output.ERROR_FORMAT_MSG.format(message=error)
392
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
393
+
394
+ if fragments_mol.GetNumAtoms() > 100:
395
+ error = f'Too large molecule: upper limit is 100 heavy atoms'
396
+ msg = output.ERROR_FORMAT_MSG.format(message=error)
397
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
398
+
399
+ inp_sdf = f'results/input_{name}.sdf'
400
+ with Chem.SDWriter(inp_sdf) as w:
401
+ w.SetKekulize(False)
402
+ w.write(fragments_mol)
403
+
404
+ inp_pdb = f'results/target_{name}.pdb'
405
+ shutil.copy(protein_path, inp_pdb)
406
+
407
+ frag_pos, frag_one_hot, frag_charges = parse_molecule(fragments_mol, is_geom=True)
408
+ pocket_pos, pocket_one_hot, pocket_charges = get_pocket(fragments_mol, protein_path)
409
+ print(f'Detected pocket with {len(pocket_pos)} atoms')
410
+
411
+ positions = np.concatenate([frag_pos, pocket_pos], axis=0)
412
+ one_hot = np.concatenate([frag_one_hot, pocket_one_hot], axis=0)
413
+ charges = np.concatenate([frag_charges, pocket_charges], axis=0)
414
+ anchors = np.zeros_like(charges)
415
+ anchors[selected_atoms] = 1
416
+
417
+ fragment_only_mask = np.concatenate([
418
+ np.ones_like(frag_charges),
419
+ np.zeros_like(pocket_charges),
420
+ ])
421
+ pocket_mask = np.concatenate([
422
+ np.zeros_like(frag_charges),
423
+ np.ones_like(pocket_charges),
424
+ ])
425
+ linker_mask = np.concatenate([
426
+ np.zeros_like(frag_charges),
427
+ np.zeros_like(pocket_charges),
428
+ ])
429
+ fragment_mask = np.concatenate([
430
+ np.ones_like(frag_charges),
431
+ np.ones_like(pocket_charges),
432
+ ])
433
+ print('Read and parsed molecule')
434
+
435
+ dataset = [{
436
+ 'uuid': '0',
437
+ 'name': '0',
438
+ 'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
439
+ 'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
440
+ 'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
441
+ 'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
442
+ 'fragment_only_mask': torch.tensor(fragment_only_mask, dtype=const.TORCH_FLOAT, device=device),
443
+ 'pocket_mask': torch.tensor(pocket_mask, dtype=const.TORCH_FLOAT, device=device),
444
+ 'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
445
+ 'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
446
+ 'num_atoms': len(positions),
447
+ }] * N_SAMPLES
448
+ dataset = MOADDataset(data=dataset)
449
+ ddpm.val_dataset = dataset
450
+
451
+ dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
452
+ print('Created dataloader')
453
+
454
+ ddpm.edm.T = n_steps
455
+
456
+ if n_atoms == 0:
457
+ def sample_fn(_data):
458
+ out, _ = size_nn.forward(_data, return_loss=False)
459
+ probabilities = torch.softmax(out, dim=1)
460
+ distribution = torch.distributions.Categorical(probs=probabilities)
461
+ samples = distribution.sample()
462
+ sizes = []
463
+ for label in samples.detach().cpu().numpy():
464
+ sizes.append(size_nn.linker_id2size[label])
465
+ sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
466
+ return sizes
467
+ else:
468
+ def sample_fn(_data):
469
+ return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
470
+
471
+ for data in dataloader:
472
+ try:
473
+ generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=True)
474
+ except Exception as e:
475
+ e = str(e).replace('\'', '')
476
+ error = f'Caught exception while generating linkers: {e}'
477
+ msg = output.ERROR_FORMAT_MSG.format(message=error)
478
+ return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
479
+
480
+ out_files = try_to_convert_to_sdf(name)
481
+ out_files = [inp_sdf, inp_pdb] + out_files
482
+ out_files = [compress(out_files, name=name)] + out_files
483
 
484
  return [
485
  draw_sample(radio_samples, out_files),
 
505
  with gr.Box():
506
  with gr.Row():
507
  with gr.Column():
508
+ gr.Markdown('## Input')
509
  gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format:')
510
+ with gr.Column():
511
+ input_fragments_file = gr.File(
512
+ file_count='single',
513
+ label='Input Fragments',
514
+ file_types=['.sdf', '.pdb', '.mol', '.mol2']
515
+ )
516
+ # gr.Markdown('(Optionally) upload the file of the target protein in .pdb format:')
517
+ with gr.Column():
518
+ input_protein_file = gr.File(file_count='single', label='Target Protein', file_types=['.pdb'])
519
+
520
+ n_steps = gr.Slider(minimum=50, maximum=500, label="Number of Denoising Steps", step=10)
521
  n_atoms = gr.Slider(
522
  minimum=0, maximum=20,
523
  label="Linker Size: DiffLinker will predict it if set to 0",
524
  step=1
525
  )
526
  examples = gr.Dataset(
527
+ components=[gr.File(visible=False), gr.File(visible=False)],
528
+ samples=[
529
+ ['examples/example_1.sdf', None],
530
+ ['examples/example_2.sdf', None],
531
+ ['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
532
+ ['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
533
+ ],
534
+ headers=['Fragments', 'Target Protein'],
535
+ type='values',
536
  )
537
 
538
  button = gr.Button('Generate Linker!')
 
554
  )
555
  visualization = gr.HTML()
556
 
557
+ input_fragments_file.change(
558
  fn=show_input,
559
+ inputs=[input_fragments_file, input_protein_file],
560
  outputs=[visualization, samples, hidden],
561
  )
562
+ input_protein_file.change(
563
+ fn=show_input,
564
+ inputs=[input_fragments_file, input_protein_file],
565
+ outputs=[visualization, samples, hidden],
566
+ )
567
+ input_fragments_file.clear(
568
+ fn=clear_fragments_input,
569
+ inputs=[input_protein_file],
570
+ outputs=[input_fragments_file, visualization, samples, hidden],
571
+ )
572
+ input_protein_file.clear(
573
+ fn=clear_protein_input,
574
+ inputs=[input_fragments_file],
575
+ outputs=[input_protein_file, visualization, samples, hidden],
576
  )
577
  examples.click(
578
+ fn=click_on_example,
579
  inputs=[examples],
580
+ outputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, visualization, samples, hidden]
581
  )
582
  button.click(
583
  fn=generate,
584
+ inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, samples, hidden],
585
  outputs=[visualization, output_files, samples, hidden],
586
  _js=output.RETURN_SELECTION_JS,
587
  )
examples/3hz1_fragments.sdf ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fragments
2
+ PyMOL2.5 3D 0
3
+
4
+ 23 25 0 0 0 0 0 0 0 0999 V2000
5
+ 0.7050 10.1160 25.5000 C 0 0 0 0 0 0 0 0 0 0 0 0
6
+ -0.4250 10.6930 24.7810 C 0 0 0 0 0 0 0 0 0 0 0 0
7
+ -1.6420 10.9060 25.5370 C 0 0 0 0 0 0 0 0 0 0 0 0
8
+ -1.7510 10.5210 26.8370 N 0 0 0 0 0 0 0 0 0 0 0 0
9
+ -0.6900 9.9510 27.4380 C 0 0 0 0 0 0 0 0 0 0 0 0
10
+ 0.4770 9.7630 26.7990 N 0 0 0 0 0 0 0 0 0 0 0 0
11
+ -0.6830 11.1870 23.5600 N 0 0 0 0 0 0 0 0 0 0 0 0
12
+ -1.9660 11.6240 23.5390 C 0 0 0 0 0 0 0 0 0 0 0 0
13
+ -2.5810 11.4250 24.7070 N 0 0 0 0 0 0 0 0 0 0 0 0
14
+ 1.9520 9.8170 24.8700 N 0 0 0 0 0 0 0 0 0 0 0 0
15
+ 3.1230 9.3980 25.6290 C 0 0 0 0 0 0 0 0 0 0 0 0
16
+ 2.1100 9.7530 23.4320 C 0 0 0 0 0 0 0 0 0 0 0 0
17
+ 7.8600 10.1360 22.6040 C 0 0 0 0 0 0 0 0 0 0 0 0
18
+ 6.5530 9.6800 22.8080 C 0 0 0 0 0 0 0 0 0 0 0 0
19
+ 5.8720 10.7150 23.6130 O 0 0 0 0 0 0 0 0 0 0 0 0
20
+ 6.8390 11.6780 23.7840 C 0 0 0 0 0 0 0 0 0 0 0 0
21
+ 8.0580 11.3690 23.2280 C 0 0 0 0 0 0 0 0 0 0 0 0
22
+ 6.6560 12.9400 24.5720 C 0 0 0 0 0 0 0 0 0 0 0 0
23
+ 7.6630 13.4980 25.2340 N 0 0 0 0 0 0 0 0 0 0 0 0
24
+ 7.1190 14.6210 25.8930 N 0 0 0 0 0 0 0 0 0 0 0 0
25
+ 5.8050 14.8140 25.6500 C 0 0 0 0 0 0 0 0 0 0 0 0
26
+ 5.4220 13.6990 24.7720 C 0 0 0 0 0 0 0 0 0 0 0 0
27
+ 4.9170 15.9400 26.1920 C 0 0 0 0 0 0 0 0 0 0 0 0
28
+ 1 2 4 0 0 0 0
29
+ 1 6 4 0 0 0 0
30
+ 1 10 1 0 0 0 0
31
+ 2 3 4 0 0 0 0
32
+ 2 7 4 0 0 0 0
33
+ 3 4 4 0 0 0 0
34
+ 3 9 4 0 0 0 0
35
+ 4 5 4 0 0 0 0
36
+ 5 6 4 0 0 0 0
37
+ 7 8 4 0 0 0 0
38
+ 8 9 4 0 0 0 0
39
+ 10 11 1 0 0 0 0
40
+ 10 12 1 0 0 0 0
41
+ 13 14 4 0 0 0 0
42
+ 13 17 4 0 0 0 0
43
+ 14 15 4 0 0 0 0
44
+ 15 16 4 0 0 0 0
45
+ 16 17 4 0 0 0 0
46
+ 16 18 1 0 0 0 0
47
+ 18 19 4 0 0 0 0
48
+ 18 22 4 0 0 0 0
49
+ 19 20 4 0 0 0 0
50
+ 20 21 4 0 0 0 0
51
+ 21 22 4 0 0 0 0
52
+ 21 23 1 0 0 0 0
53
+ M END
54
+ $$$$
examples/3hz1_protein.pdb ADDED
The diff for this file is too large to render. See raw diff
 
examples/5ou2_fragments.sdf ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 5ou2_fragments
2
+ PyMOL2.5 3D 0
3
+
4
+ 24 26 0 0 0 0 0 0 0 0999 V2000
5
+ 135.6651 -15.3583 0.1325 N 0 0 0 0 0 0 0 0 0 0 0 0
6
+ 134.8356 -14.4706 -0.4078 C 0 0 0 0 0 0 0 0 0 0 0 0
7
+ 134.5969 -13.5549 0.5236 N 0 0 0 0 0 0 0 0 0 0 0 0
8
+ 135.2672 -13.8787 1.6104 C 0 0 0 0 0 0 0 0 0 0 0 0
9
+ 135.9361 -15.0095 1.3626 C 0 0 0 0 0 0 0 0 0 0 0 0
10
+ 135.2407 -13.1072 2.8878 C 0 0 0 0 0 0 0 0 0 0 0 0
11
+ 135.5339 -13.7328 4.0539 C 0 0 0 0 0 0 0 0 0 0 0 0
12
+ 135.5239 -13.0695 5.2284 C 0 0 0 0 0 0 0 0 0 0 0 0
13
+ 135.1995 -11.7489 5.2810 C 0 0 0 0 0 0 0 0 0 0 0 0
14
+ 134.9023 -11.1173 4.1089 C 0 0 0 0 0 0 0 0 0 0 0 0
15
+ 134.9113 -11.7774 2.9035 C 0 0 0 0 0 0 0 0 0 0 0 0
16
+ 135.1362 -10.8138 6.9517 Br 0 0 0 0 0 0 0 0 0 0 0 0
17
+ 126.8521 -19.0355 0.2522 N 0 0 0 0 0 0 0 0 0 0 0 0
18
+ 126.0921 -18.0299 -0.2360 C 0 0 0 0 0 0 0 0 0 0 0 0
19
+ 126.8721 -17.2548 -1.0322 N 0 0 0 0 0 0 0 0 0 0 0 0
20
+ 128.1098 -17.7707 -1.0325 C 0 0 0 0 0 0 0 0 0 0 0 0
21
+ 128.0889 -18.8815 -0.2256 C 0 0 0 0 0 0 0 0 0 0 0 0
22
+ 129.3145 -17.2106 -1.7791 C 0 0 0 0 0 0 0 0 0 0 0 0
23
+ 130.5850 -17.7185 -1.5264 C 0 0 0 0 0 0 0 0 0 0 0 0
24
+ 131.6879 -17.2095 -2.1865 C 0 0 0 0 0 0 0 0 0 0 0 0
25
+ 131.5211 -16.1844 -3.1052 C 0 0 0 0 0 0 0 0 0 0 0 0
26
+ 130.2586 -15.6644 -3.3699 C 0 0 0 0 0 0 0 0 0 0 0 0
27
+ 129.1548 -16.1795 -2.7058 C 0 0 0 0 0 0 0 0 0 0 0 0
28
+ 133.0656 -15.5029 -4.0086 Br 0 0 0 0 0 0 0 0 0 0 0 0
29
+ 1 2 4 0 0 0 0
30
+ 2 3 4 0 0 0 0
31
+ 3 4 4 0 0 0 0
32
+ 4 6 1 0 0 0 0
33
+ 1 5 4 0 0 0 0
34
+ 4 5 4 0 0 0 0
35
+ 6 7 4 0 0 0 0
36
+ 6 11 4 0 0 0 0
37
+ 7 8 4 0 0 0 0
38
+ 8 9 4 0 0 0 0
39
+ 9 10 4 0 0 0 0
40
+ 9 12 1 0 0 0 0
41
+ 10 11 4 0 0 0 0
42
+ 13 14 4 0 0 0 0
43
+ 14 15 4 0 0 0 0
44
+ 15 16 4 0 0 0 0
45
+ 16 18 1 0 0 0 0
46
+ 13 17 4 0 0 0 0
47
+ 16 17 4 0 0 0 0
48
+ 18 19 4 0 0 0 0
49
+ 18 23 4 0 0 0 0
50
+ 19 20 4 0 0 0 0
51
+ 20 21 4 0 0 0 0
52
+ 21 22 4 0 0 0 0
53
+ 21 24 1 0 0 0 0
54
+ 22 23 4 0 0 0 0
55
+ M END
56
+ $$$$
examples/5ou2_protein.pdb ADDED
The diff for this file is too large to render. See raw diff
 
output.py CHANGED
@@ -1,4 +1,4 @@
1
- INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
2
  <html>
3
  <head>
4
  <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
@@ -26,7 +26,6 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
26
  let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
27
  viewer.addModel(`{molecule}`, "{fmt}");
28
  viewer.getModel(0).setStyle(defaultStyle);
29
- // document.cookie = document.cookie + "|selected_atoms:";
30
 
31
  viewer.getModel(0).setClickable(
32
  {{}},
@@ -38,20 +37,16 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
38
  {{"serial": _atom.serial, "model": 0}},
39
  {{"sphere": {{"color": "magenta", "radius": 0.4}} }}
40
  );
41
- // document.cookie = document.cookie + "atom_" + String(_atom.serial) + "-";
42
  window.parent.postMessage({{
43
  name: "atom_selection",
44
  data: {{"atom": _atom.serial, "add": true}}
45
- // data: JSON.stringify({{"add": _atom.serial}})
46
  }}, "*");
47
  }} else {{
48
  delete _atom.isClicked;
49
  _viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
50
- // document.cookie = document.cookie.replace("atom_" + String(_atom.serial) + "-", "");
51
  window.parent.postMessage({{
52
  name: "atom_selection",
53
  data: {{"atom": _atom.serial, "add": false}}
54
- // data: JSON.stringify({{"remove": _atom.serial}})
55
  }}, "*");
56
  }}
57
  _viewer.render();
@@ -67,6 +62,112 @@ INITIAL_RENDERING_TEMPLATE = """<!DOCTYPE html>
67
  </html>
68
  """
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
72
  <html>
@@ -88,6 +189,7 @@ SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
88
 
89
  <body>
90
  <div id="container" class="mol-container"></div>
 
91
  <button id="fragments">Input Fragments</button>
92
  <button id="molecule">Output Molecule</button>
93
  <script>
@@ -120,6 +222,74 @@ SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
120
  </html>
121
  """
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  INVALID_FORMAT_MSG = """
125
  <!DOCTYPE html>
@@ -135,13 +305,18 @@ INVALID_FORMAT_MSG = """
135
 
136
  <body>
137
  <h3>Invalid file format: {extension}</h3>
138
- Please upload the file in one of the following formats:
139
  <ul>
140
  <li>.pdb</li>
141
  <li>.sdf</li>
142
  <li>.mol</li>
143
  <li>.mol2</li>
144
  </ul>
 
 
 
 
 
145
  </body>
146
  </html>
147
  """
@@ -190,7 +365,7 @@ STARTUP_JS = """
190
  """
191
 
192
  RETURN_SELECTION_JS = """
193
- (input_file, n_steps, n_atoms, samples, hidden) => {
194
  let selected = []
195
  for (const [atom, add] of Object.entries(window.selected_elements)) {
196
  if (add) {
@@ -203,6 +378,6 @@ RETURN_SELECTION_JS = """
203
  }
204
  }
205
  console.log("Finished parsing");
206
- return [input_file, n_steps, n_atoms, samples, selected.join(",")];
207
  }
208
  """
 
1
+ FRAGMENTS_RENDERING_TEMPLATE = """<!DOCTYPE html>
2
  <html>
3
  <head>
4
  <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
 
26
  let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
27
  viewer.addModel(`{molecule}`, "{fmt}");
28
  viewer.getModel(0).setStyle(defaultStyle);
 
29
 
30
  viewer.getModel(0).setClickable(
31
  {{}},
 
37
  {{"serial": _atom.serial, "model": 0}},
38
  {{"sphere": {{"color": "magenta", "radius": 0.4}} }}
39
  );
 
40
  window.parent.postMessage({{
41
  name: "atom_selection",
42
  data: {{"atom": _atom.serial, "add": true}}
 
43
  }}, "*");
44
  }} else {{
45
  delete _atom.isClicked;
46
  _viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
 
47
  window.parent.postMessage({{
48
  name: "atom_selection",
49
  data: {{"atom": _atom.serial, "add": false}}
 
50
  }}, "*");
51
  }}
52
  _viewer.render();
 
62
  </html>
63
  """
64
 
65
+ TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
66
+ <html>
67
+ <head>
68
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
69
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
70
+ <script src="https://3Dmol.org/build/3Dmol.js"></script>
71
+ <style>
72
+ .mol-container {{
73
+ width: 600px;
74
+ height: 600px;
75
+ position: relative;
76
+ }}
77
+ .mol-container select{{
78
+ background-image:None;
79
+ }}
80
+ </style>
81
+ </head>
82
+
83
+ <body>
84
+ <div id="container" class="mol-container"></div>
85
+ <script>
86
+ $(document).ready(function() {{
87
+ let element = $("#container");
88
+ let config = {{ backgroundColor: "white" }};
89
+ let viewer = $3Dmol.createViewer(element, config);
90
+ let proteinStyle = {{ cartoon: {{ colorscheme: "ssPyMOL" }} }};
91
+ viewer.addModel(`{molecule}`, "{fmt}");
92
+ viewer.getModel(0).setStyle(proteinStyle);
93
+
94
+ viewer.zoomTo();
95
+ viewer.zoom(0.7);
96
+ viewer.render();
97
+ }});
98
+ </script>
99
+ </body>
100
+ </html>
101
+ """
102
+
103
+ FRAGMENTS_AND_TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
104
+ <html>
105
+ <head>
106
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
107
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
108
+ <script src="https://3Dmol.org/build/3Dmol.js"></script>
109
+ <style>
110
+ .mol-container {{
111
+ width: 600px;
112
+ height: 600px;
113
+ position: relative;
114
+ }}
115
+ .mol-container select{{
116
+ background-image:None;
117
+ }}
118
+ </style>
119
+ </head>
120
+
121
+ <body>
122
+ <div id="container" class="mol-container"></div>
123
+ <script>
124
+ $(document).ready(function() {{
125
+ let element = $("#container");
126
+ let config = {{ backgroundColor: "white" }};
127
+ let viewer = $3Dmol.createViewer(element, config);
128
+ let defaultStyle = {{ stick: {{ colorscheme: "greenCarbon" }} }};
129
+ let proteinStyle = {{ cartoon: {{ colorscheme: "ssPyMOL" }} }};
130
+
131
+ viewer.addModel(`{molecule}`, "{fmt}");
132
+ viewer.getModel(0).setStyle(defaultStyle);
133
+ viewer.getModel(0).setClickable(
134
+ {{}},
135
+ true,
136
+ function (_atom, _viewer, _event, _container) {{
137
+ if (!_atom.isClicked) {{
138
+ _atom.isClicked = true;
139
+ _viewer.addStyle(
140
+ {{"serial": _atom.serial, "model": 0}},
141
+ {{"sphere": {{"color": "magenta", "radius": 0.4}} }}
142
+ );
143
+ window.parent.postMessage({{
144
+ name: "atom_selection",
145
+ data: {{"atom": _atom.serial, "add": true}}
146
+ }}, "*");
147
+ }} else {{
148
+ delete _atom.isClicked;
149
+ _viewer.setStyle({{"serial": _atom.serial, "model": 0}}, defaultStyle);
150
+ window.parent.postMessage({{
151
+ name: "atom_selection",
152
+ data: {{"atom": _atom.serial, "add": false}}
153
+ }}, "*");
154
+ }}
155
+ _viewer.render();
156
+ }}
157
+ );
158
+
159
+ viewer.addModel(`{target}`, "{target_fmt}");
160
+ viewer.getModel(1).setStyle(proteinStyle);
161
+
162
+ viewer.zoomTo();
163
+ viewer.zoom(0.7);
164
+ viewer.render();
165
+ }});
166
+ </script>
167
+ </body>
168
+ </html>
169
+ """
170
+
171
 
172
  SAMPLES_RENDERING_TEMPLATE = """<!DOCTYPE html>
173
  <html>
 
189
 
190
  <body>
191
  <div id="container" class="mol-container"></div>
192
+ <br>
193
  <button id="fragments">Input Fragments</button>
194
  <button id="molecule">Output Molecule</button>
195
  <script>
 
222
  </html>
223
  """
224
 
225
+ SAMPLES_WITH_TARGET_RENDERING_TEMPLATE = """<!DOCTYPE html>
226
+ <html>
227
+ <head>
228
+ <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
229
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
230
+ <script src="https://3Dmol.org/build/3Dmol.js"></script>
231
+ <style>
232
+ .mol-container {{
233
+ width: 600px;
234
+ height: 600px;
235
+ position: relative;
236
+ }}
237
+ .mol-container select{{
238
+ background-image:None;
239
+ }}
240
+ </style>
241
+ </head>
242
+
243
+ <body>
244
+ <div id="container" class="mol-container"></div>
245
+ <br>
246
+ <button id="fragments">Input Fragments</button>
247
+ <button id="molecule">Output Molecule</button>
248
+ <button id="show-target">Show Target</button>
249
+ <button id="hide-target">Hide Target</button>
250
+ <script>
251
+ let element = $("#container");
252
+ let config = {{ backgroundColor: "white" }};
253
+ let viewer = $3Dmol.createViewer( element, config );
254
+
255
+ $(document).ready(function() {{
256
+ viewer.addModel(`{fragments}`, "{fragments_fmt}")
257
+ viewer.getModel(0).setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
258
+ viewer.getModel(0).hide();
259
+
260
+ viewer.addModel(`{molecule}`, "{molecule_fmt}")
261
+ viewer.getModel(1).setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
262
+
263
+ viewer.addModel(`{target}`, "{target_fmt}")
264
+ viewer.getModel(2).setStyle({{ cartoon: {{ colorscheme: "ssPyMOL" }} }})
265
+
266
+ viewer.zoomTo();
267
+ viewer.zoom(0.7);
268
+ viewer.render();
269
+ }});
270
+ $("#fragments").click(function() {{
271
+ viewer.getModel(0).show();
272
+ viewer.getModel(1).hide();
273
+ viewer.render();
274
+ }});
275
+ $("#molecule").click(function() {{
276
+ viewer.getModel(1).show();
277
+ viewer.getModel(0).hide();
278
+ viewer.render();
279
+ }});
280
+ $("#show-target").click(function() {{
281
+ viewer.getModel(2).show();
282
+ viewer.render();
283
+ }});
284
+ $("#hide-target").click(function() {{
285
+ viewer.getModel(2).hide();
286
+ viewer.render();
287
+ }});
288
+ </script>
289
+ </body>
290
+ </html>
291
+ """
292
+
293
 
294
  INVALID_FORMAT_MSG = """
295
  <!DOCTYPE html>
 
305
 
306
  <body>
307
  <h3>Invalid file format: {extension}</h3>
308
+ Allowed formats for the fragments file:
309
  <ul>
310
  <li>.pdb</li>
311
  <li>.sdf</li>
312
  <li>.mol</li>
313
  <li>.mol2</li>
314
  </ul>
315
+
316
+ Allowed formats for the optional protein file:
317
+ <ul>
318
+ <li>.pdb</li>
319
+ </ul>
320
  </body>
321
  </html>
322
  """
 
365
  """
366
 
367
  RETURN_SELECTION_JS = """
368
+ (input_file, input_protein_file, n_steps, n_atoms, samples, hidden) => {
369
  let selected = []
370
  for (const [atom, add] of Object.entries(window.selected_elements)) {
371
  if (add) {
 
378
  }
379
  }
380
  console.log("Finished parsing");
381
+ return [input_file, input_protein_file, n_steps, n_atoms, samples, selected.join(",")];
382
  }
383
  """
src/datasets.py CHANGED
@@ -101,15 +101,25 @@ class ZincDataset(Dataset):
101
 
102
 
103
  class MOADDataset(Dataset):
104
- def __init__(self, data_path, prefix, device):
105
- prefix, pocket_mode = prefix.split('.')
 
 
 
 
 
 
 
 
 
 
106
 
107
  dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
108
  if os.path.exists(dataset_path):
109
  self.data = torch.load(dataset_path, map_location=device)
110
  else:
111
  print(f'Preprocessing dataset with prefix {prefix}')
112
- self.data = MOADDataset.preprocess(data_path, prefix, pocket_mode, device)
113
  torch.save(self.data, dataset_path)
114
 
115
  def __len__(self):
@@ -264,7 +274,7 @@ def collate_with_fragment_edges(batch):
264
  out = {}
265
 
266
  # Filter out big molecules
267
- batch = [data for data in batch if data['num_atoms'] <= 50]
268
 
269
  for i, data in enumerate(batch):
270
  for key, value in data.items():
 
101
 
102
 
103
  class MOADDataset(Dataset):
104
+ def __init__(self, data=None, data_path=None, prefix=None, device=None):
105
+ assert (data is not None) or all(x is not None for x in (data_path, prefix, device))
106
+ if data is not None:
107
+ self.data = data
108
+ return
109
+
110
+ if '.' in prefix:
111
+ prefix, pocket_mode = prefix.split('.')
112
+ else:
113
+ parts = prefix.split('_')
114
+ prefix = '_'.join(parts[:-1])
115
+ pocket_mode = parts[-1]
116
 
117
  dataset_path = os.path.join(data_path, f'{prefix}_{pocket_mode}.pt')
118
  if os.path.exists(dataset_path):
119
  self.data = torch.load(dataset_path, map_location=device)
120
  else:
121
  print(f'Preprocessing dataset with prefix {prefix}')
122
+ self.data = self.preprocess(data_path, prefix, pocket_mode, device)
123
  torch.save(self.data, dataset_path)
124
 
125
  def __len__(self):
 
274
  out = {}
275
 
276
  # Filter out big molecules
277
+ # batch = [data for data in batch if data['num_atoms'] <= 50]
278
 
279
  for i, data in enumerate(batch):
280
  for key, value in data.items():
src/generation.py CHANGED
@@ -1,24 +1,44 @@
 
1
  import os.path
2
  import subprocess
3
  import torch
4
 
 
 
5
  from src.visualizer import save_xyz_file
 
 
6
 
7
  N_SAMPLES = 5
8
 
9
 
10
- def generate_linkers(ddpm, data, sample_fn, name):
11
- chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
 
 
 
 
 
 
 
12
  print('Generated linker')
13
  x = chain[0][:, :, :ddpm.n_dims]
14
  h = chain[0][:, :, ddpm.n_dims:]
15
 
16
  # Put the molecule back to the initial orientation
17
- pos_masked = data['positions'] * data['fragment_mask']
18
- N = data['fragment_mask'].sum(1, keepdims=True)
 
 
 
 
 
19
  mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
20
  x = x + mean * node_mask
21
 
 
 
 
22
  names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
23
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
24
  print('Saved XYZ files')
@@ -36,3 +56,62 @@ def try_to_convert_to_sdf(name):
36
  out_files.append(out_xyz)
37
 
38
  return out_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
  import os.path
3
  import subprocess
4
  import torch
5
 
6
+ from Bio.PDB import PDBParser
7
+ from src import const
8
  from src.visualizer import save_xyz_file
9
+ from src.utils import FoundNaNException
10
+ from src.datasets import get_one_hot
11
 
12
  N_SAMPLES = 5
13
 
14
 
15
+ def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
16
+ chain = node_mask = None
17
+ for i in range(5):
18
+ try:
19
+ chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
20
+ break
21
+ except FoundNaNException:
22
+ continue
23
+
24
  print('Generated linker')
25
  x = chain[0][:, :, :ddpm.n_dims]
26
  h = chain[0][:, :, ddpm.n_dims:]
27
 
28
  # Put the molecule back to the initial orientation
29
+ if with_pocket:
30
+ com_mask = data['fragment_only_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
31
+ else:
32
+ com_mask = data['fragment_mask'] if ddpm.center_of_mass == 'fragments' else data['anchors']
33
+
34
+ pos_masked = data['positions'] * com_mask
35
+ N = com_mask.sum(1, keepdims=True)
36
  mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
37
  x = x + mean * node_mask
38
 
39
+ if with_pocket:
40
+ node_mask[torch.where(data['pocket_mask'])] = 0
41
+
42
  names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
43
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
44
  print('Saved XYZ files')
 
56
  out_files.append(out_xyz)
57
 
58
  return out_files
59
+
60
+
61
+ def get_pocket(mol, pdb_path):
62
+ struct = PDBParser().get_structure('', pdb_path)
63
+ residue_ids = []
64
+ atom_coords = []
65
+
66
+ for residue in struct.get_residues():
67
+ resid = residue.get_id()[1]
68
+ for atom in residue.get_atoms():
69
+ atom_coords.append(atom.get_coord())
70
+ residue_ids.append(resid)
71
+
72
+ residue_ids = np.array(residue_ids)
73
+ atom_coords = np.array(atom_coords)
74
+ mol_atom_coords = mol.GetConformer().GetPositions()
75
+
76
+ distances = np.linalg.norm(atom_coords[:, None, :] - mol_atom_coords[None, :, :], axis=-1)
77
+ contact_residues = np.unique(residue_ids[np.where(distances.min(1) <= 6)[0]])
78
+
79
+ pocket_coords_full = []
80
+ pocket_types_full = []
81
+
82
+ pocket_coords_bb = []
83
+ pocket_types_bb = []
84
+
85
+ for residue in struct.get_residues():
86
+ resid = residue.get_id()[1]
87
+ if resid not in contact_residues:
88
+ continue
89
+
90
+ for atom in residue.get_atoms():
91
+ atom_name = atom.get_name()
92
+ atom_type = atom.element.upper()
93
+ atom_coord = atom.get_coord()
94
+
95
+ pocket_coords_full.append(atom_coord.tolist())
96
+ pocket_types_full.append(atom_type)
97
+
98
+ if atom_name in {'N', 'CA', 'C', 'O'}:
99
+ pocket_coords_bb.append(atom_coord.tolist())
100
+ pocket_types_bb.append(atom_type)
101
+
102
+ pocket_pos = []
103
+ pocket_one_hot = []
104
+ pocket_charges = []
105
+ for coord, atom_type in zip(pocket_coords_full, pocket_types_full):
106
+ if atom_type not in const.GEOM_ATOM2IDX.keys():
107
+ continue
108
+
109
+ pocket_pos.append(coord)
110
+ pocket_one_hot.append(get_one_hot(atom_type, const.GEOM_ATOM2IDX))
111
+ pocket_charges.append(const.GEOM_CHARGES[atom_type])
112
+
113
+ pocket_pos = np.array(pocket_pos)
114
+ pocket_one_hot = np.array(pocket_one_hot)
115
+ pocket_charges = np.array(pocket_charges)
116
+
117
+ return pocket_pos, pocket_one_hot, pocket_charges
src/lightning.py CHANGED
@@ -21,7 +21,6 @@ from pdb import set_trace
21
 
22
 
23
  def get_activation(activation):
24
- print(activation)
25
  if activation == 'silu':
26
  return torch.nn.SiLU()
27
  else:
@@ -158,7 +157,7 @@ class DDPM(pl.LightningModule):
158
  context = fragment_mask
159
 
160
  # Add information about pocket to the context
161
- if '.' in self.train_data_prefix:
162
  fragment_pocket_mask = fragment_mask
163
  fragment_only_mask = data['fragment_only_mask']
164
  pocket_only_mask = fragment_pocket_mask - fragment_only_mask
@@ -170,6 +169,8 @@ class DDPM(pl.LightningModule):
170
  # Removing COM of fragment from the atom coordinates
171
  if self.inpainting:
172
  center_of_mass_mask = node_mask
 
 
173
  elif self.center_of_mass == 'fragments':
174
  center_of_mass_mask = fragment_mask
175
  elif self.center_of_mass == 'anchors':
@@ -423,9 +424,9 @@ class DDPM(pl.LightningModule):
423
  context = fragment_mask
424
 
425
  # Add information about pocket to the context
426
- if '.' in self.train_data_prefix:
427
  fragment_pocket_mask = fragment_mask
428
- fragment_only_mask = data['fragment_only_mask']
429
  pocket_only_mask = fragment_pocket_mask - fragment_only_mask
430
  if self.anchors_context:
431
  context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
@@ -435,6 +436,8 @@ class DDPM(pl.LightningModule):
435
  # Removing COM of fragment from the atom coordinates
436
  if self.inpainting:
437
  center_of_mass_mask = node_mask
 
 
438
  elif self.center_of_mass == 'fragments':
439
  center_of_mass_mask = fragment_mask
440
  elif self.center_of_mass == 'anchors':
 
21
 
22
 
23
  def get_activation(activation):
 
24
  if activation == 'silu':
25
  return torch.nn.SiLU()
26
  else:
 
157
  context = fragment_mask
158
 
159
  # Add information about pocket to the context
160
+ if isinstance(self.train_dataset, MOADDataset):
161
  fragment_pocket_mask = fragment_mask
162
  fragment_only_mask = data['fragment_only_mask']
163
  pocket_only_mask = fragment_pocket_mask - fragment_only_mask
 
169
  # Removing COM of fragment from the atom coordinates
170
  if self.inpainting:
171
  center_of_mass_mask = node_mask
172
+ elif isinstance(self.train_dataset, MOADDataset) and self.center_of_mass == 'fragments':
173
+ center_of_mass_mask = data['fragment_only_mask']
174
  elif self.center_of_mass == 'fragments':
175
  center_of_mass_mask = fragment_mask
176
  elif self.center_of_mass == 'anchors':
 
424
  context = fragment_mask
425
 
426
  # Add information about pocket to the context
427
+ if isinstance(self.val_dataset, MOADDataset):
428
  fragment_pocket_mask = fragment_mask
429
+ fragment_only_mask = template_data['fragment_only_mask']
430
  pocket_only_mask = fragment_pocket_mask - fragment_only_mask
431
  if self.anchors_context:
432
  context = torch.cat([anchors, fragment_only_mask, pocket_only_mask], dim=-1)
 
436
  # Removing COM of fragment from the atom coordinates
437
  if self.inpainting:
438
  center_of_mass_mask = node_mask
439
+ elif isinstance(self.val_dataset, MOADDataset) and self.center_of_mass == 'fragments':
440
+ center_of_mass_mask = template_data['fragment_only_mask']
441
  elif self.center_of_mass == 'fragments':
442
  center_of_mass_mask = fragment_mask
443
  elif self.center_of_mass == 'anchors':