File size: 44,584 Bytes
ae2e28c 5ed891e ae2e28c 5ed891e ae2e28c d2c1221 ae2e28c d2c1221 ae2e28c d2c1221 ae2e28c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 |
import itertools
from functools import partial
from typing import Any, Dict, Tuple, Callable
from typing import Union, Optional, List
import numpy as np
import torch
from diffusers import DPMSolverMultistepScheduler
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModelOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import replace_example_docstring
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def custom_sort_order(obj):
"""
Key function for sorting order of execution in forward methods
"""
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing):
"""
:param timestep_spacing: the timestep_spacing array we want to squeeze
:param n: the size of the squeezed array
:param i: the index we start squeezing from
:return: squeezed timestep_spacing
Example:
timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16)
n = 10, i = 6
Expected:
[967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133
"""
assert i < n
squeezed = np.flip(np.arange(n)) + 1 # [n, n-1, ..., 2, 1]
squeezed[:i] = timestep_spacing[:i]
k = squeezed[i - 1] // (n - i + 1)
squeezed[i:] *= k
return squeezed
PREDEFINED_TIMESTEP_SQUEEZERS = {
# Tested with DPM 16-steps (reduced 16 -> 10 or 11 steps)
"10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6),
"11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7),
}
FlexibleUnetConfigurations = {
# General parameters for all blocks
"sample_size": 64,
"temb_dim": 320 * 4,
"resnet_eps": 1e-5,
"resnet_act_fn": "silu",
"num_attention_heads": 8,
"cross_attention_dim": 768,
# Controls modules execute order in unet's forward
"mix_block_in_forward": True,
# Down blocks parameters
"down_blocks_in_channels": [320, 320, 640],
"down_blocks_out_channels": [320, 640, 1280],
"down_blocks_num_attentions": [0, 1, 3],
"down_blocks_num_resnets": [2, 2, 1],
"add_downsample": [True, True, False],
# Middle block parameters
"add_upsample_mid_block": None,
"mid_num_resnets": 0,
"mid_num_attentions": 0,
# Up block parameters
"prev_output_channels": [1280, 1280, 640],
"up_blocks_num_attentions": [5, 3, 0],
"up_blocks_num_resnets": [2, 3, 3],
"add_upsample": [True, True, False],
}
class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin):
"""
This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences:
* Defaults are modified to accommodate DeciDiffusion
* It supports a squeezer to squeeze the number of inference steps to a smaller number
//!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline!
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "squaredcos_cap_v2", # NOTE THIS DEFAULT VALUE
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "v_prediction", # NOTE THIS DEFAULT VALUE
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "heun", # NOTE THIS DEFAULT VALUE
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -7.5, # NOTE THIS DEFAULT VALUE
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 1,
squeeze_mode: Optional[str] = None, # NOTE THIS ADDITION. Supports keys from `PREDEFINED_TIMESTEP_SQUEEZERS` defined above
):
self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode)
if use_karras_sigmas:
raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`")
super().__init__(
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
solver_order=solver_order,
prediction_type=prediction_type,
thresholding=thresholding,
dynamic_thresholding_ratio=dynamic_thresholding_ratio,
sample_max_value=sample_max_value,
algorithm_type=algorithm_type,
solver_type=solver_type,
lower_order_final=lower_order_final,
use_karras_sigmas=False,
lambda_min_clipped=lambda_min_clipped,
variance_type=variance_type,
timestep_spacing=timestep_spacing,
steps_offset=steps_offset,
)
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
super().set_timesteps(num_inference_steps=num_inference_steps, device=device)
if self._squeezer is not None:
timesteps = self._squeezer(self.timesteps.cpu())
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
self.num_inference_steps = len(timesteps)
class FlexibleIdentityBlock(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
return hidden_states
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
configurations = FlexibleUnetConfigurations
@register_to_config
def __init__(self):
super().__init__(
sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]),
cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]),
)
num_attention_heads = self.configurations.get("num_attention_heads")
cross_attention_dim = self.configurations.get("cross_attention_dim")
mix_block_in_forward = self.configurations.get("mix_block_in_forward")
resnet_act_fn = self.configurations.get("resnet_act_fn")
resnet_eps = self.configurations.get("resnet_eps")
temb_dim = self.configurations.get("temb_dim")
###############
# Down blocks #
###############
down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions")
down_blocks_out_channels = self.configurations.get("down_blocks_out_channels")
down_blocks_in_channels = self.configurations.get("down_blocks_in_channels")
down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets")
add_downsample = self.configurations.get("add_downsample")
self.down_blocks = nn.ModuleList()
for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(
zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample)
):
last_block = i == len(down_blocks_in_channels) - 1
self.down_blocks.append(
FlexibleCrossAttnDownBlock2D(
in_channels=in_c,
out_channels=out_c,
temb_channels=temb_dim,
num_resnets=n_res,
num_attentions=n_att,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_downsample=add_down,
last_block=last_block,
mix_block_in_forward=mix_block_in_forward,
)
)
###############
# Mid blocks #
###############
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
mid_num_attentions = self.configurations.get("mid_num_attentions")
mid_num_resnets = self.configurations.get("mid_num_resnets")
if mid_num_resnets == mid_num_attentions == 0:
self.mid_block = FlexibleIdentityBlock()
else:
self.mid_block = FlexibleUNetMidBlock2DCrossAttn(
in_channels=down_blocks_out_channels[-1],
temb_channels=temb_dim,
resnet_act_fn=resnet_act_fn,
resnet_eps=resnet_eps,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
num_resnets=mid_num_resnets,
num_attentions=mid_num_attentions,
mix_block_in_forward=mix_block_in_forward,
add_upsample=mid_block_add_upsample,
)
###############
# Up blocks #
###############
up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions")
up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets")
prev_output_channels = self.configurations.get("prev_output_channels")
up_upsample = self.configurations.get("add_upsample")
self.up_blocks = nn.ModuleList()
for in_c, out_c, prev_out, n_res, n_att, add_up in zip(
reversed(down_blocks_in_channels),
reversed(down_blocks_out_channels),
prev_output_channels,
up_blocks_num_resnets,
up_blocks_num_attentions,
up_upsample,
):
self.up_blocks.append(
FlexibleCrossAttnUpBlock2D(
in_channels=in_c,
out_channels=out_c,
prev_output_channel=prev_out,
temb_channels=temb_dim,
num_resnets=n_res,
num_attentions=n_att,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_upsample=add_up,
mix_block_in_forward=mix_block_in_forward,
)
)
class FlexibleCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
last_block: bool = False,
mix_block_in_forward: bool = True,
):
super().__init__()
self.last_block = last_block
self.mix_block_in_forward = mix_block_in_forward
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
modules = []
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
in_channels = in_channels if i == 0 else out_channels
if add_resnet:
modules.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
if add_downsample:
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")])
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
output_states = ()
for module in self.modules_list:
if isinstance(module, ResnetBlock2D):
hidden_states = module(hidden_states, temb)
elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
raise ValueError(f"Got an unexpected module in modules list! {type(module)}")
if isinstance(module, ResnetBlock2D):
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if not self.last_block:
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class FlexibleCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
mix_block_in_forward: bool = True,
):
super().__init__()
modules = []
# WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline
self.resnets = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
if add_resnet:
self.resnets += [True]
modules.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for module in self.modules_list:
if isinstance(module, ResnetBlock2D):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = module(hidden_states, temb)
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class FlexibleUNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
use_linear_projection: bool = False,
upcast_attention: bool = False,
mix_block_in_forward: bool = True,
add_upsample: bool = True,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
modules = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
if add_resnet:
modules.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
self.upsamplers = nn.ModuleList([nn.Identity()])
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.modules_list[0](hidden_states, temb)
for module in self.modules_list:
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
elif isinstance(module, ResnetBlock2D):
hidden_states = module(hidden_states, temb)
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class FlexibleTransformer2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
only_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.in_channels = in_channels
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.use_linear_projection = use_linear_projection
if self.use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
for _ in range(num_layers)
]
)
# Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = False,
):
# 1. Input
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class DeciDiffusionPipeline(StableDiffusionPipeline):
deci_default_squeeze_mode = "10,6"
deci_default_number_of_iterations = 16
deci_default_guidance_rescale = 0.8
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
# Replace UNet with Deci`s unet
del unet
unet = FlexibleUNet2DConditionModel()
# Replace with custom scheduler
del scheduler
scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode)
super().__init__(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 16,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.8,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|