File size: 9,661 Bytes
220fc34 d738856 220fc34 d738856 220fc34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# modular diffusers diff-idff
from diffusers.modular_pipelines import (
PipelineBlock,
SequentialPipelineBlocks,
PipelineState,
InputParam,
OutputParam,
ComponentSpec,
AutoPipelineBlocks
)
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
from diffusers.schedulers import EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
from diffusers.configuration_utils import FrozenDict
from diffusers.modular_pipelines.stable_diffusion_xl.before_denoise import prepare_latents_img2img
from diffusers.modular_pipelines.stable_diffusion_xl.denoise import (
StableDiffusionXLDenoiseLoopWrapper,
StableDiffusionXLDenoiseLoopDenoiser,
StableDiffusionXLControlNetDenoiseLoopDenoiser,
StableDiffusionXLDenoiseLoopAfterDenoiser
)
from diffusers.modular_pipelines.stable_diffusion_xl.modular_pipeline_block_mappings import (
IMAGE2IMAGE_BLOCKS,
TEXT2IMAGE_BLOCKS
)
import torch
from typing import List, Tuple, Any, Optional
class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Step that prepares the latents for the differential diffusion generation process"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec("scheduler", EulerDiscreteScheduler),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}),
default_creation_method="from_config",
)
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("diffdiff_map",type_hint=PipelineImageInput, required=True),
InputParam(
"latents",
type_hint=Optional[torch.Tensor],
description="Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`."
),
InputParam(
"num_images_per_prompt",
default=1,
type_hint=int,
description="The number of images to generate per prompt"
),
InputParam(
"denoising_start",
type_hint=Optional[float],
description="When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. The initial part of the denoising process is skipped and it is assumed that the passed `image` is a partly denoised image. Note that when this is specified, strength will be ignored. Useful for 'Mixture of Denoisers' multi-pipeline setups."
),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("timesteps",type_hint=torch.Tensor, description="The timesteps to use for sampling. Can be generated in set_timesteps step."),
InputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step."),
InputParam("batch_size", type_hint=int, description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step."),
InputParam("num_inference_steps", type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step."),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
OutputParam("original_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
OutputParam("diffdiff_masks", type_hint=torch.Tensor, description="The masks used for the differential diffusion denoising process"),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState):
block_state = self.get_block_state(state)
block_state.dtype = components.vae.dtype
block_state.device = components._execution_device
block_state.add_noise = True if block_state.denoising_start is None else False
components.scheduler.set_begin_index(None)
if block_state.latents is None:
block_state.latents = prepare_latents_img2img(
components.vae,
components.scheduler,
block_state.image_latents,
block_state.timesteps,
block_state.batch_size,
block_state.num_images_per_prompt,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.add_noise,
)
latent_height = block_state.image_latents.shape[-2]
latent_width = block_state.image_latents.shape[-1]
diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
diffdiff_map = diffdiff_map.squeeze(0).to(block_state.device)
thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(block_state.device)
block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
block_state.original_latents = block_state.latents
self.add_block_state(state, block_state)
return components, state
class SDXLDiffDiffDenoiseLoopBeforeDenoiser(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("denoising_start"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam(
"latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"original_latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step."
),
InputParam(
"diffdiff_masks",
type_hint=torch.Tensor,
description="The masks used for the differential diffusion denoising process, can be generated in DiffDiffInput step."
),
]
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", EulerDiscreteScheduler)
]
@torch.no_grad()
def __call__(self, components, block_state, i, t) -> PipelineState:
# diff diff
if i == 0 and block_state.denoising_start is None:
block_state.latents = block_state.original_latents[:1]
else:
block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0)
# cast mask to the same type as latents etc
block_state.mask = block_state.mask.to(block_state.latents.dtype)
block_state.mask = block_state.mask.unsqueeze(1) # fit shape
block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
# end diff diff
# expand the latents if we are doing classifier free guidance
block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
return components, block_state
class SDXLDiffDiffDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [SDXLDiffDiffDenoiseLoopBeforeDenoiser, StableDiffusionXLDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
# control_cond
class SDXLDiffDiffControlNetDenoiseLoop(StableDiffusionXLDenoiseLoopWrapper):
block_classes = [SDXLDiffDiffDenoiseLoopBeforeDenoiser, StableDiffusionXLControlNetDenoiseLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
class SDXLDiffDiffDenoiseStep(AutoPipelineBlocks):
block_classes = [SDXLDiffDiffControlNetDenoiseLoop, SDXLDiffDiffDenoiseLoop]
block_names = ["controlnet_denoise", "denoise"]
block_trigger_inputs = ["controlnet_cond", None]
DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
class DiffDiffBlocks(SequentialPipelineBlocks):
block_classes = list(DIFFDIFF_BLOCKS.values())
block_names = list(DIFFDIFF_BLOCKS.keys())
|