MrAlex commited on
Commit
8987924
1 Parent(s): 90804b7

back to original

Browse files
Files changed (1) hide show
  1. pipeline.py +108 -112
pipeline.py CHANGED
@@ -625,89 +625,89 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
625
 
626
  return timesteps, num_inference_steps - t_start
627
 
628
- # def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
629
- # if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
630
- # raise ValueError(
631
- # f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
632
- # )
633
-
634
- # image = image.to(device=device, dtype=dtype)
635
-
636
- # batch_size = batch_size * num_images_per_prompt
637
- # if isinstance(generator, list) and len(generator) != batch_size:
638
- # raise ValueError(
639
- # f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
640
- # f" size of {batch_size}. Make sure the batch size matches the length of the generators."
641
- # )
642
-
643
- # if isinstance(generator, list):
644
- # init_latents = [
645
- # self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
646
- # ]
647
- # init_latents = torch.cat(init_latents, dim=0)
648
- # else:
649
- # init_latents = self.vae.encode(image).latent_dist.sample(generator)
650
-
651
- # init_latents = self.vae.config.scaling_factor * init_latents
652
-
653
- # if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
654
- # raise ValueError(
655
- # f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
656
- # )
657
- # else:
658
- # init_latents = torch.cat([init_latents], dim=0)
659
-
660
- # shape = init_latents.shape
661
- # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
662
-
663
- # # get latents
664
- # init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
665
- # latents = init_latents
666
-
667
- # return latents
668
-
669
  def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
670
  if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
671
  raise ValueError(
672
  f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
673
  )
674
-
675
- if isinstance(image, list):
676
- image_tensors = []
677
- for img in image:
678
- img_tensor = prepare_image(img)
679
- img_tensor = img_tensor.to(device=device, dtype=dtype)
680
- image_tensors.append(img_tensor)
681
- image = torch.stack(image_tensors, dim=0)
682
- else:
683
- image = prepare_image(image)
684
- image = image.to(device=device, dtype=dtype)
685
-
686
  batch_size = batch_size * num_images_per_prompt
687
  if isinstance(generator, list) and len(generator) != batch_size:
688
  raise ValueError(
689
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
690
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
691
  )
692
-
693
  if isinstance(generator, list):
694
  init_latents = [
695
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(image.shape[0])
696
  ]
697
  init_latents = torch.cat(init_latents, dim=0)
698
  else:
699
  init_latents = self.vae.encode(image).latent_dist.sample(generator)
 
700
  init_latents = self.vae.config.scaling_factor * init_latents
701
-
 
 
 
 
 
 
 
702
  shape = init_latents.shape
703
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
704
-
705
  # get latents
706
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
707
  latents = init_latents
708
-
709
  return latents
710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
  def _default_height_width(self, height, width, image):
713
  if isinstance(image, list):
@@ -940,27 +940,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
940
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
941
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
942
 
943
- # 6. Prepare latent variables
944
- # latents = self.prepare_latents(
945
- # image,
946
- # latent_timestep,
947
- # batch_size,
948
- # num_images_per_prompt,
949
- # prompt_embeds.dtype,
950
- # device,
951
- # generator,
952
- # )
953
-
954
- latents = [self.prepare_latents(
955
- img,
956
  latent_timestep,
957
  batch_size,
958
  num_images_per_prompt,
959
  prompt_embeds.dtype,
960
  device,
961
  generator,
962
- ) for img in images]
963
- latents = torch.cat(latents)
 
 
 
 
 
 
 
 
 
 
964
 
965
 
966
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -980,24 +980,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
980
  # compute the percentage of total steps we are at
981
  current_sampling_percent = i / len(timesteps)
982
 
983
- # if (
984
- # current_sampling_percent < controlnet_guidance_start
985
- # or current_sampling_percent > controlnet_guidance_end
986
- # ):
987
- # # do not apply the controlnet
988
- # down_block_res_samples = None
989
- # mid_block_res_sample = None
990
- # else:
991
- # # apply the controlnet
992
- # down_block_res_samples, mid_block_res_sample = self.controlnet(
993
- # latent_model_input,
994
- # t,
995
- # encoder_hidden_states=prompt_embeds,
996
- # controlnet_cond=controlnet_conditioning_image,
997
- # conditioning_scale=controlnet_conditioning_scale,
998
- # return_dict=False,
999
- # )
1000
-
1001
  if (
1002
  current_sampling_percent < controlnet_guidance_start
1003
  or current_sampling_percent > controlnet_guidance_end
@@ -1006,28 +988,42 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
1006
  down_block_res_samples = None
1007
  mid_block_res_sample = None
1008
  else:
1009
- down_block_res_samples = []
1010
- mid_block_res_samples = []
1011
- for i in range(batch_size):
1012
- # apply the controlnet
1013
- down_block_res_sample, mid_block_res_sample = self.controlnet(
1014
- latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
1015
- t,
1016
- encoder_hidden_states=prompt_embeds[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
1017
- controlnet_cond=controlnet_conditioning_image[i],
1018
- conditioning_scale=controlnet_conditioning_scale,
1019
- return_dict=False,
1020
- )
1021
-
1022
- down_block_res_samples.append(down_block_res_sample)
1023
- mid_block_res_samples.append(mid_block_res_sample)
1024
-
1025
- down_block_res_samples = tuple(down_block_res_samples)
1026
- mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
1027
-
1028
- # down_block_res_samples = torch.cat(down_block_res_samples, dim=0)
1029
- # mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
1030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1031
 
1032
  # predict the noise residual
1033
  noise_pred = self.unet(
 
625
 
626
  return timesteps, num_inference_steps - t_start
627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
629
  if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
630
  raise ValueError(
631
  f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
632
  )
633
+
634
+ image = image.to(device=device, dtype=dtype)
635
+
 
 
 
 
 
 
 
 
 
636
  batch_size = batch_size * num_images_per_prompt
637
  if isinstance(generator, list) and len(generator) != batch_size:
638
  raise ValueError(
639
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
640
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
641
  )
642
+
643
  if isinstance(generator, list):
644
  init_latents = [
645
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
646
  ]
647
  init_latents = torch.cat(init_latents, dim=0)
648
  else:
649
  init_latents = self.vae.encode(image).latent_dist.sample(generator)
650
+
651
  init_latents = self.vae.config.scaling_factor * init_latents
652
+
653
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
654
+ raise ValueError(
655
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
656
+ )
657
+ else:
658
+ init_latents = torch.cat([init_latents], dim=0)
659
+
660
  shape = init_latents.shape
661
  noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
662
+
663
  # get latents
664
  init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
665
  latents = init_latents
666
+
667
  return latents
668
 
669
+ # def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
670
+ # if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
671
+ # raise ValueError(
672
+ # f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
673
+ # )
674
+
675
+ # if isinstance(image, list):
676
+ # image_tensors = []
677
+ # for img in image:
678
+ # img_tensor = prepare_image(img)
679
+ # img_tensor = img_tensor.to(device=device, dtype=dtype)
680
+ # image_tensors.append(img_tensor)
681
+ # image = torch.stack(image_tensors, dim=0)
682
+ # else:
683
+ # image = prepare_image(image)
684
+ # image = image.to(device=device, dtype=dtype)
685
+
686
+ # batch_size = batch_size * num_images_per_prompt
687
+ # if isinstance(generator, list) and len(generator) != batch_size:
688
+ # raise ValueError(
689
+ # f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
690
+ # f" size of {batch_size}. Make sure the batch size matches the length of the generators."
691
+ # )
692
+
693
+ # if isinstance(generator, list):
694
+ # init_latents = [
695
+ # self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(image.shape[0])
696
+ # ]
697
+ # init_latents = torch.cat(init_latents, dim=0)
698
+ # else:
699
+ # init_latents = self.vae.encode(image).latent_dist.sample(generator)
700
+ # init_latents = self.vae.config.scaling_factor * init_latents
701
+
702
+ # shape = init_latents.shape
703
+ # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
704
+
705
+ # # get latents
706
+ # init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
707
+ # latents = init_latents
708
+
709
+ # return latents
710
+
711
 
712
  def _default_height_width(self, height, width, image):
713
  if isinstance(image, list):
 
940
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
941
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
942
 
943
+ 6. Prepare latent variables
944
+ latents = self.prepare_latents(
945
+ image,
 
 
 
 
 
 
 
 
 
 
946
  latent_timestep,
947
  batch_size,
948
  num_images_per_prompt,
949
  prompt_embeds.dtype,
950
  device,
951
  generator,
952
+ )
953
+
954
+ # latents = [self.prepare_latents(
955
+ # img,
956
+ # latent_timestep,
957
+ # batch_size,
958
+ # num_images_per_prompt,
959
+ # prompt_embeds.dtype,
960
+ # device,
961
+ # generator,
962
+ # ) for img in images]
963
+ # latents = torch.cat(latents)
964
 
965
 
966
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
 
980
  # compute the percentage of total steps we are at
981
  current_sampling_percent = i / len(timesteps)
982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
  if (
984
  current_sampling_percent < controlnet_guidance_start
985
  or current_sampling_percent > controlnet_guidance_end
 
988
  down_block_res_samples = None
989
  mid_block_res_sample = None
990
  else:
991
+ # apply the controlnet
992
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
993
+ latent_model_input,
994
+ t,
995
+ encoder_hidden_states=prompt_embeds,
996
+ controlnet_cond=controlnet_conditioning_image,
997
+ conditioning_scale=controlnet_conditioning_scale,
998
+ return_dict=False,
999
+ )
 
 
 
 
 
 
 
 
 
 
 
 
1000
 
1001
+ # if (
1002
+ # current_sampling_percent < controlnet_guidance_start
1003
+ # or current_sampling_percent > controlnet_guidance_end
1004
+ # ):
1005
+ # # do not apply the controlnet
1006
+ # down_block_res_samples = None
1007
+ # mid_block_res_sample = None
1008
+ # else:
1009
+ # down_block_res_samples = []
1010
+ # mid_block_res_samples = []
1011
+ # for i in range(batch_size):
1012
+ # # apply the controlnet
1013
+ # down_block_res_sample, mid_block_res_sample = self.controlnet(
1014
+ # latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
1015
+ # t,
1016
+ # encoder_hidden_states=prompt_embeds[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
1017
+ # controlnet_cond=controlnet_conditioning_image[i],
1018
+ # conditioning_scale=controlnet_conditioning_scale,
1019
+ # return_dict=False,
1020
+ # )
1021
+
1022
+ # down_block_res_samples.append(down_block_res_sample)
1023
+ # mid_block_res_samples.append(mid_block_res_sample)
1024
+
1025
+ # down_block_res_samples = tuple(down_block_res_samples)
1026
+ # mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
1027
 
1028
  # predict the noise residual
1029
  noise_pred = self.unet(