igashov commited on
Commit
7a7c7ad
1 Parent(s): 6f4a6fd
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -15,6 +15,7 @@ from src.lightning import DDPM
15
  from src.linker_size_lightning import SizeClassifier
16
 
17
  N_SAMPLES = 5
 
18
 
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument('--ip', type=str, default=None)
@@ -38,6 +39,7 @@ if not os.path.exists(diffusion_path):
38
  link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
39
  subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
40
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
 
41
  print('Loaded diffusion model')
42
 
43
 
@@ -172,10 +174,7 @@ def generate(input_file):
172
  return [
173
  output.IFRAME_TEMPLATE.format(html=html),
174
  [inp_sdf] + out_files,
175
- gr.Radio.update(
176
- choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
177
- value='Sample 1',
178
- )
179
  ]
180
 
181
 
@@ -203,7 +202,14 @@ with demo:
203
  gr.Markdown('## Visualization')
204
  gr.Markdown('Below you will see input and output molecules')
205
  visualization = gr.HTML()
206
- samples = gr.Radio(interactive=True, type='index', label='Samples')
 
 
 
 
 
 
 
207
 
208
  input_file.change(
209
  fn=show_input,
 
15
  from src.linker_size_lightning import SizeClassifier
16
 
17
  N_SAMPLES = 5
18
+ N_STEPS = 10
19
 
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument('--ip', type=str, default=None)
 
39
  link = 'https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1'
40
  subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
41
  ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
42
+ ddpm.edm.T = N_STEPS
43
  print('Loaded diffusion model')
44
 
45
 
 
174
  return [
175
  output.IFRAME_TEMPLATE.format(html=html),
176
  [inp_sdf] + out_files,
177
+ gr.Radio.update(visible=True)
 
 
 
178
  ]
179
 
180
 
 
202
  gr.Markdown('## Visualization')
203
  gr.Markdown('Below you will see input and output molecules')
204
  visualization = gr.HTML()
205
+ samples = gr.Radio(
206
+ choices=['Sample 1', 'Sample 2', 'Sample 3', 'Sample 4', 'Sample 5'],
207
+ value='Sample 1',
208
+ type='index',
209
+ show_label=False,
210
+ visible=False,
211
+ interactive=True,
212
+ )
213
 
214
  input_file.change(
215
  fn=show_input,