igashov commited on
Commit
abdd514
1 Parent(s): cd2152f

Variable number of samples

Browse files
Files changed (3) hide show
  1. app.py +49 -38
  2. output.py +2 -2
  3. src/generation.py +4 -6
app.py CHANGED
@@ -13,7 +13,7 @@ from src import const
13
  from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
- from src.generation import N_SAMPLES, generate_linkers, try_to_convert_to_sdf, get_pocket
17
  from zipfile import ZipFile
18
 
19
 
@@ -125,7 +125,7 @@ def show_input(in_fragments, in_protein):
125
  vis = show_target(in_protein)
126
  elif in_fragments is not None and in_protein is not None:
127
  vis = show_fragments_and_target(in_fragments, in_protein)
128
- return [vis, gr.Radio.update(visible=False), None]
129
 
130
 
131
  def show_fragments(in_fragments):
@@ -167,28 +167,25 @@ def clear_fragments_input(in_protein):
167
  vis = ''
168
  if in_protein is not None:
169
  vis = show_target(in_protein)
170
- return [None, vis, gr.Radio.update(visible=False), None]
171
 
172
 
173
  def clear_protein_input(in_fragments):
174
  vis = ''
175
  if in_fragments is not None:
176
  vis = show_fragments(in_fragments)
177
- return [None, vis, gr.Radio.update(visible=False), None]
178
 
179
 
180
  def click_on_example(example):
181
  fragment_fname, target_fname = example
182
  fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
183
  target_path = f'examples/{target_fname}' if target_fname != '' else None
184
- return [fragment_path, target_path, 50, 0] + show_input(fragment_path, target_path)
185
 
186
 
187
- def draw_sample(idx, out_files):
188
- with_protein = (len(out_files) == N_SAMPLES + 3)
189
-
190
- if isinstance(idx, str):
191
- idx = int(idx.strip().split(' ')[-1]) - 1
192
 
193
  in_file = out_files[1]
194
  in_sdf = in_file if isinstance(in_file, str) else in_file.name
@@ -204,8 +201,7 @@ def draw_sample(idx, out_files):
204
  input_target_content = read_molecule_content(in_pdb)
205
  target_fmt = in_pdb.split('.')[-1]
206
 
207
- out_file = out_files[idx + offset]
208
- out_sdf = out_file if isinstance(out_file, str) else out_file.name
209
  generated_molecule_content = read_molecule_content(out_sdf)
210
  molecule_fmt = out_sdf.split('.')[-1]
211
 
@@ -237,17 +233,17 @@ def compress(output_fnames, name):
237
  return archive_path
238
 
239
 
240
- def generate(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
241
  if in_fragments is None:
242
  return [None, None, None, None]
243
 
244
  if in_protein is None:
245
- return generate_without_pocket(in_fragments, n_steps, n_atoms, radio_samples, selected_atoms)
246
  else:
247
- return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms)
248
 
249
 
250
- def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selected_atoms):
251
  # Parsing selected atoms (javascript output)
252
  selected_atoms = selected_atoms.strip()
253
  if selected_atoms == '':
@@ -310,8 +306,8 @@ def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selecte
310
  'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
311
  'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
312
  'num_atoms': len(positions),
313
- }] * N_SAMPLES
314
- dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
315
  print('Created dataloader')
316
 
317
  ddpm.edm.T = n_steps
@@ -333,26 +329,33 @@ def generate_without_pocket(input_file, n_steps, n_atoms, radio_samples, selecte
333
 
334
  for data in dataloader:
335
  try:
336
- generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
 
 
337
  except Exception as e:
338
  e = str(e).replace('\'', '')
339
  error = f'Caught exception while generating linkers: {e}'
340
  msg = output.ERROR_FORMAT_MSG.format(message=error)
341
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
342
 
343
- out_files = try_to_convert_to_sdf(name)
344
  out_files = [inp_sdf] + out_files
345
  out_files = [compress(out_files, name=name)] + out_files
 
346
 
347
  return [
348
- draw_sample(radio_samples, out_files),
349
  out_files,
350
- gr.Radio.update(visible=True),
 
 
 
 
351
  None
352
  ]
353
 
354
 
355
- def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_samples, selected_atoms):
356
  # Parsing selected atoms (javascript output)
357
  selected_atoms = selected_atoms.strip()
358
  if selected_atoms == '':
@@ -443,11 +446,11 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_sampl
443
  'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
444
  'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
445
  'num_atoms': len(positions),
446
- }] * N_SAMPLES
447
  dataset = MOADDataset(data=dataset)
448
  ddpm.val_dataset = dataset
449
 
450
- dataloader = get_dataloader(dataset, batch_size=N_SAMPLES, collate_fn=collate_with_fragment_edges)
451
  print('Created dataloader')
452
 
453
  ddpm.edm.T = n_steps
@@ -469,21 +472,28 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, radio_sampl
469
 
470
  for data in dataloader:
471
  try:
472
- generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=True)
 
 
473
  except Exception as e:
474
  e = str(e).replace('\'', '')
475
  error = f'Caught exception while generating linkers: {e}'
476
  msg = output.ERROR_FORMAT_MSG.format(message=error)
477
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
478
 
479
- out_files = try_to_convert_to_sdf(name)
480
  out_files = [inp_sdf, inp_pdb] + out_files
481
  out_files = [compress(out_files, name=name)] + out_files
 
482
 
483
  return [
484
- draw_sample(radio_samples, out_files),
485
  out_files,
486
- gr.Radio.update(visible=True),
 
 
 
 
487
  None
488
  ]
489
 
@@ -516,6 +526,7 @@ with demo:
516
  label="Linker Size: DiffLinker will predict it if set to 0",
517
  step=1
518
  )
 
519
  examples = gr.Dataset(
520
  components=[gr.File(visible=False), gr.File(visible=False)],
521
  samples=[
@@ -524,7 +535,6 @@ with demo:
524
  ['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
525
  ['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
526
  ],
527
- # headers=['Fragments', 'Target Protein'],
528
  type='values',
529
  )
530
 
@@ -537,13 +547,14 @@ with demo:
537
  with gr.Column():
538
  gr.Markdown('## Visualization')
539
  gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)')
540
- samples = gr.Radio(
541
- choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
542
- value='Sample 1',
543
  type='value',
544
- show_label=False,
545
  visible=False,
546
  interactive=True,
 
547
  )
548
  visualization = gr.HTML()
549
 
@@ -570,17 +581,17 @@ with demo:
570
  examples.click(
571
  fn=click_on_example,
572
  inputs=[examples],
573
- outputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, visualization, samples, hidden]
574
  )
575
  button.click(
576
  fn=generate,
577
- inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, samples, hidden],
578
  outputs=[visualization, output_files, samples, hidden],
579
  _js=output.RETURN_SELECTION_JS,
580
  )
581
- samples.change(
582
  fn=draw_sample,
583
- inputs=[samples, output_files],
584
  outputs=[visualization],
585
  )
586
  demo.load(_js=output.STARTUP_JS)
 
13
  from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule, MOADDataset
14
  from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
+ from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
17
  from zipfile import ZipFile
18
 
19
 
 
125
  vis = show_target(in_protein)
126
  elif in_fragments is not None and in_protein is not None:
127
  vis = show_fragments_and_target(in_fragments, in_protein)
128
+ return [vis, gr.Dropdown.update(choices=[], value=None, visible=False), None]
129
 
130
 
131
  def show_fragments(in_fragments):
 
167
  vis = ''
168
  if in_protein is not None:
169
  vis = show_target(in_protein)
170
+ return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None]
171
 
172
 
173
  def clear_protein_input(in_fragments):
174
  vis = ''
175
  if in_fragments is not None:
176
  vis = show_fragments(in_fragments)
177
+ return [None, vis, gr.Dropdown.update(choices=[], value=None, visible=False), None]
178
 
179
 
180
  def click_on_example(example):
181
  fragment_fname, target_fname = example
182
  fragment_path = f'examples/{fragment_fname}' if fragment_fname != '' else None
183
  target_path = f'examples/{target_fname}' if target_fname != '' else None
184
+ return [fragment_path, target_path] + show_input(fragment_path, target_path)
185
 
186
 
187
+ def draw_sample(sample_path, out_files, num_samples):
188
+ with_protein = (len(out_files) == num_samples + 3)
 
 
 
189
 
190
  in_file = out_files[1]
191
  in_sdf = in_file if isinstance(in_file, str) else in_file.name
 
201
  input_target_content = read_molecule_content(in_pdb)
202
  target_fmt = in_pdb.split('.')[-1]
203
 
204
+ out_sdf = sample_path if isinstance(sample_path, str) else sample_path.name
 
205
  generated_molecule_content = read_molecule_content(out_sdf)
206
  molecule_fmt = out_sdf.split('.')[-1]
207
 
 
233
  return archive_path
234
 
235
 
236
+ def generate(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms):
237
  if in_fragments is None:
238
  return [None, None, None, None]
239
 
240
  if in_protein is None:
241
+ return generate_without_pocket(in_fragments, n_steps, n_atoms, num_samples, selected_atoms)
242
  else:
243
+ return generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms)
244
 
245
 
246
+ def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_atoms):
247
  # Parsing selected atoms (javascript output)
248
  selected_atoms = selected_atoms.strip()
249
  if selected_atoms == '':
 
306
  'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
307
  'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
308
  'num_atoms': len(positions),
309
+ }] * num_samples
310
+ dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges)
311
  print('Created dataloader')
312
 
313
  ddpm.edm.T = n_steps
 
329
 
330
  for data in dataloader:
331
  try:
332
+ generate_linkers(
333
+ ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=False
334
+ )
335
  except Exception as e:
336
  e = str(e).replace('\'', '')
337
  error = f'Caught exception while generating linkers: {e}'
338
  msg = output.ERROR_FORMAT_MSG.format(message=error)
339
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
340
 
341
+ out_files = try_to_convert_to_sdf(name, num_samples)
342
  out_files = [inp_sdf] + out_files
343
  out_files = [compress(out_files, name=name)] + out_files
344
+ choice = out_files[2]
345
 
346
  return [
347
+ draw_sample(choice, out_files, num_samples),
348
  out_files,
349
+ gr.Dropdown.update(
350
+ choices=out_files[2:],
351
+ value=choice,
352
+ visible=True,
353
+ ),
354
  None
355
  ]
356
 
357
 
358
+ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples, selected_atoms):
359
  # Parsing selected atoms (javascript output)
360
  selected_atoms = selected_atoms.strip()
361
  if selected_atoms == '':
 
446
  'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
447
  'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
448
  'num_atoms': len(positions),
449
+ }] * num_samples
450
  dataset = MOADDataset(data=dataset)
451
  ddpm.val_dataset = dataset
452
 
453
+ dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges)
454
  print('Created dataloader')
455
 
456
  ddpm.edm.T = n_steps
 
472
 
473
  for data in dataloader:
474
  try:
475
+ generate_linkers(
476
+ ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=True
477
+ )
478
  except Exception as e:
479
  e = str(e).replace('\'', '')
480
  error = f'Caught exception while generating linkers: {e}'
481
  msg = output.ERROR_FORMAT_MSG.format(message=error)
482
  return [output.IFRAME_TEMPLATE.format(html=msg), None, None, None]
483
 
484
+ out_files = try_to_convert_to_sdf(name, num_samples)
485
  out_files = [inp_sdf, inp_pdb] + out_files
486
  out_files = [compress(out_files, name=name)] + out_files
487
+ choice = out_files[3]
488
 
489
  return [
490
+ draw_sample(choice, out_files, num_samples),
491
  out_files,
492
+ gr.Dropdown.update(
493
+ choices=out_files[3:],
494
+ value=choice,
495
+ visible=True,
496
+ ),
497
  None
498
  ]
499
 
 
526
  label="Linker Size: DiffLinker will predict it if set to 0",
527
  step=1
528
  )
529
+ n_samples = gr.Slider(minimum=5, maximum=50, label="Number of Samples", step=5)
530
  examples = gr.Dataset(
531
  components=[gr.File(visible=False), gr.File(visible=False)],
532
  samples=[
 
535
  ['examples/3hz1_fragments.sdf', 'examples/3hz1_protein.pdb'],
536
  ['examples/5ou2_fragments.sdf', 'examples/5ou2_protein.pdb'],
537
  ],
 
538
  type='values',
539
  )
540
 
 
547
  with gr.Column():
548
  gr.Markdown('## Visualization')
549
  gr.Markdown('**Hint:** click on atoms to select anchor points (optionally)')
550
+ samples = gr.Dropdown(
551
+ choices=[],
552
+ value=None,
553
  type='value',
554
+ multiselect=False,
555
  visible=False,
556
  interactive=True,
557
+ label='Samples'
558
  )
559
  visualization = gr.HTML()
560
 
 
581
  examples.click(
582
  fn=click_on_example,
583
  inputs=[examples],
584
+ outputs=[input_fragments_file, input_protein_file, visualization, samples, hidden]
585
  )
586
  button.click(
587
  fn=generate,
588
+ inputs=[input_fragments_file, input_protein_file, n_steps, n_atoms, n_samples, hidden],
589
  outputs=[visualization, output_files, samples, hidden],
590
  _js=output.RETURN_SELECTION_JS,
591
  )
592
+ samples.select(
593
  fn=draw_sample,
594
+ inputs=[samples, output_files, n_samples],
595
  outputs=[visualization],
596
  )
597
  demo.load(_js=output.STARTUP_JS)
output.py CHANGED
@@ -365,7 +365,7 @@ STARTUP_JS = """
365
  """
366
 
367
  RETURN_SELECTION_JS = """
368
- (input_file, input_protein_file, n_steps, n_atoms, samples, hidden) => {
369
  let selected = []
370
  for (const [atom, add] of Object.entries(window.selected_elements)) {
371
  if (add) {
@@ -378,6 +378,6 @@ RETURN_SELECTION_JS = """
378
  }
379
  }
380
  console.log("Finished parsing");
381
- return [input_file, input_protein_file, n_steps, n_atoms, samples, selected.join(",")];
382
  }
383
  """
 
365
  """
366
 
367
  RETURN_SELECTION_JS = """
368
+ (input_file, input_protein_file, n_steps, n_atoms, n_samples, hidden) => {
369
  let selected = []
370
  for (const [atom, add] of Object.entries(window.selected_elements)) {
371
  if (add) {
 
378
  }
379
  }
380
  console.log("Finished parsing");
381
+ return [input_file, input_protein_file, n_steps, n_atoms, n_samples, selected.join(",")];
382
  }
383
  """
src/generation.py CHANGED
@@ -9,10 +9,8 @@ from src.visualizer import save_xyz_file
9
  from src.utils import FoundNaNException
10
  from src.datasets import get_one_hot
11
 
12
- N_SAMPLES = 5
13
 
14
-
15
- def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
16
  chain = node_mask = None
17
  for i in range(5):
18
  try:
@@ -39,14 +37,14 @@ def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False):
39
  if with_pocket:
40
  node_mask[torch.where(data['pocket_mask'])] = 0
41
 
42
- names = [f'output_{i + 1}_{name}' for i in range(N_SAMPLES)]
43
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
44
  print('Saved XYZ files')
45
 
46
 
47
- def try_to_convert_to_sdf(name):
48
  out_files = []
49
- for i in range(N_SAMPLES):
50
  out_xyz = f'results/output_{i + 1}_{name}_.xyz'
51
  out_sdf = f'results/output_{i + 1}_{name}_.sdf'
52
  subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
 
9
  from src.utils import FoundNaNException
10
  from src.datasets import get_one_hot
11
 
 
12
 
13
+ def generate_linkers(ddpm, data, num_samples, sample_fn, name, with_pocket=False):
 
14
  chain = node_mask = None
15
  for i in range(5):
16
  try:
 
37
  if with_pocket:
38
  node_mask[torch.where(data['pocket_mask'])] = 0
39
 
40
+ names = [f'output_{i + 1}_{name}' for i in range(num_samples)]
41
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
42
  print('Saved XYZ files')
43
 
44
 
45
+ def try_to_convert_to_sdf(name, num_samples):
46
  out_files = []
47
+ for i in range(num_samples):
48
  out_xyz = f'results/output_{i + 1}_{name}_.xyz'
49
  out_sdf = f'results/output_{i + 1}_{name}_.sdf'
50
  subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)