|
import itertools |
|
from functools import partial |
|
from typing import Any, Dict, Tuple, Callable |
|
from typing import Union, Optional, List |
|
|
|
import numpy as np |
|
import torch |
|
from diffusers import DPMSolverMultistepScheduler |
|
from diffusers import StableDiffusionPipeline, AutoencoderKL |
|
from diffusers import Transformer2DModel, ModelMixin, ConfigMixin, SchedulerMixin |
|
from diffusers import UNet2DConditionModel |
|
from diffusers.configuration_utils import register_to_config |
|
from diffusers.models.attention import BasicTransformerBlock |
|
from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D |
|
from diffusers.models.transformer_2d import Transformer2DModelOutput |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.utils import replace_example_docstring |
|
from torch import nn |
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
|
|
def custom_sort_order(obj): |
|
""" |
|
Key function for sorting order of execution in forward methods |
|
""" |
|
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__) |
|
|
|
|
|
def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing): |
|
""" |
|
:param timestep_spacing: the timestep_spacing array we want to squeeze |
|
:param n: the size of the squeezed array |
|
:param i: the index we start squeezing from |
|
:return: squeezed timestep_spacing |
|
Example: |
|
timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16) |
|
n = 10, i = 6 |
|
Expected: |
|
[967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133 |
|
""" |
|
assert i < n |
|
squeezed = np.flip(np.arange(n)) + 1 |
|
squeezed[:i] = timestep_spacing[:i] |
|
k = squeezed[i - 1] // (n - i + 1) |
|
squeezed[i:] *= k |
|
|
|
return squeezed |
|
|
|
|
|
PREDEFINED_TIMESTEP_SQUEEZERS = { |
|
|
|
"10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6), |
|
"11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7), |
|
} |
|
|
|
FlexibleUnetConfigurations = { |
|
|
|
"sample_size": 64, |
|
"temb_dim": 320 * 4, |
|
"resnet_eps": 1e-5, |
|
"resnet_act_fn": "silu", |
|
"num_attention_heads": 8, |
|
"cross_attention_dim": 768, |
|
|
|
"mix_block_in_forward": True, |
|
|
|
"down_blocks_in_channels": [320, 320, 640], |
|
"down_blocks_out_channels": [320, 640, 1280], |
|
"down_blocks_num_attentions": [0, 1, 3], |
|
"down_blocks_num_resnets": [2, 2, 1], |
|
"add_downsample": [True, True, False], |
|
|
|
"add_upsample_mid_block": None, |
|
"mid_num_resnets": 0, |
|
"mid_num_attentions": 0, |
|
|
|
"prev_output_channels": [1280, 1280, 640], |
|
"up_blocks_num_attentions": [5, 3, 0], |
|
"up_blocks_num_resnets": [2, 3, 3], |
|
"add_upsample": [True, True, False], |
|
} |
|
|
|
|
|
class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler, SchedulerMixin): |
|
""" |
|
This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences: |
|
* Defaults are modified to accommodate DeciDiffusion |
|
* It supports a squeezer to squeeze the number of inference steps to a smaller number |
|
//!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline! |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_train_timesteps: int = 1000, |
|
beta_start: float = 0.0001, |
|
beta_end: float = 0.02, |
|
beta_schedule: str = "squaredcos_cap_v2", |
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None, |
|
solver_order: int = 2, |
|
prediction_type: str = "v_prediction", |
|
thresholding: bool = False, |
|
dynamic_thresholding_ratio: float = 0.995, |
|
sample_max_value: float = 1.0, |
|
algorithm_type: str = "dpmsolver++", |
|
solver_type: str = "heun", |
|
lower_order_final: bool = True, |
|
use_karras_sigmas: Optional[bool] = False, |
|
lambda_min_clipped: float = -3.0, |
|
variance_type: Optional[str] = None, |
|
timestep_spacing: str = "linspace", |
|
steps_offset: int = 1, |
|
squeeze_mode: Optional[str] = None, |
|
): |
|
self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode) |
|
|
|
if use_karras_sigmas: |
|
raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`") |
|
|
|
super().__init__( |
|
num_train_timesteps=num_train_timesteps, |
|
beta_start=beta_start, |
|
beta_end=beta_end, |
|
beta_schedule=beta_schedule, |
|
trained_betas=trained_betas, |
|
solver_order=solver_order, |
|
prediction_type=prediction_type, |
|
thresholding=thresholding, |
|
dynamic_thresholding_ratio=dynamic_thresholding_ratio, |
|
sample_max_value=sample_max_value, |
|
algorithm_type=algorithm_type, |
|
solver_type=solver_type, |
|
lower_order_final=lower_order_final, |
|
use_karras_sigmas=False, |
|
lambda_min_clipped=lambda_min_clipped, |
|
variance_type=variance_type, |
|
timestep_spacing=timestep_spacing, |
|
steps_offset=steps_offset, |
|
) |
|
|
|
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): |
|
""" |
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
|
|
|
Args: |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
""" |
|
super().set_timesteps(num_inference_steps=num_inference_steps, device=device) |
|
if self._squeezer is not None: |
|
timesteps = self._squeezer(self.timesteps.cpu()) |
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
|
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 |
|
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) |
|
self.sigmas = torch.from_numpy(sigmas) |
|
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) |
|
self.num_inference_steps = len(timesteps) |
|
|
|
|
|
class FlexibleIdentityBlock(nn.Module): |
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
): |
|
return hidden_states |
|
|
|
|
|
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin): |
|
configurations = FlexibleUnetConfigurations |
|
|
|
@register_to_config |
|
def __init__(self): |
|
super().__init__( |
|
sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]), |
|
cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]), |
|
) |
|
|
|
num_attention_heads = self.configurations.get("num_attention_heads") |
|
cross_attention_dim = self.configurations.get("cross_attention_dim") |
|
mix_block_in_forward = self.configurations.get("mix_block_in_forward") |
|
resnet_act_fn = self.configurations.get("resnet_act_fn") |
|
resnet_eps = self.configurations.get("resnet_eps") |
|
temb_dim = self.configurations.get("temb_dim") |
|
|
|
|
|
|
|
|
|
down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions") |
|
down_blocks_out_channels = self.configurations.get("down_blocks_out_channels") |
|
down_blocks_in_channels = self.configurations.get("down_blocks_in_channels") |
|
down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets") |
|
add_downsample = self.configurations.get("add_downsample") |
|
|
|
self.down_blocks = nn.ModuleList() |
|
|
|
for i, (in_c, out_c, n_res, n_att, add_down) in enumerate( |
|
zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample) |
|
): |
|
last_block = i == len(down_blocks_in_channels) - 1 |
|
self.down_blocks.append( |
|
FlexibleCrossAttnDownBlock2D( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
temb_channels=temb_dim, |
|
num_resnets=n_res, |
|
num_attentions=n_att, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
num_attention_heads=num_attention_heads, |
|
cross_attention_dim=cross_attention_dim, |
|
add_downsample=add_down, |
|
last_block=last_block, |
|
mix_block_in_forward=mix_block_in_forward, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block") |
|
mid_num_attentions = self.configurations.get("mid_num_attentions") |
|
mid_num_resnets = self.configurations.get("mid_num_resnets") |
|
|
|
if mid_num_resnets == mid_num_attentions == 0: |
|
self.mid_block = FlexibleIdentityBlock() |
|
else: |
|
self.mid_block = FlexibleUNetMidBlock2DCrossAttn( |
|
in_channels=down_blocks_out_channels[-1], |
|
temb_channels=temb_dim, |
|
resnet_act_fn=resnet_act_fn, |
|
resnet_eps=resnet_eps, |
|
cross_attention_dim=cross_attention_dim, |
|
num_attention_heads=num_attention_heads, |
|
num_resnets=mid_num_resnets, |
|
num_attentions=mid_num_attentions, |
|
mix_block_in_forward=mix_block_in_forward, |
|
add_upsample=mid_block_add_upsample, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions") |
|
up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets") |
|
prev_output_channels = self.configurations.get("prev_output_channels") |
|
up_upsample = self.configurations.get("add_upsample") |
|
|
|
self.up_blocks = nn.ModuleList() |
|
for in_c, out_c, prev_out, n_res, n_att, add_up in zip( |
|
reversed(down_blocks_in_channels), |
|
reversed(down_blocks_out_channels), |
|
prev_output_channels, |
|
up_blocks_num_resnets, |
|
up_blocks_num_attentions, |
|
up_upsample, |
|
): |
|
self.up_blocks.append( |
|
FlexibleCrossAttnUpBlock2D( |
|
in_channels=in_c, |
|
out_channels=out_c, |
|
prev_output_channel=prev_out, |
|
temb_channels=temb_dim, |
|
num_resnets=n_res, |
|
num_attentions=n_att, |
|
resnet_eps=resnet_eps, |
|
resnet_act_fn=resnet_act_fn, |
|
num_attention_heads=num_attention_heads, |
|
cross_attention_dim=cross_attention_dim, |
|
add_upsample=add_up, |
|
mix_block_in_forward=mix_block_in_forward, |
|
) |
|
) |
|
|
|
|
|
class FlexibleCrossAttnDownBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_resnets: int = 1, |
|
num_attentions: int = 1, |
|
transformer_layers_per_block: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
num_attention_heads: int = 1, |
|
cross_attention_dim: int = 1280, |
|
output_scale_factor: float = 1.0, |
|
downsample_padding: int = 1, |
|
add_downsample: bool = True, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
last_block: bool = False, |
|
mix_block_in_forward: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.last_block = last_block |
|
self.mix_block_in_forward = mix_block_in_forward |
|
self.has_cross_attention = True |
|
self.num_attention_heads = num_attention_heads |
|
|
|
modules = [] |
|
|
|
add_resnets = [True] * num_resnets |
|
add_cross_attentions = [True] * num_attentions |
|
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
|
in_channels = in_channels if i == 0 else out_channels |
|
if add_resnet: |
|
modules.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
if add_cross_attention: |
|
modules.append( |
|
FlexibleTransformer2DModel( |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=out_channels // num_attention_heads, |
|
in_channels=out_channels, |
|
num_layers=transformer_layers_per_block, |
|
cross_attention_dim=cross_attention_dim, |
|
norm_num_groups=resnet_groups, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
) |
|
) |
|
|
|
if not mix_block_in_forward: |
|
modules = sorted(modules, key=custom_sort_order) |
|
|
|
self.modules_list = nn.ModuleList(modules) |
|
|
|
if add_downsample: |
|
self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")]) |
|
else: |
|
self.downsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
): |
|
output_states = () |
|
|
|
for module in self.modules_list: |
|
if isinstance(module, ResnetBlock2D): |
|
hidden_states = module(hidden_states, temb) |
|
elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
|
hidden_states = module( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
else: |
|
raise ValueError(f"Got an unexpected module in modules list! {type(module)}") |
|
if isinstance(module, ResnetBlock2D): |
|
output_states = output_states + (hidden_states,) |
|
|
|
if self.downsamplers is not None: |
|
for downsampler in self.downsamplers: |
|
hidden_states = downsampler(hidden_states) |
|
|
|
if not self.last_block: |
|
output_states = output_states + (hidden_states,) |
|
|
|
return hidden_states, output_states |
|
|
|
|
|
class FlexibleCrossAttnUpBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
prev_output_channel: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_resnets: int = 1, |
|
num_attentions: int = 1, |
|
transformer_layers_per_block: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
num_attention_heads: int = 1, |
|
cross_attention_dim: int = 1280, |
|
output_scale_factor: float = 1.0, |
|
add_upsample: bool = True, |
|
use_linear_projection: bool = False, |
|
only_cross_attention: bool = False, |
|
upcast_attention: bool = False, |
|
mix_block_in_forward: bool = True, |
|
): |
|
super().__init__() |
|
modules = [] |
|
|
|
|
|
self.resnets = [] |
|
|
|
self.has_cross_attention = True |
|
self.num_attention_heads = num_attention_heads |
|
|
|
add_resnets = [True] * num_resnets |
|
add_cross_attentions = [True] * num_attentions |
|
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
|
res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels |
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
|
if add_resnet: |
|
self.resnets += [True] |
|
modules.append( |
|
ResnetBlock2D( |
|
in_channels=resnet_in_channels + res_skip_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
if add_cross_attention: |
|
modules.append( |
|
FlexibleTransformer2DModel( |
|
num_attention_heads, |
|
out_channels // num_attention_heads, |
|
in_channels=out_channels, |
|
num_layers=transformer_layers_per_block, |
|
cross_attention_dim=cross_attention_dim, |
|
norm_num_groups=resnet_groups, |
|
use_linear_projection=use_linear_projection, |
|
only_cross_attention=only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
) |
|
) |
|
|
|
if not mix_block_in_forward: |
|
modules = sorted(modules, key=custom_sort_order) |
|
|
|
self.modules_list = nn.ModuleList(modules) |
|
|
|
self.upsamplers = None |
|
if add_upsample: |
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
upsample_size: Optional[int] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
): |
|
|
|
for module in self.modules_list: |
|
if isinstance(module, ResnetBlock2D): |
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
hidden_states = module(hidden_states, temb) |
|
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
|
hidden_states = module( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlexibleUNetMidBlock2DCrossAttn(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
temb_channels: int, |
|
dropout: float = 0.0, |
|
num_resnets: int = 1, |
|
num_attentions: int = 1, |
|
transformer_layers_per_block: int = 1, |
|
resnet_eps: float = 1e-6, |
|
resnet_time_scale_shift: str = "default", |
|
resnet_act_fn: str = "swish", |
|
resnet_groups: int = 32, |
|
resnet_pre_norm: bool = True, |
|
num_attention_heads: int = 1, |
|
output_scale_factor: float = 1.0, |
|
cross_attention_dim: int = 1280, |
|
use_linear_projection: bool = False, |
|
upcast_attention: bool = False, |
|
mix_block_in_forward: bool = True, |
|
add_upsample: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.has_cross_attention = True |
|
self.num_attention_heads = num_attention_heads |
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
|
|
|
modules = [ |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=in_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
] |
|
|
|
add_resnets = [True] * num_resnets |
|
add_cross_attentions = [True] * num_attentions |
|
for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)): |
|
if add_cross_attention: |
|
modules.append( |
|
FlexibleTransformer2DModel( |
|
num_attention_heads, |
|
in_channels // num_attention_heads, |
|
in_channels=in_channels, |
|
num_layers=transformer_layers_per_block, |
|
cross_attention_dim=cross_attention_dim, |
|
norm_num_groups=resnet_groups, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
) |
|
) |
|
|
|
if add_resnet: |
|
modules.append( |
|
ResnetBlock2D( |
|
in_channels=in_channels, |
|
out_channels=in_channels, |
|
temb_channels=temb_channels, |
|
eps=resnet_eps, |
|
groups=resnet_groups, |
|
dropout=dropout, |
|
time_embedding_norm=resnet_time_scale_shift, |
|
non_linearity=resnet_act_fn, |
|
output_scale_factor=output_scale_factor, |
|
pre_norm=resnet_pre_norm, |
|
) |
|
) |
|
if not mix_block_in_forward: |
|
modules = sorted(modules, key=custom_sort_order) |
|
|
|
self.modules_list = nn.ModuleList(modules) |
|
|
|
self.upsamplers = nn.ModuleList([nn.Identity()]) |
|
if add_upsample: |
|
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)]) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
) -> torch.FloatTensor: |
|
hidden_states = self.modules_list[0](hidden_states, temb) |
|
|
|
for module in self.modules_list: |
|
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)): |
|
hidden_states = module( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
elif isinstance(module, ResnetBlock2D): |
|
hidden_states = module(hidden_states, temb) |
|
|
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class FlexibleTransformer2DModel(ModelMixin, ConfigMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_attention_heads: int = 16, |
|
attention_head_dim: int = 88, |
|
in_channels: Optional[int] = None, |
|
out_channels: Optional[int] = None, |
|
num_layers: int = 1, |
|
dropout: float = 0.0, |
|
norm_num_groups: int = 32, |
|
cross_attention_dim: Optional[int] = None, |
|
attention_bias: bool = False, |
|
activation_fn: str = "geglu", |
|
num_embeds_ada_norm: Optional[int] = None, |
|
only_cross_attention: bool = False, |
|
use_linear_projection: bool = False, |
|
upcast_attention: bool = False, |
|
norm_type: str = "layer_norm", |
|
norm_elementwise_affine: bool = True, |
|
): |
|
super().__init__() |
|
self.num_attention_heads = num_attention_heads |
|
self.attention_head_dim = attention_head_dim |
|
self.in_channels = in_channels |
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
|
|
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
|
self.use_linear_projection = use_linear_projection |
|
if self.use_linear_projection: |
|
self.proj_in = nn.Linear(in_channels, inner_dim) |
|
else: |
|
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
BasicTransformerBlock( |
|
inner_dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
dropout=dropout, |
|
cross_attention_dim=cross_attention_dim, |
|
activation_fn=activation_fn, |
|
num_embeds_ada_norm=num_embeds_ada_norm, |
|
attention_bias=attention_bias, |
|
only_cross_attention=only_cross_attention, |
|
upcast_attention=upcast_attention, |
|
norm_type=norm_type, |
|
norm_elementwise_affine=norm_elementwise_affine, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
|
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
if self.use_linear_projection: |
|
self.proj_out = nn.Linear(inner_dim, in_channels) |
|
else: |
|
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = False, |
|
): |
|
|
|
batch, _, height, width = hidden_states.shape |
|
residual = hidden_states |
|
|
|
hidden_states = self.norm(hidden_states) |
|
if not self.use_linear_projection: |
|
hidden_states = self.proj_in(hidden_states) |
|
inner_dim = hidden_states.shape[1] |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
|
else: |
|
inner_dim = hidden_states.shape[1] |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
|
hidden_states = self.proj_in(hidden_states) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
hidden_states = block( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
timestep=timestep, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
class_labels=class_labels, |
|
) |
|
|
|
|
|
if not self.use_linear_projection: |
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
hidden_states = self.proj_out(hidden_states) |
|
else: |
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
|
|
|
output = hidden_states + residual |
|
if return_dict: |
|
return (output,) |
|
return Transformer2DModelOutput(sample=output) |
|
|
|
|
|
class DeciDiffusionPipeline(StableDiffusionPipeline): |
|
deci_default_squeeze_mode = "10,6" |
|
deci_default_number_of_iterations = 16 |
|
deci_default_guidance_rescale = 0.7 |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
safety_checker: StableDiffusionSafetyChecker, |
|
feature_extractor: CLIPImageProcessor, |
|
requires_safety_checker: bool = True, |
|
): |
|
|
|
del unet |
|
unet = FlexibleUNet2DConditionModel() |
|
|
|
|
|
del scheduler |
|
scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode) |
|
|
|
super().__init__( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
requires_safety_checker=requires_safety_checker, |
|
) |
|
|
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 16, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.7, |
|
): |
|
r""" |
|
The call function to the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
|
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
|
The width in pixels of the generated image. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
A higher guidance scale value encourages the model to generate images closely linked to the text |
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
|
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
|
generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor is generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
|
provided, text embeddings are generated from the `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that calls every `callback_steps` steps during inference. The function is called with the |
|
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function is called. If not specified, the callback is called at |
|
every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
guidance_rescale (`float`, *optional*, defaults to 0.7): |
|
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are |
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when |
|
using zero terminal SNR. |
|
|
|
Examples: |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is a list with the generated images and the |
|
second element is a list of `bool`s indicating whether the corresponding generated image contains |
|
"not-safe-for-work" (nsfw) content. |
|
""" |
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
|
|
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=text_encoder_lora_scale, |
|
) |
|
|
|
|
|
|
|
if do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=len(timesteps)) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if do_classifier_free_guidance and guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
if not output_type == "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
else: |
|
image = latents |
|
has_nsfw_concept = None |
|
|
|
if has_nsfw_concept is None: |
|
do_denormalize = [True] * image.shape[0] |
|
else: |
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|