apolinario commited on
Commit
0abf9df
1 Parent(s): 5156e7a

Add sample and bring back the steps slider

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -21,18 +21,15 @@ ddim_pipeline = DDIMPipeline(unet=model, scheduler=ddim_scheduler)
21
  pndm_scheduler = PNDMScheduler.from_config(model_id, subfolder="scheduler")
22
  pndm_pipeline = PNDMPipeline(unet=model, scheduler=pndm_scheduler)
23
  # run pipeline in inference (sample random noise and denoise)
24
- def predict(seed=42,scheduler="ddim"):
25
  torch.cuda.empty_cache()
26
  generator = torch.manual_seed(seed)
27
  if(scheduler == "ddim"):
28
- image = ddim_pipeline(generator=generator, num_inference_steps=100)
29
- image = image["sample"]
30
  elif(scheduler == "ddpm"):
31
- image = ddpm_pipeline(generator=generator)
32
- #["sample"] doesnt work here for some reason
33
  elif(scheduler == "pndm"):
34
- image = pndm_pipeline(generator=generator, num_inference_steps=11)
35
- #["sample"] doesnt work here for some reason
36
 
37
  image_processed = image.cpu().permute(0, 2, 3, 1)
38
  if scheduler == "pndm":
@@ -49,7 +46,7 @@ random_seed = random.randint(0, 2147483647)
49
  gr.Interface(
50
  predict,
51
  inputs=[
52
- #gr.inputs.Slider(1, 1000, label='Inference Steps', default=20, step=1),
53
  gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed),
54
  gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler")
55
  ],
 
21
  pndm_scheduler = PNDMScheduler.from_config(model_id, subfolder="scheduler")
22
  pndm_pipeline = PNDMPipeline(unet=model, scheduler=pndm_scheduler)
23
  # run pipeline in inference (sample random noise and denoise)
24
+ def predict(steps=100, seed=42,scheduler="ddim"):
25
  torch.cuda.empty_cache()
26
  generator = torch.manual_seed(seed)
27
  if(scheduler == "ddim"):
28
+ image = ddim_pipeline(generator=generator, num_inference_steps=steps)["sample"]
 
29
  elif(scheduler == "ddpm"):
30
+ image = ddpm_pipeline(generator=generator)["sample"]
 
31
  elif(scheduler == "pndm"):
32
+ image = pndm_pipeline(generator=generator, num_inference_steps=steps)["sample"]
 
33
 
34
  image_processed = image.cpu().permute(0, 2, 3, 1)
35
  if scheduler == "pndm":
 
46
  gr.Interface(
47
  predict,
48
  inputs=[
49
+ gr.inputs.Slider(1, 100, label='Inference Steps', default=20, step=1),
50
  gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed),
51
  gr.inputs.Radio(["ddim", "ddpm", "pndm"], default="ddpm",label="Diffusion scheduler")
52
  ],