xxie commited on
Commit
531dfb5
1 Parent(s): a6d435a

add ddim support

Browse files
Files changed (4) hide show
  1. app.py +15 -8
  2. configs/structured.py +1 -0
  3. demo.py +6 -1
  4. model/model_hoattn.py +20 -6
app.py CHANGED
@@ -127,7 +127,7 @@ def plot_points(colors, coords):
127
  return fig
128
 
129
 
130
- def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed, input_cls):
131
  """
132
  given user input, run inference
133
  :param runner:
@@ -138,6 +138,7 @@ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, s
138
  :param std_coverage: float value, used to estimate camera translation
139
  :param input_seed: random seed
140
  :param input_cls: the object category of the input image
 
141
  :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
142
  """
143
  log = ""
@@ -153,6 +154,8 @@ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, s
153
  log += f"Reloading fine-tuned checkpoint of category {input_cls}\n"
154
  runner.reload_checkpoint(input_cls)
155
 
 
 
156
  out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
157
  points = out_stage2.points_packed().cpu().numpy()
158
  colors = out_stage2.features_packed().cpu().numpy()
@@ -204,6 +207,10 @@ def main(cfg: ProjectConfig):
204
  'chair', 'skateboard', 'suitcase', 'table'],
205
  value='general')
206
  input_seed = gr.Number(label='Random seed', value=42)
 
 
 
 
207
  # Output visualization
208
  with gr.Row():
209
  pc_plot = gr.Plot(label="Reconstructed point cloud")
@@ -217,20 +224,20 @@ def main(cfg: ProjectConfig):
217
  with gr.Row():
218
  button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
219
  button_recon.click(fn=partial(inference, runner, cfg),
220
- inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],
221
  outputs=[pc_plot, out_pc_download, out_log])
222
  gr.HTML("""<br/>""")
223
  # Example input
224
  example_dir = cfg.run.code_dir_abs+"/examples"
225
  rgb, ps, obj = 'k1_color.jpg', 'k1_person_mask.png', 'k1_obj_rend_mask.png'
226
  example_images = gr.Examples([
227
- [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42, 'skateboard'],
228
- [f"{example_dir}/205904/{rgb}", f"{example_dir}/205904/{ps}", f"{example_dir}/205904/{obj}", 3.2, 42, 'suitcase'],
229
- [f"{example_dir}/066241/{rgb}", f"{example_dir}/066241/{ps}", f"{example_dir}/066241/{obj}", 3.5, 42, 'backpack'],
230
- [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42, 'chair'],
231
- [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42, 'chair'],
232
 
233
- ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls],)
234
 
235
  gr.Markdown(citation_str)
236
 
 
127
  return fig
128
 
129
 
130
+ def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed, input_cls, input_scheduler):
131
  """
132
  given user input, run inference
133
  :param runner:
 
138
  :param std_coverage: float value, used to estimate camera translation
139
  :param input_seed: random seed
140
  :param input_cls: the object category of the input image
141
+ :param input_scheduler: reverse sampling scheduler, ddim or ddpm
142
  :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
143
  """
144
  log = ""
 
154
  log += f"Reloading fine-tuned checkpoint of category {input_cls}\n"
155
  runner.reload_checkpoint(input_cls)
156
 
157
+ cfg.run.diffusion_scheduler = input_scheduler
158
+ cfg.run.num_inference_steps = 1000 if input_scheduler == 'ddpm' else 100
159
  out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
160
  points = out_stage2.points_packed().cpu().numpy()
161
  colors = out_stage2.features_packed().cpu().numpy()
 
207
  'chair', 'skateboard', 'suitcase', 'table'],
208
  value='general')
209
  input_seed = gr.Number(label='Random seed', value=42)
210
+ input_scheduler = gr.Dropdown(label='Diffusion scheduler',
211
+ info='Reverse diffusion scheduler: DDIM is 10x faster',
212
+ choices=['ddpm', 'ddim'],
213
+ value='ddim')
214
  # Output visualization
215
  with gr.Row():
216
  pc_plot = gr.Plot(label="Reconstructed point cloud")
 
224
  with gr.Row():
225
  button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
226
  button_recon.click(fn=partial(inference, runner, cfg),
227
+ inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls, input_scheduler],
228
  outputs=[pc_plot, out_pc_download, out_log])
229
  gr.HTML("""<br/>""")
230
  # Example input
231
  example_dir = cfg.run.code_dir_abs+"/examples"
232
  rgb, ps, obj = 'k1_color.jpg', 'k1_person_mask.png', 'k1_obj_rend_mask.png'
233
  example_images = gr.Examples([
234
+ [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42, 'skateboard', 'ddim'],
235
+ [f"{example_dir}/205904/{rgb}", f"{example_dir}/205904/{ps}", f"{example_dir}/205904/{obj}", 3.2, 42, 'suitcase', 'ddim'],
236
+ [f"{example_dir}/066241/{rgb}", f"{example_dir}/066241/{ps}", f"{example_dir}/066241/{obj}", 3.5, 42, 'backpack', 'ddim'],
237
+ [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42, 'chair', 'ddim'],
238
+ [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42, 'chair', 'ddim'],
239
 
240
+ ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed, input_cls, input_scheduler],)
241
 
242
  gr.Markdown(citation_str)
243
 
configs/structured.py CHANGED
@@ -127,6 +127,7 @@ class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig):
127
  beta_end: float = 8e-3 # 0.012
128
  beta_schedule: str = 'linear' # 'custom'
129
  dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
 
130
 
131
  # Point cloud model arguments
132
  point_cloud_model: str = 'pvcnn'
 
127
  beta_end: float = 8e-3 # 0.012
128
  beta_schedule: str = 'linear' # 'custom'
129
  dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
130
+ ddim_eta: float = 1.0 # DDIM eta parameter: 0 is the default one which does deterministic generation
131
 
132
  # Point cloud model arguments
133
  point_cloud_model: str = 'pvcnn'
demo.py CHANGED
@@ -180,6 +180,7 @@ class DemoRunner:
180
  mask=torch.stack(batch['masks']).to('cuda'),
181
  scheduler=cfg.run.diffusion_scheduler,
182
  num_inference_steps=cfg.run.num_inference_steps,
 
183
  )
184
  # segment and normalize human/object
185
  bs = len(out_stage1)
@@ -254,7 +255,11 @@ class DemoRunner:
254
  radius_hum=radius_hum.unsqueeze(-1),
255
  radius_obj=radius_obj.unsqueeze(-1),
256
  sample_from_interm=True,
257
- noise_step=cfg.run.sample_noise_step)
 
 
 
 
258
  return out_stage1, out_stage2
259
 
260
  def upsample_predicted_pc(self, num_samples, pc_obj):
 
180
  mask=torch.stack(batch['masks']).to('cuda'),
181
  scheduler=cfg.run.diffusion_scheduler,
182
  num_inference_steps=cfg.run.num_inference_steps,
183
+ eta=cfg.model.ddim_eta,
184
  )
185
  # segment and normalize human/object
186
  bs = len(out_stage1)
 
255
  radius_hum=radius_hum.unsqueeze(-1),
256
  radius_obj=radius_obj.unsqueeze(-1),
257
  sample_from_interm=True,
258
+ noise_step=cfg.run.sample_noise_step,
259
+ scheduler=cfg.run.diffusion_scheduler,
260
+ num_inference_steps=cfg.run.num_inference_steps,
261
+ eta=cfg.model.ddim_eta,
262
+ )
263
  return out_stage1, out_stage2
264
 
265
  def upsample_predicted_pc(self, num_samples, pc_obj):
model/model_hoattn.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
 
12
  from pytorch3d.structures import Pointclouds
13
  from pytorch3d.renderer import CamerasBase
 
14
  from .model_diff_data import ConditionalPCDiffusionBehave
15
  from .pvcnn.pvcnn_ho import PVCNN2HumObj
16
  import torch.nn.functional as F
@@ -375,17 +376,30 @@ class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave):
375
 
376
  return (output, all_outputs) if return_all_outputs else output
377
 
378
- def get_reverse_timesteps(self, scheduler, interm_steps:int):
379
  """
380
-
381
  :param scheduler:
382
- :param interm_steps: start from some intermediate steps
 
383
  :return:
384
  """
385
- if interm_steps > 0:
386
- timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
387
  else:
388
- timesteps = scheduler.timesteps.to(self.device)
389
  return timesteps
390
 
391
  def pack_norm_params(self, kwargs:dict, scale=True):
 
11
 
12
  from pytorch3d.structures import Pointclouds
13
  from pytorch3d.renderer import CamerasBase
14
+ from diffusers.schedulers import DDPMScheduler, DDIMScheduler
15
  from .model_diff_data import ConditionalPCDiffusionBehave
16
  from .pvcnn.pvcnn_ho import PVCNN2HumObj
17
  import torch.nn.functional as F
 
376
 
377
  return (output, all_outputs) if return_all_outputs else output
378
 
379
+ def get_reverse_timesteps(self, scheduler, interm_steps: int):
380
  """
381
+ get the timesteps to run reverse diffusion
382
  :param scheduler:
383
+ :param interm_steps: start from some intermediate steps, the step number is for DDPM scheduler
384
+ if DDIM, will be recomputed accordingly
385
  :return:
386
  """
387
+ if isinstance(scheduler, DDPMScheduler):
388
+ # DDPM, directly reverse N steps from interm_steps
389
+ if interm_steps > 0:
390
+ timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
391
+ else:
392
+ timesteps = scheduler.timesteps.to(self.device)
393
+ elif isinstance(scheduler, DDIMScheduler):
394
+ if interm_steps > 0:
395
+ # compute a step ratio, and find the intermediate steps for DDIM
396
+ step_ratio = scheduler.config.num_train_timesteps // scheduler.num_inference_steps
397
+ timesteps = (np.arange(0, interm_steps, step_ratio)).round()[::-1].copy().astype(np.int64)
398
+ timesteps = torch.from_numpy(timesteps).to(self.device)
399
+ else:
400
+ timesteps = scheduler.timesteps.to(self.device)
401
  else:
402
+ raise NotImplementedError
403
  return timesteps
404
 
405
  def pack_norm_params(self, kwargs:dict, scale=True):