back to original
Browse files- pipeline.py +108 -112
pipeline.py
CHANGED
@@ -625,89 +625,89 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
625 |
|
626 |
return timesteps, num_inference_steps - t_start
|
627 |
|
628 |
-
# def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
629 |
-
# if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
630 |
-
# raise ValueError(
|
631 |
-
# f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
632 |
-
# )
|
633 |
-
|
634 |
-
# image = image.to(device=device, dtype=dtype)
|
635 |
-
|
636 |
-
# batch_size = batch_size * num_images_per_prompt
|
637 |
-
# if isinstance(generator, list) and len(generator) != batch_size:
|
638 |
-
# raise ValueError(
|
639 |
-
# f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
640 |
-
# f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
641 |
-
# )
|
642 |
-
|
643 |
-
# if isinstance(generator, list):
|
644 |
-
# init_latents = [
|
645 |
-
# self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
646 |
-
# ]
|
647 |
-
# init_latents = torch.cat(init_latents, dim=0)
|
648 |
-
# else:
|
649 |
-
# init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
650 |
-
|
651 |
-
# init_latents = self.vae.config.scaling_factor * init_latents
|
652 |
-
|
653 |
-
# if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
654 |
-
# raise ValueError(
|
655 |
-
# f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
656 |
-
# )
|
657 |
-
# else:
|
658 |
-
# init_latents = torch.cat([init_latents], dim=0)
|
659 |
-
|
660 |
-
# shape = init_latents.shape
|
661 |
-
# noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
662 |
-
|
663 |
-
# # get latents
|
664 |
-
# init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
665 |
-
# latents = init_latents
|
666 |
-
|
667 |
-
# return latents
|
668 |
-
|
669 |
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
670 |
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
671 |
raise ValueError(
|
672 |
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
673 |
)
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
for img in image:
|
678 |
-
img_tensor = prepare_image(img)
|
679 |
-
img_tensor = img_tensor.to(device=device, dtype=dtype)
|
680 |
-
image_tensors.append(img_tensor)
|
681 |
-
image = torch.stack(image_tensors, dim=0)
|
682 |
-
else:
|
683 |
-
image = prepare_image(image)
|
684 |
-
image = image.to(device=device, dtype=dtype)
|
685 |
-
|
686 |
batch_size = batch_size * num_images_per_prompt
|
687 |
if isinstance(generator, list) and len(generator) != batch_size:
|
688 |
raise ValueError(
|
689 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
690 |
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
691 |
)
|
692 |
-
|
693 |
if isinstance(generator, list):
|
694 |
init_latents = [
|
695 |
-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(
|
696 |
]
|
697 |
init_latents = torch.cat(init_latents, dim=0)
|
698 |
else:
|
699 |
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
|
|
700 |
init_latents = self.vae.config.scaling_factor * init_latents
|
701 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
shape = init_latents.shape
|
703 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
704 |
-
|
705 |
# get latents
|
706 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
707 |
latents = init_latents
|
708 |
-
|
709 |
return latents
|
710 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
def _default_height_width(self, height, width, image):
|
713 |
if isinstance(image, list):
|
@@ -940,27 +940,27 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
940 |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
941 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
942 |
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
# latent_timestep,
|
947 |
-
# batch_size,
|
948 |
-
# num_images_per_prompt,
|
949 |
-
# prompt_embeds.dtype,
|
950 |
-
# device,
|
951 |
-
# generator,
|
952 |
-
# )
|
953 |
-
|
954 |
-
latents = [self.prepare_latents(
|
955 |
-
img,
|
956 |
latent_timestep,
|
957 |
batch_size,
|
958 |
num_images_per_prompt,
|
959 |
prompt_embeds.dtype,
|
960 |
device,
|
961 |
generator,
|
962 |
-
)
|
963 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
964 |
|
965 |
|
966 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
@@ -980,24 +980,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
980 |
# compute the percentage of total steps we are at
|
981 |
current_sampling_percent = i / len(timesteps)
|
982 |
|
983 |
-
# if (
|
984 |
-
# current_sampling_percent < controlnet_guidance_start
|
985 |
-
# or current_sampling_percent > controlnet_guidance_end
|
986 |
-
# ):
|
987 |
-
# # do not apply the controlnet
|
988 |
-
# down_block_res_samples = None
|
989 |
-
# mid_block_res_sample = None
|
990 |
-
# else:
|
991 |
-
# # apply the controlnet
|
992 |
-
# down_block_res_samples, mid_block_res_sample = self.controlnet(
|
993 |
-
# latent_model_input,
|
994 |
-
# t,
|
995 |
-
# encoder_hidden_states=prompt_embeds,
|
996 |
-
# controlnet_cond=controlnet_conditioning_image,
|
997 |
-
# conditioning_scale=controlnet_conditioning_scale,
|
998 |
-
# return_dict=False,
|
999 |
-
# )
|
1000 |
-
|
1001 |
if (
|
1002 |
current_sampling_percent < controlnet_guidance_start
|
1003 |
or current_sampling_percent > controlnet_guidance_end
|
@@ -1006,28 +988,42 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversi
|
|
1006 |
down_block_res_samples = None
|
1007 |
mid_block_res_sample = None
|
1008 |
else:
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
conditioning_scale=controlnet_conditioning_scale,
|
1019 |
-
return_dict=False,
|
1020 |
-
)
|
1021 |
-
|
1022 |
-
down_block_res_samples.append(down_block_res_sample)
|
1023 |
-
mid_block_res_samples.append(mid_block_res_sample)
|
1024 |
-
|
1025 |
-
down_block_res_samples = tuple(down_block_res_samples)
|
1026 |
-
mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
1027 |
-
|
1028 |
-
# down_block_res_samples = torch.cat(down_block_res_samples, dim=0)
|
1029 |
-
# mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
1030 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1031 |
|
1032 |
# predict the noise residual
|
1033 |
noise_pred = self.unet(
|
|
|
625 |
|
626 |
return timesteps, num_inference_steps - t_start
|
627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
629 |
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
630 |
raise ValueError(
|
631 |
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
632 |
)
|
633 |
+
|
634 |
+
image = image.to(device=device, dtype=dtype)
|
635 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
batch_size = batch_size * num_images_per_prompt
|
637 |
if isinstance(generator, list) and len(generator) != batch_size:
|
638 |
raise ValueError(
|
639 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
640 |
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
641 |
)
|
642 |
+
|
643 |
if isinstance(generator, list):
|
644 |
init_latents = [
|
645 |
+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
646 |
]
|
647 |
init_latents = torch.cat(init_latents, dim=0)
|
648 |
else:
|
649 |
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
650 |
+
|
651 |
init_latents = self.vae.config.scaling_factor * init_latents
|
652 |
+
|
653 |
+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
654 |
+
raise ValueError(
|
655 |
+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
656 |
+
)
|
657 |
+
else:
|
658 |
+
init_latents = torch.cat([init_latents], dim=0)
|
659 |
+
|
660 |
shape = init_latents.shape
|
661 |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
662 |
+
|
663 |
# get latents
|
664 |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
665 |
latents = init_latents
|
666 |
+
|
667 |
return latents
|
668 |
|
669 |
+
# def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
670 |
+
# if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
671 |
+
# raise ValueError(
|
672 |
+
# f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
673 |
+
# )
|
674 |
+
|
675 |
+
# if isinstance(image, list):
|
676 |
+
# image_tensors = []
|
677 |
+
# for img in image:
|
678 |
+
# img_tensor = prepare_image(img)
|
679 |
+
# img_tensor = img_tensor.to(device=device, dtype=dtype)
|
680 |
+
# image_tensors.append(img_tensor)
|
681 |
+
# image = torch.stack(image_tensors, dim=0)
|
682 |
+
# else:
|
683 |
+
# image = prepare_image(image)
|
684 |
+
# image = image.to(device=device, dtype=dtype)
|
685 |
+
|
686 |
+
# batch_size = batch_size * num_images_per_prompt
|
687 |
+
# if isinstance(generator, list) and len(generator) != batch_size:
|
688 |
+
# raise ValueError(
|
689 |
+
# f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
690 |
+
# f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
691 |
+
# )
|
692 |
+
|
693 |
+
# if isinstance(generator, list):
|
694 |
+
# init_latents = [
|
695 |
+
# self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(image.shape[0])
|
696 |
+
# ]
|
697 |
+
# init_latents = torch.cat(init_latents, dim=0)
|
698 |
+
# else:
|
699 |
+
# init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
700 |
+
# init_latents = self.vae.config.scaling_factor * init_latents
|
701 |
+
|
702 |
+
# shape = init_latents.shape
|
703 |
+
# noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
704 |
+
|
705 |
+
# # get latents
|
706 |
+
# init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
707 |
+
# latents = init_latents
|
708 |
+
|
709 |
+
# return latents
|
710 |
+
|
711 |
|
712 |
def _default_height_width(self, height, width, image):
|
713 |
if isinstance(image, list):
|
|
|
940 |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
941 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
942 |
|
943 |
+
6. Prepare latent variables
|
944 |
+
latents = self.prepare_latents(
|
945 |
+
image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
946 |
latent_timestep,
|
947 |
batch_size,
|
948 |
num_images_per_prompt,
|
949 |
prompt_embeds.dtype,
|
950 |
device,
|
951 |
generator,
|
952 |
+
)
|
953 |
+
|
954 |
+
# latents = [self.prepare_latents(
|
955 |
+
# img,
|
956 |
+
# latent_timestep,
|
957 |
+
# batch_size,
|
958 |
+
# num_images_per_prompt,
|
959 |
+
# prompt_embeds.dtype,
|
960 |
+
# device,
|
961 |
+
# generator,
|
962 |
+
# ) for img in images]
|
963 |
+
# latents = torch.cat(latents)
|
964 |
|
965 |
|
966 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
|
980 |
# compute the percentage of total steps we are at
|
981 |
current_sampling_percent = i / len(timesteps)
|
982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
983 |
if (
|
984 |
current_sampling_percent < controlnet_guidance_start
|
985 |
or current_sampling_percent > controlnet_guidance_end
|
|
|
988 |
down_block_res_samples = None
|
989 |
mid_block_res_sample = None
|
990 |
else:
|
991 |
+
# apply the controlnet
|
992 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
993 |
+
latent_model_input,
|
994 |
+
t,
|
995 |
+
encoder_hidden_states=prompt_embeds,
|
996 |
+
controlnet_cond=controlnet_conditioning_image,
|
997 |
+
conditioning_scale=controlnet_conditioning_scale,
|
998 |
+
return_dict=False,
|
999 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
|
1001 |
+
# if (
|
1002 |
+
# current_sampling_percent < controlnet_guidance_start
|
1003 |
+
# or current_sampling_percent > controlnet_guidance_end
|
1004 |
+
# ):
|
1005 |
+
# # do not apply the controlnet
|
1006 |
+
# down_block_res_samples = None
|
1007 |
+
# mid_block_res_sample = None
|
1008 |
+
# else:
|
1009 |
+
# down_block_res_samples = []
|
1010 |
+
# mid_block_res_samples = []
|
1011 |
+
# for i in range(batch_size):
|
1012 |
+
# # apply the controlnet
|
1013 |
+
# down_block_res_sample, mid_block_res_sample = self.controlnet(
|
1014 |
+
# latent_model_input[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
|
1015 |
+
# t,
|
1016 |
+
# encoder_hidden_states=prompt_embeds[i * num_images_per_prompt:(i + 1) * num_images_per_prompt],
|
1017 |
+
# controlnet_cond=controlnet_conditioning_image[i],
|
1018 |
+
# conditioning_scale=controlnet_conditioning_scale,
|
1019 |
+
# return_dict=False,
|
1020 |
+
# )
|
1021 |
+
|
1022 |
+
# down_block_res_samples.append(down_block_res_sample)
|
1023 |
+
# mid_block_res_samples.append(mid_block_res_sample)
|
1024 |
+
|
1025 |
+
# down_block_res_samples = tuple(down_block_res_samples)
|
1026 |
+
# mid_block_res_sample = torch.cat(mid_block_res_samples, dim=0)
|
1027 |
|
1028 |
# predict the noise residual
|
1029 |
noise_pred = self.unet(
|