File size: 16,310 Bytes
e523134 513d4d1 e523134 513d4d1 e523134 513d4d1 e523134 ea15fc2 e523134 ea15fc2 e523134 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 |
import torch
from diffusers import DiffusionPipeline, DDPMScheduler, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import PyTorchModelHubMixin
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from diffusers.models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
)
class CombinedStableDiffusionXL(
DiffusionPipeline,
PyTorchModelHubMixin
):
"""
A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis,
noise scheduling, latent space manipulation, and image decoding.
"""
def __init__(
self,
original_unet: torch.nn.Module,
fine_tuned_unet: torch.nn.Module,
scheduler: DDPMScheduler,
vae: torch.nn.Module,
tokenizer: CLIPTextModel,
tokenizer_2: CLIPTextModel,
text_encoder: CLIPTextModelWithProjection,
text_encoder_2: CLIPTextModelWithProjection,
) -> None:
super().__init__()
self.register_modules(
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
original_unet=original_unet,
fine_tuned_unet=fine_tuned_unet,
scheduler=scheduler,
vae=vae,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor
)
self.resolution = 1024
def _get_negative_prompts(
self, batch_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
inputs_ids_1 = self.tokenizer(
[""] * batch_size,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
input_ids_2 = self.tokenizer_2(
[""] * batch_size,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
return inputs_ids_1, input_ids_2
def _get_encoder_hidden_states(
self,
tokenized_prompts_1: torch.Tensor,
tokenized_prompts_2: torch.Tensor,
do_classifier_free_guidance: bool = False
) -> torch.Tensor:
text_input_ids_list = [
tokenized_prompts_1,
tokenized_prompts_2
]
batch_size = text_input_ids_list[0].size(0)
if do_classifier_free_guidance:
negative_prompts = [
embed.to(text_input_ids_list[0].device)
for embed in self._get_negative_prompts(batch_size)
]
text_input_ids_list = [
torch.cat(
[
negative_prompt,
text_input,
]
)
for text_input, negative_prompt in zip(
text_input_ids_list, negative_prompts
)
]
prompt_embeds_list = []
text_encoders = [self.text_encoder, self.text_encoder_2]
for text_encoder, text_input_ids in zip(text_encoders, text_input_ids_list):
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
return_dict=False,
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds[-1][-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
def _get_unet_prediction(
self,
latent_model_input: torch.Tensor,
timestep: int,
encoder_hidden_states: torch.Tensor,
) -> torch.Tensor:
"""
Return unet noise prediction
Args:
latent_model_input (torch.Tensor): Unet latents input
timestep (int): noise scheduler timestep
encoder_hidden_states (tuple[torch.Tensor, torch.Tensor]): Text encoder hidden states
Returns:
torch.Tensor: noise prediction
"""
unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet
prompt_embeds, pooled_prompt_embeds = encoder_hidden_states
target_size = torch.tensor(
[
[self.resolution, self.resolution]
for _ in range(latent_model_input.size(0))
],
device=latent_model_input.device,
dtype=torch.float32,
)
add_time_ids = torch.cat(
[target_size, torch.zeros_like(target_size), target_size], dim=1
)
unet_added_conditions = {
"time_ids": add_time_ids,
"text_embeds": pooled_prompt_embeds,
}
return unet(
latent_model_input,
timestep,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=unet_added_conditions,
).sample
def get_noise_prediction(
self,
latents: torch.Tensor,
timestep_index: int,
encoder_hidden_states: torch.Tensor,
do_classifier_free_guidance: bool = False,
detach_main_path: bool = False,
):
"""
Return noise prediction
Args:
latents (torch.Tensor): Image latents
timestep_index (int): noise scheduler timestep index
encoder_hidden_states (torch.Tensor): Text encoder hidden states
do_classifier_free_guidance (bool) Whether to do classifier free guidance
detach_main_path (bool): Detach gradient
Returns:
torch.Tensor: noise prediction
"""
timestep = self.scheduler.timesteps[timestep_index]
latent_model_input = self.scheduler.scale_model_input(
sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
timestep=timestep,
)
noise_pred = self._get_unet_prediction(
latent_model_input=latent_model_input,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
if detach_main_path:
noise_pred_text = noise_pred_text.detach()
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
def sample_next_latents(
self,
latents: torch.Tensor,
timestep_index: int,
noise_pred: torch.Tensor,
return_pred_original: bool = False,
) -> torch.Tensor:
"""
Return next latents prediction
Args:
latents (torch.Tensor): Image latents
timestep_index (int): noise scheduler timestep index
noise_pred (torch.Tensor): noise prediction
return_pred_original (bool) Whether to sample original sample
Returns:
torch.Tensor: latent prediction
"""
timestep = self.scheduler.timesteps[timestep_index]
sample = self.scheduler.step(
model_output=noise_pred, timestep=timestep, sample=latents
)
return (
sample.pred_original_sample if return_pred_original else sample.prev_sample
)
def predict_next_latents(
self,
latents: torch.Tensor,
timestep_index: int,
encoder_hidden_states: torch.Tensor,
return_pred_original: bool = False,
do_classifier_free_guidance: bool = False,
detach_main_path: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predicts the next latent states during the diffusion process.
Args:
latents (torch.Tensor): Current latent states.
timestep_index (int): Index of the current timestep.
encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder.
return_pred_original (bool): Whether to return the predicted original sample.
do_classifier_free_guidance (bool) Whether to do classifier free guidance
detach_main_path (bool): Detach gradient
Returns:
tuple: Next latents and predicted noise tensor.
"""
noise_pred = self.get_noise_prediction(
latents=latents,
timestep_index=timestep_index,
encoder_hidden_states=encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
detach_main_path=detach_main_path,
)
latents = self.sample_next_latents(
latents=latents,
noise_pred=noise_pred,
timestep_index=timestep_index,
return_pred_original=return_pred_original,
)
return latents, noise_pred
def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor:
latent_resolution = int(self.resolution) // self.vae_scale_factor
return torch.randn(
(
batch_size,
self.original_unet.config.in_channels,
latent_resolution,
latent_resolution,
),
device=device,
)
def do_k_diffusion_steps(
self,
start_timestep_index: int,
end_timestep_index: int,
latents: torch.Tensor,
encoder_hidden_states: torch.Tensor,
return_pred_original: bool = False,
do_classifier_free_guidance: bool = False,
detach_main_path: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Performs multiple diffusion steps between specified timesteps.
Args:
start_timestep_index (int): Starting timestep index.
end_timestep_index (int): Ending timestep index.
latents (torch.Tensor): Initial latents.
encoder_hidden_states (torch.Tensor): Encoder hidden states.
return_pred_original (bool): Whether to return the predicted original sample.
do_classifier_free_guidance (bool) Whether to do classifier free guidance
detach_main_path (bool): Detach gradient
Returns:
tuple: Resulting latents and encoder hidden states.
"""
assert start_timestep_index <= end_timestep_index
for timestep_index in range(start_timestep_index, end_timestep_index - 1):
latents, _ = self.predict_next_latents(
latents=latents,
timestep_index=timestep_index,
encoder_hidden_states=encoder_hidden_states,
return_pred_original=False,
do_classifier_free_guidance=do_classifier_free_guidance,
detach_main_path=detach_main_path,
)
res, _ = self.predict_next_latents(
latents=latents,
timestep_index=end_timestep_index - 1,
encoder_hidden_states=encoder_hidden_states,
return_pred_original=return_pred_original,
do_classifier_free_guidance=do_classifier_free_guidance,
)
return res, encoder_hidden_states
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
FusedAttnProcessor2_0,
),
)
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(
self,
prompt: str | list[str],
num_inference_steps=40,
original_unet_steps=35,
resolution=1024,
guidance_scale=5,
output_type: str = "pil",
return_dict: bool = True,
):
self.guidance_scale = guidance_scale
self.resolution = resolution
batch_size = 1 if isinstance(prompt, str) else len(prompt)
tokenized_prompts_1 = self.tokenizer(
prompt,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
tokenized_prompts_2 = self.tokenizer_2(
prompt,
max_length=self.tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
original_encoder_hidden_states = self._get_encoder_hidden_states(
tokenized_prompts_1=tokenized_prompts_1,
tokenized_prompts_2=tokenized_prompts_2,
do_classifier_free_guidance=True
)
fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states(
tokenized_prompts_1=tokenized_prompts_1,
tokenized_prompts_2=tokenized_prompts_2,
do_classifier_free_guidance=False
)
latent_resolution = int(resolution) // self.vae_scale_factor
latents = torch.randn(
(
batch_size,
self.original_unet.config.in_channels,
latent_resolution,
latent_resolution,
),
device=self.device,
)
self.scheduler.set_timesteps(
num_inference_steps,
device=self.device
)
self._use_original_unet = True
latents, _ = self.do_k_diffusion_steps(
start_timestep_index=0,
end_timestep_index=original_unet_steps,
latents=latents,
encoder_hidden_states=original_encoder_hidden_states,
return_pred_original=False,
do_classifier_free_guidance=True,
)
self._use_original_unet = False
latents, _ = self.do_k_diffusion_steps(
start_timestep_index=original_unet_steps,
end_timestep_index=num_inference_steps,
latents=latents,
encoder_hidden_states=fine_tuned_encoder_hidden_states,
return_pred_original=False,
do_classifier_free_guidance=False,
)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents).sample
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(
image,
output_type=output_type,
do_denormalize=[True] * image.shape[0]
)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
|