Re-enable TQDM (reverts #11)

#17
by cbensimon HF staff - opened
Files changed (1) hide show
  1. pipeline.py +44 -43
pipeline.py CHANGED
@@ -398,50 +398,51 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
398
 
399
  # 11. Denoising loop
400
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
401
-
402
- for i, t in enumerate(timesteps):
403
- latent_model_input = (
404
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
405
- )
406
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
407
-
408
- if i <= start_merge_step:
409
- current_prompt_embeds = torch.cat(
410
- [negative_prompt_embeds, prompt_embeds_text_only], dim=0
411
- )
412
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
413
- else:
414
- current_prompt_embeds = torch.cat(
415
- [negative_prompt_embeds, prompt_embeds], dim=0
416
  )
417
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
418
- # predict the noise residual
419
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
420
- noise_pred = self.unet(
421
- latent_model_input,
422
- t,
423
- encoder_hidden_states=current_prompt_embeds,
424
- cross_attention_kwargs=cross_attention_kwargs,
425
- added_cond_kwargs=added_cond_kwargs,
426
- return_dict=False,
427
- )[0]
428
-
429
- # perform guidance
430
- if do_classifier_free_guidance:
431
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
432
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
433
-
434
- if do_classifier_free_guidance and guidance_rescale > 0.0:
435
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
436
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
437
-
438
- # compute the previous noisy sample x_t -> x_t-1
439
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
440
-
441
- # call the callback, if provided
442
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
443
- if callback is not None and i % callback_steps == 0:
444
- callback(i, t, latents)
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  # make sure the VAE is in float32 mode, as it overflows in float16
447
  if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
 
398
 
399
  # 11. Denoising loop
400
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
401
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
402
+ for i, t in enumerate(timesteps):
403
+ latent_model_input = (
404
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
 
 
 
 
 
 
 
 
 
 
 
405
  )
406
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
407
+
408
+ if i <= start_merge_step:
409
+ current_prompt_embeds = torch.cat(
410
+ [negative_prompt_embeds, prompt_embeds_text_only], dim=0
411
+ )
412
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0)
413
+ else:
414
+ current_prompt_embeds = torch.cat(
415
+ [negative_prompt_embeds, prompt_embeds], dim=0
416
+ )
417
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
418
+ # predict the noise residual
419
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
420
+ noise_pred = self.unet(
421
+ latent_model_input,
422
+ t,
423
+ encoder_hidden_states=current_prompt_embeds,
424
+ cross_attention_kwargs=cross_attention_kwargs,
425
+ added_cond_kwargs=added_cond_kwargs,
426
+ return_dict=False,
427
+ )[0]
428
+
429
+ # perform guidance
430
+ if do_classifier_free_guidance:
431
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
432
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
433
+
434
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
435
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
436
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
437
+
438
+ # compute the previous noisy sample x_t -> x_t-1
439
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
440
+
441
+ # call the callback, if provided
442
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
443
+ progress_bar.update()
444
+ if callback is not None and i % callback_steps == 0:
445
+ callback(i, t, latents)
446
 
447
  # make sure the VAE is in float32 mode, as it overflows in float16
448
  if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: