|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from math import gcd |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import Tensor, nn |
|
|
|
from ..configuration_utils import ConfigMixin, register_to_config |
|
from ..utils import BaseOutput, is_torch_version, logging |
|
from ..utils.torch_utils import apply_freeu |
|
from .attention_processor import ( |
|
ADDED_KV_ATTENTION_PROCESSORS, |
|
CROSS_ATTENTION_PROCESSORS, |
|
Attention, |
|
AttentionProcessor, |
|
AttnAddedKVProcessor, |
|
AttnProcessor, |
|
) |
|
from .controlnet import ControlNetConditioningEmbedding |
|
from .embeddings import TimestepEmbedding, Timesteps |
|
from .modeling_utils import ModelMixin |
|
from .unets.unet_2d_blocks import ( |
|
CrossAttnDownBlock2D, |
|
CrossAttnUpBlock2D, |
|
Downsample2D, |
|
ResnetBlock2D, |
|
Transformer2DModel, |
|
UNetMidBlock2DCrossAttn, |
|
Upsample2D, |
|
) |
|
from .unets.unet_2d_condition import UNet2DConditionModel |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class ControlNetXSOutput(BaseOutput): |
|
""" |
|
The output of [`UNetControlNetXSModel`]. |
|
|
|
Args: |
|
sample (`Tensor` of shape `(batch_size, num_channels, height, width)`): |
|
The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base |
|
model output, but is already the final output. |
|
""" |
|
|
|
sample: Tensor = None |
|
|
|
|
|
class DownBlockControlNetXSAdapter(nn.Module): |
|
"""Components that together with corresponding components from the base model will form a |
|
`ControlNetXSCrossAttnDownBlock2D`""" |
|
|
|
def __init__( |
|
self, |
|
resnets: nn.ModuleList, |
|
base_to_ctrl: nn.ModuleList, |
|
ctrl_to_base: nn.ModuleList, |
|
attentions: Optional[nn.ModuleList] = None, |
|
downsampler: Optional[nn.Conv2d] = None, |
|
): |
|
super().__init__() |
|
self.resnets = resnets |
|
self.base_to_ctrl = base_to_ctrl |
|
self.ctrl_to_base = ctrl_to_base |
|
self.attentions = attentions |
|
self.downsamplers = downsampler |
|
|
|
|
|
class MidBlockControlNetXSAdapter(nn.Module): |
|
"""Components that together with corresponding components from the base model will form a |
|
`ControlNetXSCrossAttnMidBlock2D`""" |
|
|
|
def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList): |
|
super().__init__() |
|
self.midblock = midblock |
|
self.base_to_ctrl = base_to_ctrl |
|
self.ctrl_to_base = ctrl_to_base |
|
|
|
|
|
class UpBlockControlNetXSAdapter(nn.Module): |
|
"""Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" |
|
|
|
def __init__(self, ctrl_to_base: nn.ModuleList): |
|
super().__init__() |
|
self.ctrl_to_base = ctrl_to_base |
|
|
|
|
|
def get_down_block_adapter( |
|
base_in_channels: int, |
|
base_out_channels: int, |
|
ctrl_in_channels: int, |
|
ctrl_out_channels: int, |
|
temb_channels: int, |
|
max_norm_num_groups: Optional[int] = 32, |
|
has_crossattn=True, |
|
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, |
|
num_attention_heads: Optional[int] = 1, |
|
cross_attention_dim: Optional[int] = 1024, |
|
add_downsample: bool = True, |
|
upcast_attention: Optional[bool] = False, |
|
use_linear_projection: Optional[bool] = True, |
|
): |
|
num_layers = 2 |
|
|
|
resnets = [] |
|
attentions = [] |
|
ctrl_to_base = [] |
|
base_to_ctrl = [] |
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
|
|
|
for i in range(num_layers): |
|
base_in_channels = base_in_channels if i == 0 else base_out_channels |
|
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels |
|
|
|
|
|
|
|
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) |
|
|
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=ctrl_in_channels + base_in_channels, |
|
out_channels=ctrl_out_channels, |
|
temb_channels=temb_channels, |
|
groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), |
|
groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), |
|
eps=1e-5, |
|
) |
|
) |
|
|
|
if has_crossattn: |
|
attentions.append( |
|
Transformer2DModel( |
|
num_attention_heads, |
|
ctrl_out_channels // num_attention_heads, |
|
in_channels=ctrl_out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), |
|
) |
|
) |
|
|
|
|
|
|
|
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) |
|
|
|
if add_downsample: |
|
|
|
|
|
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) |
|
|
|
downsamplers = Downsample2D( |
|
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" |
|
) |
|
|
|
|
|
|
|
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) |
|
else: |
|
downsamplers = None |
|
|
|
down_block_components = DownBlockControlNetXSAdapter( |
|
resnets=nn.ModuleList(resnets), |
|
base_to_ctrl=nn.ModuleList(base_to_ctrl), |
|
ctrl_to_base=nn.ModuleList(ctrl_to_base), |
|
) |
|
|
|
if has_crossattn: |
|
down_block_components.attentions = nn.ModuleList(attentions) |
|
if downsamplers is not None: |
|
down_block_components.downsamplers = downsamplers |
|
|
|
return down_block_components |
|
|
|
|
|
def get_mid_block_adapter( |
|
base_channels: int, |
|
ctrl_channels: int, |
|
temb_channels: Optional[int] = None, |
|
max_norm_num_groups: Optional[int] = 32, |
|
transformer_layers_per_block: int = 1, |
|
num_attention_heads: Optional[int] = 1, |
|
cross_attention_dim: Optional[int] = 1024, |
|
upcast_attention: bool = False, |
|
use_linear_projection: bool = True, |
|
): |
|
|
|
|
|
base_to_ctrl = make_zero_conv(base_channels, base_channels) |
|
|
|
midblock = UNetMidBlock2DCrossAttn( |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=ctrl_channels + base_channels, |
|
out_channels=ctrl_channels, |
|
temb_channels=temb_channels, |
|
|
|
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), |
|
cross_attention_dim=cross_attention_dim, |
|
num_attention_heads=num_attention_heads, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
) |
|
|
|
|
|
|
|
ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) |
|
|
|
return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) |
|
|
|
|
|
def get_up_block_adapter( |
|
out_channels: int, |
|
prev_output_channel: int, |
|
ctrl_skip_channels: List[int], |
|
): |
|
ctrl_to_base = [] |
|
num_layers = 3 |
|
for i in range(num_layers): |
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) |
|
|
|
return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base)) |
|
|
|
|
|
class ControlNetXSAdapter(ModelMixin, ConfigMixin): |
|
r""" |
|
A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a |
|
`UNet2DConditionModel` base model). |
|
|
|
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic |
|
methods implemented for all models (such as downloading or saving). |
|
|
|
Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's |
|
default parameters are compatible with StableDiffusion. |
|
|
|
Parameters: |
|
conditioning_channels (`int`, defaults to 3): |
|
Number of channels of conditioning input (e.g. an image) |
|
conditioning_channel_order (`str`, defaults to `"rgb"`): |
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`. |
|
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): |
|
The tuple of output channels for each block in the `controlnet_cond_embedding` layer. |
|
time_embedding_mix (`float`, defaults to 1.0): |
|
If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time |
|
embedding is used. Otherwise, both are combined. |
|
learn_time_embedding (`bool`, defaults to `False`): |
|
Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time |
|
embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base |
|
model's time embedding. |
|
num_attention_heads (`list[int]`, defaults to `[4]`): |
|
The number of attention heads. |
|
block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): |
|
The tuple of output channels for each block. |
|
base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): |
|
The tuple of output channels for each block in the base unet. |
|
cross_attention_dim (`int`, defaults to 1024): |
|
The dimension of the cross attention features. |
|
down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): |
|
The tuple of downsample blocks to use. |
|
sample_size (`int`, defaults to 96): |
|
Height and width of input/output sample. |
|
transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): |
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for |
|
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. |
|
upcast_attention (`bool`, defaults to `True`): |
|
Whether the attention computation should always be upcasted. |
|
max_norm_num_groups (`int`, defaults to 32): |
|
Maximum number of groups in group normal. The actual number will the the largest divisor of the respective |
|
channels, that is <= max_norm_num_groups. |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
conditioning_channels: int = 3, |
|
conditioning_channel_order: str = "rgb", |
|
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), |
|
time_embedding_mix: float = 1.0, |
|
learn_time_embedding: bool = False, |
|
num_attention_heads: Union[int, Tuple[int]] = 4, |
|
block_out_channels: Tuple[int] = (4, 8, 16, 16), |
|
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), |
|
cross_attention_dim: int = 1024, |
|
down_block_types: Tuple[str] = ( |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"DownBlock2D", |
|
), |
|
sample_size: Optional[int] = 96, |
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1, |
|
upcast_attention: bool = True, |
|
max_norm_num_groups: int = 32, |
|
use_linear_projection: bool = True, |
|
): |
|
super().__init__() |
|
|
|
time_embedding_input_dim = base_block_out_channels[0] |
|
time_embedding_dim = base_block_out_channels[0] * 4 |
|
|
|
|
|
if conditioning_channel_order not in ["rgb", "bgr"]: |
|
raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") |
|
|
|
if len(block_out_channels) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
if not isinstance(transformer_layers_per_block, (list, tuple)): |
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) |
|
if not isinstance(cross_attention_dim, (list, tuple)): |
|
cross_attention_dim = [cross_attention_dim] * len(down_block_types) |
|
|
|
if not isinstance(num_attention_heads, (list, tuple)): |
|
num_attention_heads = [num_attention_heads] * len(down_block_types) |
|
|
|
if len(num_attention_heads) != len(down_block_types): |
|
raise ValueError( |
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." |
|
) |
|
|
|
|
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( |
|
conditioning_embedding_channels=block_out_channels[0], |
|
block_out_channels=conditioning_embedding_out_channels, |
|
conditioning_channels=conditioning_channels, |
|
) |
|
|
|
|
|
if learn_time_embedding: |
|
self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) |
|
else: |
|
self.time_embedding = None |
|
|
|
self.down_blocks = nn.ModuleList([]) |
|
self.up_connections = nn.ModuleList([]) |
|
|
|
|
|
self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) |
|
self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) |
|
|
|
|
|
base_out_channels = base_block_out_channels[0] |
|
ctrl_out_channels = block_out_channels[0] |
|
for i, down_block_type in enumerate(down_block_types): |
|
base_in_channels = base_out_channels |
|
base_out_channels = base_block_out_channels[i] |
|
ctrl_in_channels = ctrl_out_channels |
|
ctrl_out_channels = block_out_channels[i] |
|
has_crossattn = "CrossAttn" in down_block_type |
|
is_final_block = i == len(down_block_types) - 1 |
|
|
|
self.down_blocks.append( |
|
get_down_block_adapter( |
|
base_in_channels=base_in_channels, |
|
base_out_channels=base_out_channels, |
|
ctrl_in_channels=ctrl_in_channels, |
|
ctrl_out_channels=ctrl_out_channels, |
|
temb_channels=time_embedding_dim, |
|
max_norm_num_groups=max_norm_num_groups, |
|
has_crossattn=has_crossattn, |
|
transformer_layers_per_block=transformer_layers_per_block[i], |
|
num_attention_heads=num_attention_heads[i], |
|
cross_attention_dim=cross_attention_dim[i], |
|
add_downsample=not is_final_block, |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
) |
|
|
|
|
|
self.mid_block = get_mid_block_adapter( |
|
base_channels=base_block_out_channels[-1], |
|
ctrl_channels=block_out_channels[-1], |
|
temb_channels=time_embedding_dim, |
|
transformer_layers_per_block=transformer_layers_per_block[-1], |
|
num_attention_heads=num_attention_heads[-1], |
|
cross_attention_dim=cross_attention_dim[-1], |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
|
|
|
|
|
|
ctrl_skip_channels = [block_out_channels[0]] |
|
for i, out_channels in enumerate(block_out_channels): |
|
number_of_subblocks = ( |
|
3 if i < len(block_out_channels) - 1 else 2 |
|
) |
|
ctrl_skip_channels.extend([out_channels] * number_of_subblocks) |
|
|
|
reversed_base_block_out_channels = list(reversed(base_block_out_channels)) |
|
|
|
base_out_channels = reversed_base_block_out_channels[0] |
|
for i in range(len(down_block_types)): |
|
prev_base_output_channel = base_out_channels |
|
base_out_channels = reversed_base_block_out_channels[i] |
|
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] |
|
|
|
self.up_connections.append( |
|
get_up_block_adapter( |
|
out_channels=base_out_channels, |
|
prev_output_channel=prev_base_output_channel, |
|
ctrl_skip_channels=ctrl_skip_channels_, |
|
) |
|
) |
|
|
|
@classmethod |
|
def from_unet( |
|
cls, |
|
unet: UNet2DConditionModel, |
|
size_ratio: Optional[float] = None, |
|
block_out_channels: Optional[List[int]] = None, |
|
num_attention_heads: Optional[List[int]] = None, |
|
learn_time_embedding: bool = False, |
|
time_embedding_mix: int = 1.0, |
|
conditioning_channels: int = 3, |
|
conditioning_channel_order: str = "rgb", |
|
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), |
|
): |
|
r""" |
|
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. |
|
|
|
Parameters: |
|
unet (`UNet2DConditionModel`): |
|
The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. |
|
size_ratio (float, *optional*, defaults to `None`): |
|
When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this |
|
or `block_out_channels` must be given. |
|
block_out_channels (`List[int]`, *optional*, defaults to `None`): |
|
Down blocks output channels in control model. Either this or `size_ratio` must be given. |
|
num_attention_heads (`List[int]`, *optional*, defaults to `None`): |
|
The dimension of the attention heads. The naming seems a bit confusing and it is, see |
|
https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. |
|
learn_time_embedding (`bool`, defaults to `False`): |
|
Whether the `ControlNetXSAdapter` should learn a time embedding. |
|
time_embedding_mix (`float`, defaults to 1.0): |
|
If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time |
|
embedding is used. Otherwise, both are combined. |
|
conditioning_channels (`int`, defaults to 3): |
|
Number of channels of conditioning input (e.g. an image) |
|
conditioning_channel_order (`str`, defaults to `"rgb"`): |
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`. |
|
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): |
|
The tuple of output channel for each block in the `controlnet_cond_embedding` layer. |
|
""" |
|
|
|
|
|
fixed_size = block_out_channels is not None |
|
relative_size = size_ratio is not None |
|
if not (fixed_size ^ relative_size): |
|
raise ValueError( |
|
"Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." |
|
) |
|
|
|
|
|
block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] |
|
if num_attention_heads is None: |
|
|
|
num_attention_heads = unet.config.attention_head_dim |
|
|
|
model = cls( |
|
conditioning_channels=conditioning_channels, |
|
conditioning_channel_order=conditioning_channel_order, |
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels, |
|
time_embedding_mix=time_embedding_mix, |
|
learn_time_embedding=learn_time_embedding, |
|
num_attention_heads=num_attention_heads, |
|
block_out_channels=block_out_channels, |
|
base_block_out_channels=unet.config.block_out_channels, |
|
cross_attention_dim=unet.config.cross_attention_dim, |
|
down_block_types=unet.config.down_block_types, |
|
sample_size=unet.config.sample_size, |
|
transformer_layers_per_block=unet.config.transformer_layers_per_block, |
|
upcast_attention=unet.config.upcast_attention, |
|
max_norm_num_groups=unet.config.norm_num_groups, |
|
use_linear_projection=unet.config.use_linear_projection, |
|
) |
|
|
|
|
|
model.to(unet.dtype) |
|
|
|
return model |
|
|
|
def forward(self, *args, **kwargs): |
|
raise ValueError( |
|
"A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." |
|
) |
|
|
|
|
|
class UNetControlNetXSModel(ModelMixin, ConfigMixin): |
|
r""" |
|
A UNet fused with a ControlNet-XS adapter model |
|
|
|
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic |
|
methods implemented for all models (such as downloading or saving). |
|
|
|
`UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are |
|
compatible with StableDiffusion. |
|
|
|
It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in |
|
`ControlNetXSAdapter` . See their documentation for details. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
|
|
sample_size: Optional[int] = 96, |
|
down_block_types: Tuple[str] = ( |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"CrossAttnDownBlock2D", |
|
"DownBlock2D", |
|
), |
|
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), |
|
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), |
|
norm_num_groups: Optional[int] = 32, |
|
cross_attention_dim: Union[int, Tuple[int]] = 1024, |
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1, |
|
num_attention_heads: Union[int, Tuple[int]] = 8, |
|
addition_embed_type: Optional[str] = None, |
|
addition_time_embed_dim: Optional[int] = None, |
|
upcast_attention: bool = True, |
|
use_linear_projection: bool = True, |
|
time_cond_proj_dim: Optional[int] = None, |
|
projection_class_embeddings_input_dim: Optional[int] = None, |
|
|
|
time_embedding_mix: float = 1.0, |
|
ctrl_conditioning_channels: int = 3, |
|
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), |
|
ctrl_conditioning_channel_order: str = "rgb", |
|
ctrl_learn_time_embedding: bool = False, |
|
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), |
|
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, |
|
ctrl_max_norm_num_groups: int = 32, |
|
): |
|
super().__init__() |
|
|
|
if time_embedding_mix < 0 or time_embedding_mix > 1: |
|
raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") |
|
if time_embedding_mix < 1 and not ctrl_learn_time_embedding: |
|
raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") |
|
|
|
if addition_embed_type is not None and addition_embed_type != "text_time": |
|
raise ValueError( |
|
"As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." |
|
) |
|
|
|
if not isinstance(transformer_layers_per_block, (list, tuple)): |
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) |
|
if not isinstance(cross_attention_dim, (list, tuple)): |
|
cross_attention_dim = [cross_attention_dim] * len(down_block_types) |
|
if not isinstance(num_attention_heads, (list, tuple)): |
|
num_attention_heads = [num_attention_heads] * len(down_block_types) |
|
if not isinstance(ctrl_num_attention_heads, (list, tuple)): |
|
ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) |
|
|
|
base_num_attention_heads = num_attention_heads |
|
|
|
self.in_channels = 4 |
|
|
|
|
|
self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) |
|
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( |
|
conditioning_embedding_channels=ctrl_block_out_channels[0], |
|
block_out_channels=ctrl_conditioning_embedding_out_channels, |
|
conditioning_channels=ctrl_conditioning_channels, |
|
) |
|
self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) |
|
self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) |
|
|
|
|
|
time_embed_input_dim = block_out_channels[0] |
|
time_embed_dim = block_out_channels[0] * 4 |
|
|
|
self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.base_time_embedding = TimestepEmbedding( |
|
time_embed_input_dim, |
|
time_embed_dim, |
|
cond_proj_dim=time_cond_proj_dim, |
|
) |
|
if ctrl_learn_time_embedding: |
|
self.ctrl_time_embedding = TimestepEmbedding( |
|
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim |
|
) |
|
else: |
|
self.ctrl_time_embedding = None |
|
|
|
if addition_embed_type is None: |
|
self.base_add_time_proj = None |
|
self.base_add_embedding = None |
|
else: |
|
self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) |
|
|
|
|
|
down_blocks = [] |
|
base_out_channels = block_out_channels[0] |
|
ctrl_out_channels = ctrl_block_out_channels[0] |
|
for i, down_block_type in enumerate(down_block_types): |
|
base_in_channels = base_out_channels |
|
base_out_channels = block_out_channels[i] |
|
ctrl_in_channels = ctrl_out_channels |
|
ctrl_out_channels = ctrl_block_out_channels[i] |
|
has_crossattn = "CrossAttn" in down_block_type |
|
is_final_block = i == len(down_block_types) - 1 |
|
|
|
down_blocks.append( |
|
ControlNetXSCrossAttnDownBlock2D( |
|
base_in_channels=base_in_channels, |
|
base_out_channels=base_out_channels, |
|
ctrl_in_channels=ctrl_in_channels, |
|
ctrl_out_channels=ctrl_out_channels, |
|
temb_channels=time_embed_dim, |
|
norm_num_groups=norm_num_groups, |
|
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, |
|
has_crossattn=has_crossattn, |
|
transformer_layers_per_block=transformer_layers_per_block[i], |
|
base_num_attention_heads=base_num_attention_heads[i], |
|
ctrl_num_attention_heads=ctrl_num_attention_heads[i], |
|
cross_attention_dim=cross_attention_dim[i], |
|
add_downsample=not is_final_block, |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
) |
|
|
|
|
|
self.mid_block = ControlNetXSCrossAttnMidBlock2D( |
|
base_channels=block_out_channels[-1], |
|
ctrl_channels=ctrl_block_out_channels[-1], |
|
temb_channels=time_embed_dim, |
|
norm_num_groups=norm_num_groups, |
|
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, |
|
transformer_layers_per_block=transformer_layers_per_block[-1], |
|
base_num_attention_heads=base_num_attention_heads[-1], |
|
ctrl_num_attention_heads=ctrl_num_attention_heads[-1], |
|
cross_attention_dim=cross_attention_dim[-1], |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
|
|
|
|
up_blocks = [] |
|
rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) |
|
rev_num_attention_heads = list(reversed(base_num_attention_heads)) |
|
rev_cross_attention_dim = list(reversed(cross_attention_dim)) |
|
|
|
|
|
ctrl_skip_channels = [ctrl_block_out_channels[0]] |
|
for i, out_channels in enumerate(ctrl_block_out_channels): |
|
number_of_subblocks = ( |
|
3 if i < len(ctrl_block_out_channels) - 1 else 2 |
|
) |
|
ctrl_skip_channels.extend([out_channels] * number_of_subblocks) |
|
|
|
reversed_block_out_channels = list(reversed(block_out_channels)) |
|
|
|
out_channels = reversed_block_out_channels[0] |
|
for i, up_block_type in enumerate(up_block_types): |
|
prev_output_channel = out_channels |
|
out_channels = reversed_block_out_channels[i] |
|
in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] |
|
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] |
|
|
|
has_crossattn = "CrossAttn" in up_block_type |
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
up_blocks.append( |
|
ControlNetXSCrossAttnUpBlock2D( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
prev_output_channel=prev_output_channel, |
|
ctrl_skip_channels=ctrl_skip_channels_, |
|
temb_channels=time_embed_dim, |
|
resolution_idx=i, |
|
has_crossattn=has_crossattn, |
|
transformer_layers_per_block=rev_transformer_layers_per_block[i], |
|
num_attention_heads=rev_num_attention_heads[i], |
|
cross_attention_dim=rev_cross_attention_dim[i], |
|
add_upsample=not is_final_block, |
|
upcast_attention=upcast_attention, |
|
norm_num_groups=norm_num_groups, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
) |
|
|
|
self.down_blocks = nn.ModuleList(down_blocks) |
|
self.up_blocks = nn.ModuleList(up_blocks) |
|
|
|
self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) |
|
self.base_conv_act = nn.SiLU() |
|
self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) |
|
|
|
@classmethod |
|
def from_unet( |
|
cls, |
|
unet: UNet2DConditionModel, |
|
controlnet: Optional[ControlNetXSAdapter] = None, |
|
size_ratio: Optional[float] = None, |
|
ctrl_block_out_channels: Optional[List[float]] = None, |
|
time_embedding_mix: Optional[float] = None, |
|
ctrl_optional_kwargs: Optional[Dict] = None, |
|
): |
|
r""" |
|
Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] |
|
. |
|
|
|
Parameters: |
|
unet (`UNet2DConditionModel`): |
|
The UNet model we want to control. |
|
controlnet (`ControlNetXSAdapter`): |
|
The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS |
|
adapter will be created. |
|
size_ratio (float, *optional*, defaults to `None`): |
|
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. |
|
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): |
|
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, |
|
where this parameter is called `block_out_channels`. |
|
time_embedding_mix (`float`, *optional*, defaults to None): |
|
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. |
|
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): |
|
Passed to the `init` of the new controlent if no controlent was given. |
|
""" |
|
if controlnet is None: |
|
controlnet = ControlNetXSAdapter.from_unet( |
|
unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs |
|
) |
|
else: |
|
if any( |
|
o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) |
|
): |
|
raise ValueError( |
|
"When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." |
|
) |
|
|
|
|
|
params_for_unet = [ |
|
"sample_size", |
|
"down_block_types", |
|
"up_block_types", |
|
"block_out_channels", |
|
"norm_num_groups", |
|
"cross_attention_dim", |
|
"transformer_layers_per_block", |
|
"addition_embed_type", |
|
"addition_time_embed_dim", |
|
"upcast_attention", |
|
"use_linear_projection", |
|
"time_cond_proj_dim", |
|
"projection_class_embeddings_input_dim", |
|
] |
|
params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} |
|
|
|
params_for_unet["num_attention_heads"] = unet.config.attention_head_dim |
|
|
|
params_for_controlnet = [ |
|
"conditioning_channels", |
|
"conditioning_embedding_out_channels", |
|
"conditioning_channel_order", |
|
"learn_time_embedding", |
|
"block_out_channels", |
|
"num_attention_heads", |
|
"max_norm_num_groups", |
|
] |
|
params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} |
|
params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix |
|
|
|
|
|
model = cls.from_config({**params_for_unet, **params_for_controlnet}) |
|
|
|
|
|
|
|
modules_from_unet = [ |
|
"time_embedding", |
|
"conv_in", |
|
"conv_norm_out", |
|
"conv_out", |
|
] |
|
for m in modules_from_unet: |
|
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) |
|
|
|
optional_modules_from_unet = [ |
|
"add_time_proj", |
|
"add_embedding", |
|
] |
|
for m in optional_modules_from_unet: |
|
if hasattr(unet, m) and getattr(unet, m) is not None: |
|
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) |
|
|
|
|
|
model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) |
|
model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict()) |
|
if controlnet.time_embedding is not None: |
|
model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict()) |
|
model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) |
|
|
|
|
|
model.down_blocks = nn.ModuleList( |
|
ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) |
|
for b, c in zip(unet.down_blocks, controlnet.down_blocks) |
|
) |
|
model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) |
|
model.up_blocks = nn.ModuleList( |
|
ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) |
|
for b, c in zip(unet.up_blocks, controlnet.up_connections) |
|
) |
|
|
|
|
|
model.to(unet.dtype) |
|
|
|
return model |
|
|
|
def freeze_unet_params(self) -> None: |
|
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine |
|
tuning.""" |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
base_parts = [ |
|
"base_time_proj", |
|
"base_time_embedding", |
|
"base_add_time_proj", |
|
"base_add_embedding", |
|
"base_conv_in", |
|
"base_conv_norm_out", |
|
"base_conv_act", |
|
"base_conv_out", |
|
] |
|
base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] |
|
for part in base_parts: |
|
for param in part.parameters(): |
|
param.requires_grad = False |
|
|
|
for d in self.down_blocks: |
|
d.freeze_base_params() |
|
self.mid_block.freeze_base_params() |
|
for u in self.up_blocks: |
|
u.freeze_base_params() |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
@property |
|
|
|
def attn_processors(self) -> Dict[str, AttentionProcessor]: |
|
r""" |
|
Returns: |
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with |
|
indexed by its weight name. |
|
""" |
|
|
|
processors = {} |
|
|
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): |
|
if hasattr(module, "get_processor"): |
|
processors[f"{name}.processor"] = module.get_processor() |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
|
|
|
return processors |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_add_processors(name, module, processors) |
|
|
|
return processors |
|
|
|
|
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
|
r""" |
|
Sets the attention processor to use to compute attention. |
|
|
|
Parameters: |
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): |
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor |
|
for **all** `Attention` layers. |
|
|
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention |
|
processor. This is strongly recommended when setting trainable attention processors. |
|
|
|
""" |
|
count = len(self.attn_processors.keys()) |
|
|
|
if isinstance(processor, dict) and len(processor) != count: |
|
raise ValueError( |
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" |
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." |
|
) |
|
|
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): |
|
if hasattr(module, "set_processor"): |
|
if not isinstance(processor, dict): |
|
module.set_processor(processor) |
|
else: |
|
module.set_processor(processor.pop(f"{name}.processor")) |
|
|
|
for sub_name, child in module.named_children(): |
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
|
|
|
for name, module in self.named_children(): |
|
fn_recursive_attn_processor(name, module, processor) |
|
|
|
|
|
def set_default_attn_processor(self): |
|
""" |
|
Disables custom attention processors and sets the default attention implementation. |
|
""" |
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnAddedKVProcessor() |
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): |
|
processor = AttnProcessor() |
|
else: |
|
raise ValueError( |
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" |
|
) |
|
|
|
self.set_attn_processor(processor) |
|
|
|
|
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): |
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. |
|
|
|
The suffixes after the scaling factors represent the stage blocks where they are being applied. |
|
|
|
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that |
|
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. |
|
|
|
Args: |
|
s1 (`float`): |
|
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to |
|
mitigate the "oversmoothing effect" in the enhanced denoising process. |
|
s2 (`float`): |
|
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to |
|
mitigate the "oversmoothing effect" in the enhanced denoising process. |
|
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. |
|
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. |
|
""" |
|
for i, upsample_block in enumerate(self.up_blocks): |
|
setattr(upsample_block, "s1", s1) |
|
setattr(upsample_block, "s2", s2) |
|
setattr(upsample_block, "b1", b1) |
|
setattr(upsample_block, "b2", b2) |
|
|
|
|
|
def disable_freeu(self): |
|
"""Disables the FreeU mechanism.""" |
|
freeu_keys = {"s1", "s2", "b1", "b2"} |
|
for i, upsample_block in enumerate(self.up_blocks): |
|
for k in freeu_keys: |
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: |
|
setattr(upsample_block, k, None) |
|
|
|
|
|
def fuse_qkv_projections(self): |
|
""" |
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
|
are fused. For cross-attention modules, key and value projection matrices are fused. |
|
|
|
<Tip warning={true}> |
|
|
|
This API is 🧪 experimental. |
|
|
|
</Tip> |
|
""" |
|
self.original_attn_processors = None |
|
|
|
for _, attn_processor in self.attn_processors.items(): |
|
if "Added" in str(attn_processor.__class__.__name__): |
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") |
|
|
|
self.original_attn_processors = self.attn_processors |
|
|
|
for module in self.modules(): |
|
if isinstance(module, Attention): |
|
module.fuse_projections(fuse=True) |
|
|
|
|
|
def unfuse_qkv_projections(self): |
|
"""Disables the fused QKV projection if enabled. |
|
|
|
<Tip warning={true}> |
|
|
|
This API is 🧪 experimental. |
|
|
|
</Tip> |
|
|
|
""" |
|
if self.original_attn_processors is not None: |
|
self.set_attn_processor(self.original_attn_processors) |
|
|
|
def forward( |
|
self, |
|
sample: Tensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
controlnet_cond: Optional[torch.Tensor] = None, |
|
conditioning_scale: Optional[float] = 1.0, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
return_dict: bool = True, |
|
apply_control: bool = True, |
|
) -> Union[ControlNetXSOutput, Tuple]: |
|
""" |
|
The [`ControlNetXSModel`] forward method. |
|
|
|
Args: |
|
sample (`Tensor`): |
|
The noisy input tensor. |
|
timestep (`Union[torch.Tensor, float, int]`): |
|
The number of timesteps to denoise an input. |
|
encoder_hidden_states (`torch.Tensor`): |
|
The encoder hidden states. |
|
controlnet_cond (`Tensor`): |
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. |
|
conditioning_scale (`float`, defaults to `1.0`): |
|
How much the control model affects the base model outputs. |
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`): |
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. |
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): |
|
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the |
|
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep |
|
embeddings. |
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): |
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask |
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large |
|
negative values to the attention scores corresponding to "discard" tokens. |
|
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): |
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. |
|
added_cond_kwargs (`dict`): |
|
Additional conditions for the Stable Diffusion XL UNet. |
|
return_dict (`bool`, defaults to `True`): |
|
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. |
|
apply_control (`bool`, defaults to `True`): |
|
If `False`, the input is run only through the base model. |
|
|
|
Returns: |
|
[`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: |
|
If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a |
|
tuple is returned where the first element is the sample tensor. |
|
""" |
|
|
|
|
|
if self.config.ctrl_conditioning_channel_order == "bgr": |
|
controlnet_cond = torch.flip(controlnet_cond, dims=[1]) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
|
|
|
|
is_mps = sample.device.type == "mps" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if is_mps else torch.float64 |
|
else: |
|
dtype = torch.int32 if is_mps else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
|
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.base_time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=sample.dtype) |
|
|
|
if self.config.ctrl_learn_time_embedding and apply_control: |
|
ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) |
|
base_temb = self.base_time_embedding(t_emb, timestep_cond) |
|
interpolation_param = self.config.time_embedding_mix**0.3 |
|
|
|
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) |
|
else: |
|
temb = self.base_time_embedding(t_emb) |
|
|
|
|
|
aug_emb = None |
|
|
|
if self.config.addition_embed_type is None: |
|
pass |
|
elif self.config.addition_embed_type == "text_time": |
|
|
|
if "text_embeds" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" |
|
) |
|
text_embeds = added_cond_kwargs.get("text_embeds") |
|
if "time_ids" not in added_cond_kwargs: |
|
raise ValueError( |
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" |
|
) |
|
time_ids = added_cond_kwargs.get("time_ids") |
|
time_embeds = self.base_add_time_proj(time_ids.flatten()) |
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) |
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) |
|
add_embeds = add_embeds.to(temb.dtype) |
|
aug_emb = self.base_add_embedding(add_embeds) |
|
else: |
|
raise ValueError( |
|
f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." |
|
) |
|
|
|
temb = temb + aug_emb if aug_emb is not None else temb |
|
|
|
|
|
cemb = encoder_hidden_states |
|
|
|
|
|
h_ctrl = h_base = sample |
|
hs_base, hs_ctrl = [], [] |
|
|
|
|
|
guided_hint = self.controlnet_cond_embedding(controlnet_cond) |
|
|
|
|
|
|
|
h_base = self.base_conv_in(h_base) |
|
h_ctrl = self.ctrl_conv_in(h_ctrl) |
|
if guided_hint is not None: |
|
h_ctrl += guided_hint |
|
if apply_control: |
|
h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale |
|
|
|
hs_base.append(h_base) |
|
hs_ctrl.append(h_ctrl) |
|
|
|
for down in self.down_blocks: |
|
h_base, h_ctrl, residual_hb, residual_hc = down( |
|
hidden_states_base=h_base, |
|
hidden_states_ctrl=h_ctrl, |
|
temb=temb, |
|
encoder_hidden_states=cemb, |
|
conditioning_scale=conditioning_scale, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
apply_control=apply_control, |
|
) |
|
hs_base.extend(residual_hb) |
|
hs_ctrl.extend(residual_hc) |
|
|
|
|
|
h_base, h_ctrl = self.mid_block( |
|
hidden_states_base=h_base, |
|
hidden_states_ctrl=h_ctrl, |
|
temb=temb, |
|
encoder_hidden_states=cemb, |
|
conditioning_scale=conditioning_scale, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
apply_control=apply_control, |
|
) |
|
|
|
|
|
for up in self.up_blocks: |
|
n_resnets = len(up.resnets) |
|
skips_hb = hs_base[-n_resnets:] |
|
skips_hc = hs_ctrl[-n_resnets:] |
|
hs_base = hs_base[:-n_resnets] |
|
hs_ctrl = hs_ctrl[:-n_resnets] |
|
h_base = up( |
|
hidden_states=h_base, |
|
res_hidden_states_tuple_base=skips_hb, |
|
res_hidden_states_tuple_ctrl=skips_hc, |
|
temb=temb, |
|
encoder_hidden_states=cemb, |
|
conditioning_scale=conditioning_scale, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
apply_control=apply_control, |
|
) |
|
|
|
|
|
h_base = self.base_conv_norm_out(h_base) |
|
h_base = self.base_conv_act(h_base) |
|
h_base = self.base_conv_out(h_base) |
|
|
|
if not return_dict: |
|
return (h_base,) |
|
|
|
return ControlNetXSOutput(sample=h_base) |
|
|
|
|
|
class ControlNetXSCrossAttnDownBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
base_in_channels: int, |
|
base_out_channels: int, |
|
ctrl_in_channels: int, |
|
ctrl_out_channels: int, |
|
temb_channels: int, |
|
norm_num_groups: int = 32, |
|
ctrl_max_norm_num_groups: int = 32, |
|
has_crossattn=True, |
|
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, |
|
base_num_attention_heads: Optional[int] = 1, |
|
ctrl_num_attention_heads: Optional[int] = 1, |
|
cross_attention_dim: Optional[int] = 1024, |
|
add_downsample: bool = True, |
|
upcast_attention: Optional[bool] = False, |
|
use_linear_projection: Optional[bool] = True, |
|
): |
|
super().__init__() |
|
base_resnets = [] |
|
base_attentions = [] |
|
ctrl_resnets = [] |
|
ctrl_attentions = [] |
|
ctrl_to_base = [] |
|
base_to_ctrl = [] |
|
|
|
num_layers = 2 |
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
|
|
|
for i in range(num_layers): |
|
base_in_channels = base_in_channels if i == 0 else base_out_channels |
|
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels |
|
|
|
|
|
|
|
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) |
|
|
|
base_resnets.append( |
|
ResnetBlock2D( |
|
in_channels=base_in_channels, |
|
out_channels=base_out_channels, |
|
temb_channels=temb_channels, |
|
groups=norm_num_groups, |
|
) |
|
) |
|
ctrl_resnets.append( |
|
ResnetBlock2D( |
|
in_channels=ctrl_in_channels + base_in_channels, |
|
out_channels=ctrl_out_channels, |
|
temb_channels=temb_channels, |
|
groups=find_largest_factor( |
|
ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups |
|
), |
|
groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), |
|
eps=1e-5, |
|
) |
|
) |
|
|
|
if has_crossattn: |
|
base_attentions.append( |
|
Transformer2DModel( |
|
base_num_attention_heads, |
|
base_out_channels // base_num_attention_heads, |
|
in_channels=base_out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
norm_num_groups=norm_num_groups, |
|
) |
|
) |
|
ctrl_attentions.append( |
|
Transformer2DModel( |
|
ctrl_num_attention_heads, |
|
ctrl_out_channels // ctrl_num_attention_heads, |
|
in_channels=ctrl_out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), |
|
) |
|
) |
|
|
|
|
|
|
|
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) |
|
|
|
if add_downsample: |
|
|
|
|
|
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) |
|
|
|
self.base_downsamplers = Downsample2D( |
|
base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" |
|
) |
|
self.ctrl_downsamplers = Downsample2D( |
|
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" |
|
) |
|
|
|
|
|
|
|
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) |
|
else: |
|
self.base_downsamplers = None |
|
self.ctrl_downsamplers = None |
|
|
|
self.base_resnets = nn.ModuleList(base_resnets) |
|
self.ctrl_resnets = nn.ModuleList(ctrl_resnets) |
|
self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers |
|
self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers |
|
self.base_to_ctrl = nn.ModuleList(base_to_ctrl) |
|
self.ctrl_to_base = nn.ModuleList(ctrl_to_base) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
@classmethod |
|
def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): |
|
|
|
def get_first_cross_attention(block): |
|
return block.attentions[0].transformer_blocks[0].attn2 |
|
|
|
base_in_channels = base_downblock.resnets[0].in_channels |
|
base_out_channels = base_downblock.resnets[0].out_channels |
|
ctrl_in_channels = ( |
|
ctrl_downblock.resnets[0].in_channels - base_in_channels |
|
) |
|
ctrl_out_channels = ctrl_downblock.resnets[0].out_channels |
|
temb_channels = base_downblock.resnets[0].time_emb_proj.in_features |
|
num_groups = base_downblock.resnets[0].norm1.num_groups |
|
ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups |
|
if hasattr(base_downblock, "attentions"): |
|
has_crossattn = True |
|
transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) |
|
base_num_attention_heads = get_first_cross_attention(base_downblock).heads |
|
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads |
|
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim |
|
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention |
|
use_linear_projection = base_downblock.attentions[0].use_linear_projection |
|
else: |
|
has_crossattn = False |
|
transformer_layers_per_block = None |
|
base_num_attention_heads = None |
|
ctrl_num_attention_heads = None |
|
cross_attention_dim = None |
|
upcast_attention = None |
|
use_linear_projection = None |
|
add_downsample = base_downblock.downsamplers is not None |
|
|
|
|
|
model = cls( |
|
base_in_channels=base_in_channels, |
|
base_out_channels=base_out_channels, |
|
ctrl_in_channels=ctrl_in_channels, |
|
ctrl_out_channels=ctrl_out_channels, |
|
temb_channels=temb_channels, |
|
norm_num_groups=num_groups, |
|
ctrl_max_norm_num_groups=ctrl_num_groups, |
|
has_crossattn=has_crossattn, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
base_num_attention_heads=base_num_attention_heads, |
|
ctrl_num_attention_heads=ctrl_num_attention_heads, |
|
cross_attention_dim=cross_attention_dim, |
|
add_downsample=add_downsample, |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
|
|
|
|
model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) |
|
model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict()) |
|
if has_crossattn: |
|
model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) |
|
model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict()) |
|
if add_downsample: |
|
model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) |
|
model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict()) |
|
model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict()) |
|
model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict()) |
|
|
|
return model |
|
|
|
def freeze_base_params(self) -> None: |
|
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine |
|
tuning.""" |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
base_parts = [self.base_resnets] |
|
if isinstance(self.base_attentions, nn.ModuleList): |
|
base_parts.append(self.base_attentions) |
|
if self.base_downsamplers is not None: |
|
base_parts.append(self.base_downsamplers) |
|
for part in base_parts: |
|
for param in part.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward( |
|
self, |
|
hidden_states_base: Tensor, |
|
temb: Tensor, |
|
encoder_hidden_states: Optional[Tensor] = None, |
|
hidden_states_ctrl: Optional[Tensor] = None, |
|
conditioning_scale: Optional[float] = 1.0, |
|
attention_mask: Optional[Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
encoder_attention_mask: Optional[Tensor] = None, |
|
apply_control: bool = True, |
|
) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
h_base = hidden_states_base |
|
h_ctrl = hidden_states_ctrl |
|
|
|
base_output_states = () |
|
ctrl_output_states = () |
|
|
|
base_blocks = list(zip(self.base_resnets, self.base_attentions)) |
|
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( |
|
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base |
|
): |
|
|
|
if apply_control: |
|
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) |
|
|
|
|
|
if self.training and self.gradient_checkpointing: |
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
h_base = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(b_res), |
|
h_base, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
h_base = b_res(h_base, temb) |
|
|
|
if b_attn is not None: |
|
h_base = b_attn( |
|
h_base, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if apply_control: |
|
if self.training and self.gradient_checkpointing: |
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
h_ctrl = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(c_res), |
|
h_ctrl, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
h_ctrl = c_res(h_ctrl, temb) |
|
if c_attn is not None: |
|
h_ctrl = c_attn( |
|
h_ctrl, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if apply_control: |
|
h_base = h_base + c2b(h_ctrl) * conditioning_scale |
|
|
|
base_output_states = base_output_states + (h_base,) |
|
ctrl_output_states = ctrl_output_states + (h_ctrl,) |
|
|
|
if self.base_downsamplers is not None: |
|
b2c = self.base_to_ctrl[-1] |
|
c2b = self.ctrl_to_base[-1] |
|
|
|
|
|
if apply_control: |
|
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) |
|
|
|
h_base = self.base_downsamplers(h_base) |
|
|
|
if apply_control: |
|
h_ctrl = self.ctrl_downsamplers(h_ctrl) |
|
|
|
if apply_control: |
|
h_base = h_base + c2b(h_ctrl) * conditioning_scale |
|
|
|
base_output_states = base_output_states + (h_base,) |
|
ctrl_output_states = ctrl_output_states + (h_ctrl,) |
|
|
|
return h_base, h_ctrl, base_output_states, ctrl_output_states |
|
|
|
|
|
class ControlNetXSCrossAttnMidBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
base_channels: int, |
|
ctrl_channels: int, |
|
temb_channels: Optional[int] = None, |
|
norm_num_groups: int = 32, |
|
ctrl_max_norm_num_groups: int = 32, |
|
transformer_layers_per_block: int = 1, |
|
base_num_attention_heads: Optional[int] = 1, |
|
ctrl_num_attention_heads: Optional[int] = 1, |
|
cross_attention_dim: Optional[int] = 1024, |
|
upcast_attention: bool = False, |
|
use_linear_projection: Optional[bool] = True, |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
self.base_to_ctrl = make_zero_conv(base_channels, base_channels) |
|
|
|
self.base_midblock = UNetMidBlock2DCrossAttn( |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=base_channels, |
|
temb_channels=temb_channels, |
|
resnet_groups=norm_num_groups, |
|
cross_attention_dim=cross_attention_dim, |
|
num_attention_heads=base_num_attention_heads, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
) |
|
|
|
self.ctrl_midblock = UNetMidBlock2DCrossAttn( |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
in_channels=ctrl_channels + base_channels, |
|
out_channels=ctrl_channels, |
|
temb_channels=temb_channels, |
|
|
|
resnet_groups=find_largest_factor( |
|
gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups |
|
), |
|
cross_attention_dim=cross_attention_dim, |
|
num_attention_heads=ctrl_num_attention_heads, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
) |
|
|
|
|
|
|
|
self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
@classmethod |
|
def from_modules( |
|
cls, |
|
base_midblock: UNetMidBlock2DCrossAttn, |
|
ctrl_midblock: MidBlockControlNetXSAdapter, |
|
): |
|
base_to_ctrl = ctrl_midblock.base_to_ctrl |
|
ctrl_to_base = ctrl_midblock.ctrl_to_base |
|
ctrl_midblock = ctrl_midblock.midblock |
|
|
|
|
|
def get_first_cross_attention(midblock): |
|
return midblock.attentions[0].transformer_blocks[0].attn2 |
|
|
|
base_channels = ctrl_to_base.out_channels |
|
ctrl_channels = ctrl_to_base.in_channels |
|
transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) |
|
temb_channels = base_midblock.resnets[0].time_emb_proj.in_features |
|
num_groups = base_midblock.resnets[0].norm1.num_groups |
|
ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups |
|
base_num_attention_heads = get_first_cross_attention(base_midblock).heads |
|
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads |
|
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim |
|
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention |
|
use_linear_projection = base_midblock.attentions[0].use_linear_projection |
|
|
|
|
|
model = cls( |
|
base_channels=base_channels, |
|
ctrl_channels=ctrl_channels, |
|
temb_channels=temb_channels, |
|
norm_num_groups=num_groups, |
|
ctrl_max_norm_num_groups=ctrl_num_groups, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
base_num_attention_heads=base_num_attention_heads, |
|
ctrl_num_attention_heads=ctrl_num_attention_heads, |
|
cross_attention_dim=cross_attention_dim, |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
|
|
|
|
model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) |
|
model.base_midblock.load_state_dict(base_midblock.state_dict()) |
|
model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) |
|
model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) |
|
|
|
return model |
|
|
|
def freeze_base_params(self) -> None: |
|
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine |
|
tuning.""" |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
for param in self.base_midblock.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward( |
|
self, |
|
hidden_states_base: Tensor, |
|
temb: Tensor, |
|
encoder_hidden_states: Tensor, |
|
hidden_states_ctrl: Optional[Tensor] = None, |
|
conditioning_scale: Optional[float] = 1.0, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
attention_mask: Optional[Tensor] = None, |
|
encoder_attention_mask: Optional[Tensor] = None, |
|
apply_control: bool = True, |
|
) -> Tuple[Tensor, Tensor]: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
h_base = hidden_states_base |
|
h_ctrl = hidden_states_ctrl |
|
|
|
joint_args = { |
|
"temb": temb, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
"attention_mask": attention_mask, |
|
"cross_attention_kwargs": cross_attention_kwargs, |
|
"encoder_attention_mask": encoder_attention_mask, |
|
} |
|
|
|
if apply_control: |
|
h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) |
|
h_base = self.base_midblock(h_base, **joint_args) |
|
if apply_control: |
|
h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) |
|
h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale |
|
|
|
return h_base, h_ctrl |
|
|
|
|
|
class ControlNetXSCrossAttnUpBlock2D(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
prev_output_channel: int, |
|
ctrl_skip_channels: List[int], |
|
temb_channels: int, |
|
norm_num_groups: int = 32, |
|
resolution_idx: Optional[int] = None, |
|
has_crossattn=True, |
|
transformer_layers_per_block: int = 1, |
|
num_attention_heads: int = 1, |
|
cross_attention_dim: int = 1024, |
|
add_upsample: bool = True, |
|
upcast_attention: bool = False, |
|
use_linear_projection: Optional[bool] = True, |
|
): |
|
super().__init__() |
|
resnets = [] |
|
attentions = [] |
|
ctrl_to_base = [] |
|
|
|
num_layers = 3 |
|
|
|
self.has_cross_attention = has_crossattn |
|
self.num_attention_heads = num_attention_heads |
|
|
|
if isinstance(transformer_layers_per_block, int): |
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers |
|
|
|
for i in range(num_layers): |
|
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels |
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels |
|
|
|
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) |
|
|
|
resnets.append( |
|
ResnetBlock2D( |
|
in_channels=resnet_in_channels + res_skip_channels, |
|
out_channels=out_channels, |
|
temb_channels=temb_channels, |
|
groups=norm_num_groups, |
|
) |
|
) |
|
|
|
if has_crossattn: |
|
attentions.append( |
|
Transformer2DModel( |
|
num_attention_heads, |
|
out_channels // num_attention_heads, |
|
in_channels=out_channels, |
|
num_layers=transformer_layers_per_block[i], |
|
cross_attention_dim=cross_attention_dim, |
|
use_linear_projection=use_linear_projection, |
|
upcast_attention=upcast_attention, |
|
norm_num_groups=norm_num_groups, |
|
) |
|
) |
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers |
|
self.ctrl_to_base = nn.ModuleList(ctrl_to_base) |
|
|
|
if add_upsample: |
|
self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) |
|
else: |
|
self.upsamplers = None |
|
|
|
self.gradient_checkpointing = False |
|
self.resolution_idx = resolution_idx |
|
|
|
@classmethod |
|
def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): |
|
ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base |
|
|
|
|
|
def get_first_cross_attention(block): |
|
return block.attentions[0].transformer_blocks[0].attn2 |
|
|
|
out_channels = base_upblock.resnets[0].out_channels |
|
in_channels = base_upblock.resnets[-1].in_channels - out_channels |
|
prev_output_channels = base_upblock.resnets[0].in_channels - out_channels |
|
ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] |
|
temb_channels = base_upblock.resnets[0].time_emb_proj.in_features |
|
num_groups = base_upblock.resnets[0].norm1.num_groups |
|
resolution_idx = base_upblock.resolution_idx |
|
if hasattr(base_upblock, "attentions"): |
|
has_crossattn = True |
|
transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) |
|
num_attention_heads = get_first_cross_attention(base_upblock).heads |
|
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim |
|
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention |
|
use_linear_projection = base_upblock.attentions[0].use_linear_projection |
|
else: |
|
has_crossattn = False |
|
transformer_layers_per_block = None |
|
num_attention_heads = None |
|
cross_attention_dim = None |
|
upcast_attention = None |
|
use_linear_projection = None |
|
add_upsample = base_upblock.upsamplers is not None |
|
|
|
|
|
model = cls( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
prev_output_channel=prev_output_channels, |
|
ctrl_skip_channels=ctrl_skip_channelss, |
|
temb_channels=temb_channels, |
|
norm_num_groups=num_groups, |
|
resolution_idx=resolution_idx, |
|
has_crossattn=has_crossattn, |
|
transformer_layers_per_block=transformer_layers_per_block, |
|
num_attention_heads=num_attention_heads, |
|
cross_attention_dim=cross_attention_dim, |
|
add_upsample=add_upsample, |
|
upcast_attention=upcast_attention, |
|
use_linear_projection=use_linear_projection, |
|
) |
|
|
|
|
|
model.resnets.load_state_dict(base_upblock.resnets.state_dict()) |
|
if has_crossattn: |
|
model.attentions.load_state_dict(base_upblock.attentions.state_dict()) |
|
if add_upsample: |
|
model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) |
|
model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict()) |
|
|
|
return model |
|
|
|
def freeze_base_params(self) -> None: |
|
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine |
|
tuning.""" |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
base_parts = [self.resnets] |
|
if isinstance(self.attentions, nn.ModuleList): |
|
base_parts.append(self.attentions) |
|
if self.upsamplers is not None: |
|
base_parts.append(self.upsamplers) |
|
for part in base_parts: |
|
for param in part.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: Tensor, |
|
res_hidden_states_tuple_base: Tuple[Tensor, ...], |
|
res_hidden_states_tuple_ctrl: Tuple[Tensor, ...], |
|
temb: Tensor, |
|
encoder_hidden_states: Optional[Tensor] = None, |
|
conditioning_scale: Optional[float] = 1.0, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
attention_mask: Optional[Tensor] = None, |
|
upsample_size: Optional[int] = None, |
|
encoder_attention_mask: Optional[Tensor] = None, |
|
apply_control: bool = True, |
|
) -> Tensor: |
|
if cross_attention_kwargs is not None: |
|
if cross_attention_kwargs.get("scale", None) is not None: |
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") |
|
|
|
is_freeu_enabled = ( |
|
getattr(self, "s1", None) |
|
and getattr(self, "s2", None) |
|
and getattr(self, "b1", None) |
|
and getattr(self, "b2", None) |
|
) |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): |
|
|
|
if is_freeu_enabled: |
|
return apply_freeu( |
|
self.resolution_idx, |
|
hidden_states, |
|
res_h_base, |
|
s1=self.s1, |
|
s2=self.s2, |
|
b1=self.b1, |
|
b2=self.b2, |
|
) |
|
else: |
|
return hidden_states, res_h_base |
|
|
|
for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( |
|
self.resnets, |
|
self.attentions, |
|
self.ctrl_to_base, |
|
reversed(res_hidden_states_tuple_base), |
|
reversed(res_hidden_states_tuple_ctrl), |
|
): |
|
if apply_control: |
|
hidden_states += c2b(res_h_ctrl) * conditioning_scale |
|
|
|
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) |
|
hidden_states = torch.cat([hidden_states, res_h_base], dim=1) |
|
|
|
if self.training and self.gradient_checkpointing: |
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(resnet), |
|
hidden_states, |
|
temb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states = resnet(hidden_states, temb) |
|
|
|
if attn is not None: |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
if self.upsamplers is not None: |
|
hidden_states = self.upsamplers(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
|
|
def make_zero_conv(in_channels, out_channels=None): |
|
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) |
|
|
|
|
|
def zero_module(module): |
|
for p in module.parameters(): |
|
nn.init.zeros_(p) |
|
return module |
|
|
|
|
|
def find_largest_factor(number, max_factor): |
|
factor = max_factor |
|
if factor >= number: |
|
return number |
|
while factor != 0: |
|
residual = number % factor |
|
if residual == 0: |
|
return factor |
|
factor -= 1 |
|
|