igashov commited on
Commit
ff512d8
1 Parent(s): 2f96d45

Change max batch_size

Browse files
Files changed (2) hide show
  1. app.py +17 -7
  2. src/generation.py +3 -2
app.py CHANGED
@@ -17,6 +17,11 @@ from src.generation import generate_linkers, try_to_convert_to_sdf, get_pocket
17
  from zipfile import ZipFile
18
 
19
 
 
 
 
 
 
20
  MODELS_METADATA = {
21
  'geom_difflinker': {
22
  'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
@@ -329,9 +334,7 @@ def generate_without_pocket(input_file, n_steps, n_atoms, num_samples, selected_
329
 
330
  for data in dataloader:
331
  try:
332
- generate_linkers(
333
- ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=False
334
- )
335
  except Exception as e:
336
  e = str(e).replace('\'', '')
337
  error = f'Caught exception while generating linkers: {e}'
@@ -450,7 +453,8 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
450
  dataset = MOADDataset(data=dataset)
451
  ddpm.val_dataset = dataset
452
 
453
- dataloader = get_dataloader(dataset, batch_size=num_samples, collate_fn=collate_with_fragment_edges)
 
454
  print('Created dataloader')
455
 
456
  ddpm.edm.T = n_steps
@@ -470,10 +474,13 @@ def generate_with_pocket(in_fragments, in_protein, n_steps, n_atoms, num_samples
470
  def sample_fn(_data):
471
  return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
472
 
473
- for data in dataloader:
474
  try:
 
475
  generate_linkers(
476
- ddpm=ddpm, data=data, num_samples=num_samples, sample_fn=sample_fn, name=name, with_pocket=True
 
 
477
  )
478
  except Exception as e:
479
  e = str(e).replace('\'', '')
@@ -520,7 +527,10 @@ with demo:
520
  gr.Markdown('Upload the file of the target protein in .pdb format (optionally):')
521
  input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)')
522
 
523
- n_steps = gr.Slider(minimum=50, maximum=500, label="Number of Denoising Steps", step=10)
 
 
 
524
  n_atoms = gr.Slider(
525
  minimum=0, maximum=20,
526
  label="Linker Size: DiffLinker will predict it if set to 0",
 
17
  from zipfile import ZipFile
18
 
19
 
20
+ MIN_N_STEPS = 100
21
+ MAX_N_STEPS = 500
22
+ MAX_BATCH_SIZE = 5
23
+
24
+
25
  MODELS_METADATA = {
26
  'geom_difflinker': {
27
  'link': 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1',
 
334
 
335
  for data in dataloader:
336
  try:
337
+ generate_linkers(ddpm=ddpm, data=data, sample_fn=sample_fn, name=name, with_pocket=False)
 
 
338
  except Exception as e:
339
  e = str(e).replace('\'', '')
340
  error = f'Caught exception while generating linkers: {e}'
 
453
  dataset = MOADDataset(data=dataset)
454
  ddpm.val_dataset = dataset
455
 
456
+ batch_size = min(num_samples, MAX_BATCH_SIZE)
457
+ dataloader = get_dataloader(dataset, batch_size=batch_size, collate_fn=collate_with_fragment_edges)
458
  print('Created dataloader')
459
 
460
  ddpm.edm.T = n_steps
 
474
  def sample_fn(_data):
475
  return torch.ones(_data['positions'].shape[0], device=device, dtype=torch.long) * n_atoms
476
 
477
+ for batch_i, data in enumerate(dataloader):
478
  try:
479
+ offset_idx = batch_i * batch_size
480
  generate_linkers(
481
+ ddpm=ddpm, data=data,
482
+ sample_fn=sample_fn, name=name, with_pocket=True,
483
+ offset_idx=offset_idx,
484
  )
485
  except Exception as e:
486
  e = str(e).replace('\'', '')
 
527
  gr.Markdown('Upload the file of the target protein in .pdb format (optionally):')
528
  input_protein_file = gr.File(file_count='single', label='Target Protein (Optional)')
529
 
530
+ n_steps = gr.Slider(
531
+ minimum=MIN_N_STEPS, maximum=MAX_N_STEPS,
532
+ label="Number of Denoising Steps", step=10
533
+ )
534
  n_atoms = gr.Slider(
535
  minimum=0, maximum=20,
536
  label="Linker Size: DiffLinker will predict it if set to 0",
src/generation.py CHANGED
@@ -10,7 +10,7 @@ from src.utils import FoundNaNException
10
  from src.datasets import get_one_hot
11
 
12
 
13
- def generate_linkers(ddpm, data, num_samples, sample_fn, name, with_pocket=False):
14
  chain = node_mask = None
15
  for i in range(5):
16
  try:
@@ -37,7 +37,8 @@ def generate_linkers(ddpm, data, num_samples, sample_fn, name, with_pocket=False
37
  if with_pocket:
38
  node_mask[torch.where(data['pocket_mask'])] = 0
39
 
40
- names = [f'output_{i + 1}_{name}' for i in range(num_samples)]
 
41
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
42
  print('Saved XYZ files')
43
 
 
10
  from src.datasets import get_one_hot
11
 
12
 
13
+ def generate_linkers(ddpm, data, sample_fn, name, with_pocket=False, offset_idx=0):
14
  chain = node_mask = None
15
  for i in range(5):
16
  try:
 
37
  if with_pocket:
38
  node_mask[torch.where(data['pocket_mask'])] = 0
39
 
40
+ batch_size = len(data)
41
+ names = [f'output_{offset_idx + i + 1}_{name}' for i in range(batch_size)]
42
  save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
43
  print('Saved XYZ files')
44