hysts HF staff commited on
Commit
af2a8f5
1 Parent(s): 26f3e39

Add an option to show denoising process

Browse files
Files changed (3) hide show
  1. app.py +15 -2
  2. model.py +50 -9
  3. style.css +4 -0
app.py CHANGED
@@ -22,7 +22,7 @@ def create_simple_demo(model: Model) -> gr.Blocks:
22
 
23
 
24
  def create_advanced_demo(model: Model) -> gr.Blocks:
25
- def update_scheduler_type(name: str) -> dict:
26
  visible = name != 'DDPM'
27
  if name == 'PNDM':
28
  minimum = 4
@@ -35,6 +35,9 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
35
  maximum=maximum,
36
  value=20)
37
 
 
 
 
38
  with gr.Blocks() as demo:
39
  gr.Markdown(DESCRIPTION)
40
 
@@ -60,9 +63,14 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
60
  step=1,
61
  value=1234,
62
  label='Seed')
 
 
63
  run_button = gr.Button('Run')
64
  with gr.Column():
65
  result = gr.Image(show_label=False, elem_id='result')
 
 
 
66
 
67
  model_name.change(fn=model.set_pipeline,
68
  inputs=[
@@ -70,7 +78,7 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
70
  scheduler_type,
71
  ],
72
  outputs=None)
73
- scheduler_type.change(fn=update_scheduler_type,
74
  inputs=scheduler_type,
75
  outputs=num_steps,
76
  queue=False)
@@ -80,6 +88,9 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
80
  scheduler_type,
81
  ],
82
  outputs=None)
 
 
 
83
  run_button.click(fn=model.run,
84
  inputs=[
85
  model_name,
@@ -87,10 +98,12 @@ def create_advanced_demo(model: Model) -> gr.Blocks:
87
  num_steps,
88
  randomize_seed,
89
  seed,
 
90
  ],
91
  outputs=[
92
  result,
93
  seed,
 
94
  ])
95
  return demo
96
 
 
22
 
23
 
24
  def create_advanced_demo(model: Model) -> gr.Blocks:
25
+ def update_num_steps(name: str) -> dict:
26
  visible = name != 'DDPM'
27
  if name == 'PNDM':
28
  minimum = 4
 
35
  maximum=maximum,
36
  value=20)
37
 
38
+ def show_denoising_changed(selected: bool) -> dict:
39
+ return gr.Video.update(visible=selected)
40
+
41
  with gr.Blocks() as demo:
42
  gr.Markdown(DESCRIPTION)
43
 
 
63
  step=1,
64
  value=1234,
65
  label='Seed')
66
+ show_denoising = gr.Checkbox(value=False,
67
+ label='Show Denoising')
68
  run_button = gr.Button('Run')
69
  with gr.Column():
70
  result = gr.Image(show_label=False, elem_id='result')
71
+ result_video = gr.Video(show_label=False,
72
+ visible=False,
73
+ elem_id='result-video')
74
 
75
  model_name.change(fn=model.set_pipeline,
76
  inputs=[
 
78
  scheduler_type,
79
  ],
80
  outputs=None)
81
+ scheduler_type.change(fn=update_num_steps,
82
  inputs=scheduler_type,
83
  outputs=num_steps,
84
  queue=False)
 
88
  scheduler_type,
89
  ],
90
  outputs=None)
91
+ show_denoising.change(fn=show_denoising_changed,
92
+ inputs=show_denoising,
93
+ outputs=result_video)
94
  run_button.click(fn=model.run,
95
  inputs=[
96
  model_name,
 
98
  num_steps,
99
  randomize_seed,
100
  seed,
101
+ show_denoising,
102
  ],
103
  outputs=[
104
  result,
105
  seed,
106
+ result_video,
107
  ])
108
  return demo
109
 
model.py CHANGED
@@ -4,10 +4,13 @@ import logging
4
  import os
5
  import random
6
  import sys
 
7
 
 
8
  import numpy as np
9
  import PIL.Image
10
  import torch
 
11
  from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
12
  DiffusionPipeline, PNDMPipeline, PNDMScheduler)
13
 
@@ -101,20 +104,58 @@ class Model:
101
  logger.info('--- done ---')
102
  return res
103
 
104
- def run(
105
- self,
106
- model_name: str,
107
- scheduler_type: str,
108
- num_steps: int,
109
- randomize_seed: bool,
110
- seed: int,
111
- ) -> tuple[PIL.Image.Image, int]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  self.set_pipeline(model_name, scheduler_type)
113
  if scheduler_type == 'PNDM':
114
  num_steps = max(4, min(num_steps, 100))
115
  if randomize_seed:
116
  seed = self.rng.randint(0, 100000)
117
- return self.generate(seed, num_steps)[0], seed
 
 
 
 
 
118
 
119
  @staticmethod
120
  def to_grid(images: list[PIL.Image.Image],
 
4
  import os
5
  import random
6
  import sys
7
+ import tempfile
8
 
9
+ import imageio
10
  import numpy as np
11
  import PIL.Image
12
  import torch
13
+ import tqdm.auto
14
  from diffusers import (DDIMPipeline, DDIMScheduler, DDPMPipeline,
15
  DiffusionPipeline, PNDMPipeline, PNDMScheduler)
16
 
 
104
  logger.info('--- done ---')
105
  return res
106
 
107
+ @staticmethod
108
+ def postprocess(sample: torch.Tensor) -> np.ndarray:
109
+ res = (sample / 2 + 0.5).clamp(0, 1)
110
+ res = (res * 255).to(torch.uint8)
111
+ res = res.cpu().permute(0, 2, 3, 1).numpy()
112
+ return res
113
+
114
+ @torch.inference_mode()
115
+ def generate_with_video(self, seed: int,
116
+ num_steps: int) -> tuple[PIL.Image.Image, str]:
117
+ logger.info('--- generate_with_video ---')
118
+ if self.scheduler_type == 'DDPM':
119
+ num_steps = 1000
120
+ fps = 100
121
+ else:
122
+ fps = 10
123
+ logger.info(f'{seed=}, {num_steps=}')
124
+
125
+ model = self.pipeline.unet.to(self.device)
126
+ scheduler = self.pipeline.scheduler
127
+ scheduler.set_timesteps(num_inference_steps=num_steps)
128
+ input_shape = (1, model.config.in_channels, model.config.sample_size,
129
+ model.config.sample_size)
130
+ torch.manual_seed(seed)
131
+
132
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
133
+ writer = imageio.get_writer(out_file.name, fps=fps)
134
+ sample = torch.randn(input_shape).to(self.device)
135
+ for t in tqdm.auto.tqdm(scheduler.timesteps):
136
+ out = model(sample, t)['sample']
137
+ sample = scheduler.step(out, t, sample)['prev_sample']
138
+ res = self.postprocess(sample)[0]
139
+ writer.append_data(res)
140
+ writer.close()
141
+
142
+ logger.info('--- done ---')
143
+ return res, out_file.name
144
+
145
+ def run(self, model_name: str, scheduler_type: str, num_steps: int,
146
+ randomize_seed: bool, seed: int, visualize_denoising: bool
147
+ ) -> tuple[PIL.Image.Image, int, str | None]:
148
  self.set_pipeline(model_name, scheduler_type)
149
  if scheduler_type == 'PNDM':
150
  num_steps = max(4, min(num_steps, 100))
151
  if randomize_seed:
152
  seed = self.rng.randint(0, 100000)
153
+
154
+ if not visualize_denoising:
155
+ return self.generate(seed, num_steps)[0], seed, None
156
+ else:
157
+ res, filename = self.generate_with_video(seed, num_steps)
158
+ return res, seed, filename
159
 
160
  @staticmethod
161
  def to_grid(images: list[PIL.Image.Image],
style.css CHANGED
@@ -9,6 +9,10 @@ div#result {
9
  max-width: 400px;
10
  max-height: 400px;
11
  }
 
 
 
 
12
  img#visitor-badge {
13
  display: block;
14
  margin: auto;
 
9
  max-width: 400px;
10
  max-height: 400px;
11
  }
12
+ div#result-video {
13
+ max-width: 400px;
14
+ max-height: 400px;
15
+ }
16
  img#visitor-badge {
17
  display: block;
18
  margin: auto;