ashawkey commited on
Commit
218cd4f
1 Parent(s): c5edef0

optimize mvdream pipeline

Browse files
Files changed (1) hide show
  1. mvdream/pipeline_mvdream.py +12 -5
mvdream/pipeline_mvdream.py CHANGED
@@ -499,6 +499,13 @@ class MVDreamPipeline(DiffusionPipeline):
499
  # Prepare extra step kwargs.
500
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
501
 
 
 
 
 
 
 
 
502
  # Denoising loop
503
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
504
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -511,17 +518,17 @@ class MVDreamPipeline(DiffusionPipeline):
511
  unet_inputs = {
512
  'x': latent_model_input,
513
  'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
514
- 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
515
  'num_frames': actual_num_frames,
516
- 'camera': torch.cat([camera] * multiplier),
517
  }
518
 
519
  if image is not None:
520
- unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
521
- unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
522
 
523
  # predict the noise residual
524
- noise_pred = self.unet.forward(**unet_inputs)
525
 
526
  # perform guidance
527
  if do_classifier_free_guidance:
 
499
  # Prepare extra step kwargs.
500
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
501
 
502
+ context = torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames)
503
+ torch.cat([camera] * multiplier)
504
+
505
+ if image is not None:
506
+ ip = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
507
+ ip_img = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
508
+
509
  # Denoising loop
510
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
511
  with self.progress_bar(total=num_inference_steps) as progress_bar:
 
518
  unet_inputs = {
519
  'x': latent_model_input,
520
  'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
521
+ 'context': context,
522
  'num_frames': actual_num_frames,
523
+ 'camera': camera,
524
  }
525
 
526
  if image is not None:
527
+ unet_inputs['ip'] = ip
528
+ unet_inputs['ip_img'] = ip_img
529
 
530
  # predict the noise residual
531
+ noise_pred = self.unet(**unet_inputs)
532
 
533
  # perform guidance
534
  if do_classifier_free_guidance: