fffiloni commited on
Commit
9ba0acb
1 Parent(s): e2461d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -182,7 +182,7 @@ def infer(ref_style_file, style_description, caption, progress):
182
  lam_style=1, lam_txt_alignment=1.0,
183
  use_ddim_sampler=True,
184
  )
185
- for (sampled_c, _, _) in progress.tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
186
  #for i, (sampled_c, _, _) in enumerate(sampling_c, 1):
187
  # if i % 5 == 0: # Update progress every 5 steps
188
  # progress(0.4 + 0.3 * (i / extras.sampling_configs['timesteps']), f"Stage C reverse process: step {i}/{extras.sampling_configs['timesteps']}")
@@ -199,7 +199,7 @@ def infer(ref_style_file, style_description, caption, progress):
199
  unconditions_b, device=device, **extras_b.sampling_configs,
200
  )
201
  for i, (sampled_b, _, _) in enumerate(sampling_b, 1):
202
- if i % 5 == 0: # Update progress every 5 steps
203
  progress(0.7 + 0.2 * (i / extras_b.sampling_configs['timesteps']), f"Stage B reverse process: step {i}/{extras_b.sampling_configs['timesteps']}")
204
  sampled_b = sampled_b
205
  sampled = models_b.stage_a.decode(sampled_b).float()
 
182
  lam_style=1, lam_txt_alignment=1.0,
183
  use_ddim_sampler=True,
184
  )
185
+ for (sampled_c, _, _) in progress.tqdm(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])):
186
  #for i, (sampled_c, _, _) in enumerate(sampling_c, 1):
187
  # if i % 5 == 0: # Update progress every 5 steps
188
  # progress(0.4 + 0.3 * (i / extras.sampling_configs['timesteps']), f"Stage C reverse process: step {i}/{extras.sampling_configs['timesteps']}")
 
199
  unconditions_b, device=device, **extras_b.sampling_configs,
200
  )
201
  for i, (sampled_b, _, _) in enumerate(sampling_b, 1):
202
+ if i % 1 == 0: # Update progress every 1 step
203
  progress(0.7 + 0.2 * (i / extras_b.sampling_configs['timesteps']), f"Stage B reverse process: step {i}/{extras_b.sampling_configs['timesteps']}")
204
  sampled_b = sampled_b
205
  sampled = models_b.stage_a.decode(sampled_b).float()