diff --git "a/models/unet_2d_blocks.py" "b/models/unet_2d_blocks.py" --- "a/models/unet_2d_blocks.py" +++ "b/models/unet_2d_blocks.py" @@ -1,4 +1,4 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. +# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,12 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional, Tuple + import numpy as np import torch +import torch.nn.functional as F from torch import nn -from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel -from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, Upsample2D +from diffusers.utils import is_torch_version, logging +from diffusers.models.attention import AdaGroupNorm +from models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from models.dual_transformer_2d import DualTransformer2DModel +from models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from models.transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( @@ -28,16 +38,30 @@ def get_down_block( add_downsample, resnet_eps, resnet_act_fn, - attn_num_head_channels, + transformer_layers_per_block=1, + num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, ): - down_block_type = down_block_type[7:] if down_block_type.startswith( - "UNetRes") else down_block_type + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, @@ -49,26 +73,46 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, - add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnDownBlock2D") + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -78,10 +122,32 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, + num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -93,6 +159,7 @@ def get_down_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( @@ -103,8 +170,8 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, - downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( @@ -116,6 +183,7 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( @@ -127,7 +195,31 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, - attn_num_head_channels=attn_num_head_channels, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, ) raise ValueError(f"{down_block_type} does not exist.") @@ -142,15 +234,29 @@ def get_up_block( add_upsample, resnet_eps, resnet_act_fn, - attn_num_head_channels, + transformer_layers_per_block=1, + num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, ): - up_block_type = up_block_type[7:] if up_block_type.startswith( - "UNetRes") else up_block_type + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, @@ -162,13 +268,29 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: - raise ValueError( - "cross_attention_dim must be specified for CrossAttnUpBlock2D") + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, @@ -178,23 +300,52 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, + num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, - add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -206,6 +357,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( @@ -217,7 +369,8 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, - attn_num_head_channels=attn_num_head_channels, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( @@ -228,6 +381,8 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( @@ -238,8 +393,33 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, - attn_num_head_channels=attn_num_head_channels, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, ) + raise ValueError(f"{up_block_type} does not exist.") @@ -251,19 +431,17 @@ class UNetMidBlock2D(nn.Module): dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", + resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", + add_attention: bool = True, + attention_head_dim=1, output_scale_factor=1.0, ): super().__init__() - - self.attention_type = attention_type - resnet_groups = resnet_groups if resnet_groups is not None else min( - in_channels // 4, 32) + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention # there is always at least one resnet resnets = [ @@ -282,16 +460,32 @@ class UNetMidBlock2D(nn.Module): ] attentions = [] + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + for _ in range(num_layers): - attentions.append( - AttentionBlock( - in_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) ) - ) + else: + attentions.append(None) + resnets.append( ResnetBlock2D( in_channels=in_channels, @@ -310,13 +504,12 @@ class UNetMidBlock2D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, encoder_states=None): + def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.attention_type == "default": - hidden_states = attn(hidden_states) - else: - hidden_states = attn(hidden_states, encoder_states) + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) return hidden_states @@ -329,24 +522,24 @@ class UNetMidBlock2DCrossAttn(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", + num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, + upcast_attention=False, ): super().__init__() - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels - resnet_groups = resnet_groups if resnet_groups is not None else min( - in_channels // 4, 32) + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ @@ -369,20 +562,21 @@ class UNetMidBlock2DCrossAttn(nn.Module): if not dual_cross_attention: attentions.append( Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, + num_attention_heads, + in_channels // num_attention_heads, in_channels=in_channels, - num_layers=1, + num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, + num_attention_heads, + in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, @@ -407,34 +601,151 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Rich-Text: ignore the features + hidden_states, _ = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) ) - for attn in self.attentions: - attn._set_attention_slice(slice_size) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers( - use_memory_efficient_attention_xformers) + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, - text_format_dict={}): - hidden_states, _ = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states, - text_format_dict).sample + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) return hidden_states @@ -453,17 +764,21 @@ class AttnDownBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", + attention_head_dim=1, output_scale_factor=1.0, downsample_padding=1, - add_downsample=True, + downsample_type="conv", ): super().__init__() resnets = [] attentions = [] + self.downsample_type = downsample_type - self.attention_type = attention_type + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -482,19 +797,24 @@ class AttnDownBlock2D(nn.Module): ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_downsample: + if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( @@ -502,20 +822,42 @@ class AttnDownBlock2D(nn.Module): ) ] ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) else: self.downsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, upsample_size=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) hidden_states = attn(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) output_states += (hidden_states,) @@ -530,27 +872,28 @@ class CrossAttnDownBlock2D(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, + num_attention_heads=1, cross_attention_dim=1280, - attention_type="default", output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, + upcast_attention=False, ): super().__init__() resnets = [] attentions = [] - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -571,21 +914,22 @@ class CrossAttnDownBlock2D(nn.Module): if not dual_cross_attention: attentions.append( Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, + num_attention_heads, + out_channels // num_attention_heads, in_channels=out_channels, - num_layers=1, + num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, ) ) else: attentions.append( DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, + num_attention_heads, + out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, @@ -608,30 +952,15 @@ class CrossAttnDownBlock2D(nn.Module): self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers( - use_memory_efficient_attention_xformers) - - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, - text_format_dict={}): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -646,25 +975,43 @@ class CrossAttnDownBlock2D(nn.Module): return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb) + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward( - attn, return_dict=False), hidden_states, encoder_hidden_states, - text_format_dict + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, )[0] else: + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, - text_format_dict=text_format_dict).sample + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -733,18 +1080,25 @@ class DownBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -800,7 +1154,7 @@ class DownEncoderBlock2D(nn.Module): def forward(self, hidden_states): for resnet in self.resnets: - hidden_states, _ = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -821,7 +1175,7 @@ class AttnDownEncoderBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, + attention_head_dim=1, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, @@ -830,6 +1184,12 @@ class AttnDownEncoderBlock2D(nn.Module): resnets = [] attentions = [] + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( @@ -847,12 +1207,17 @@ class AttnDownEncoderBlock2D(nn.Module): ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -872,6 +1237,7 @@ class AttnDownEncoderBlock2D(nn.Module): def forward(self, hidden_states): for resnet, attn in zip(self.resnets, self.attentions): + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) @@ -894,17 +1260,19 @@ class AttnSkipDownBlock2D(nn.Module): resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", + attention_head_dim=1, output_scale_factor=np.sqrt(2.0), - downsample_padding=1, add_downsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) - self.attention_type = attention_type + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels @@ -924,11 +1292,17 @@ class AttnSkipDownBlock2D(nn.Module): ) ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -948,10 +1322,8 @@ class AttnSkipDownBlock2D(nn.Module): down=True, kernel="fir", ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None @@ -961,6 +1333,7 @@ class AttnSkipDownBlock2D(nn.Module): output_states = () for resnet, attn in zip(self.resnets, self.attentions): + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states += (hidden_states,) @@ -1030,10 +1403,8 @@ class SkipDownBlock2D(nn.Module): down=True, kernel="fir", ) - self.downsamplers = nn.ModuleList( - [FirDownsample2D(out_channels, out_channels=out_channels)]) - self.skip_conv = nn.Conv2d( - 3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None @@ -1043,6 +1414,7 @@ class SkipDownBlock2D(nn.Module): output_states = () for resnet in self.resnets: + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) output_states += (hidden_states,) @@ -1058,11 +1430,10 @@ class SkipDownBlock2D(nn.Module): return hidden_states, output_states, skip_sample -class AttnUpBlock2D(nn.Module): +class ResnetDownsampleBlock2D(nn.Module): def __init__( self, in_channels: int, - prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, @@ -1072,25 +1443,18 @@ class AttnUpBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attention_type="default", - attn_num_head_channels=1, output_scale_factor=1.0, - add_upsample=True, + add_downsample=True, + skip_time_act=False, ): super().__init__() resnets = [] - attentions = [] - - self.attention_type = attention_type for i in range(num_layers): - res_skip_channels = in_channels if ( - i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - + in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, + in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, @@ -1100,51 +1464,76 @@ class AttnUpBlock2D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=resnet_groups, + skip_time_act=skip_time_act, ) ) - self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) else: - self.upsamplers = None + self.downsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat( - [hidden_states, res_hidden_states], dim=1) + self.gradient_checkpointing = False - hidden_states, _ = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + def forward(self, hidden_states, temb=None): + output_states = () - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: - return hidden_states + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + output_states = output_states + (hidden_states,) -class CrossAttnUpBlock2D(nn.Module): + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, - prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, @@ -1153,30 +1542,29 @@ class CrossAttnUpBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, + attention_head_dim=1, cross_attention_dim=1280, - attention_type="default", output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, + add_downsample=True, + skip_time_act=False, only_cross_attention=False, + cross_attention_norm=None, ): super().__init__() + + self.has_cross_attention = True + resnets = [] attentions = [] - self.attention_type = attention_type - self.attn_num_head_channels = attn_num_head_channels + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): - res_skip_channels = in_channels if ( - i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - + in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, + in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, @@ -1186,83 +1574,80 @@ class CrossAttnUpBlock2D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - ) - ) - else: - attentions.append( - DualTransformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, ) + ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) + else: + self.downsamplers = None self.gradient_checkpointing = False - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers( - use_memory_efficient_attention_xformers) - def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - text_format_dict={} + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat( - [hidden_states, res_hidden_states], dim=1) + output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -1274,83 +1659,86 @@ class CrossAttnUpBlock2D(nn.Module): return custom_forward + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward( - attn, return_dict=False), hidden_states, encoder_hidden_states, - text_format_dict + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, )[0] else: + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, - text_format_dict=text_format_dict).sample - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) - return hidden_states + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + output_states = output_states + (hidden_states,) -class UpBlock2D(nn.Module): + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): def __init__( self, in_channels: int, - prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample=False, ): super().__init__() resnets = [] for i in range(num_layers): - res_skip_channels = in_channels if ( - i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, + in_channels=in_channels, out_channels=out_channels, + dropout=dropout, temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) else: - self.upsamplers = None + self.downsamplers = None self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): - for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat( - [hidden_states, res_hidden_states], dim=1) + def forward(self, hidden_states, temb=None): + output_states = () + for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): @@ -1359,78 +1747,163 @@ class UpBlock2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: + # Rich-Text: ignore the features hidden_states, _ = resnet(hidden_states, temb) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + output_states += (hidden_states,) - return hidden_states + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states -class UpDecoderBlock2D(nn.Module): +class KCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, + temb_channels: int, + cross_attention_dim: int, dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor=1.0, - add_upsample=True, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample=True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", ): super().__init__() resnets = [] + attentions = [] + + self.has_cross_attention = True for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( - in_channels=input_channels, + in_channels=in_channels, out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, ) ) self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) - if add_upsample: - self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) else: - self.upsamplers = None + self.downsamplers = None - def forward(self, hidden_states): - for resnet in self.resnets: - hidden_states, _ = resnet(hidden_states, temb=None) + self.gradient_checkpointing = False - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + output_states = () - return hidden_states + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) -class AttnUpDecoderBlock2D(nn.Module): + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, + prev_output_channel: int, out_channels: int, + temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -1438,22 +1911,31 @@ class AttnUpDecoderBlock2D(nn.Module): resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, + attention_head_dim=1, output_scale_factor=1.0, - add_upsample=True, + upsample_type="conv", ): super().__init__() resnets = [] attentions = [] + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( - in_channels=input_channels, + in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -1464,75 +1946,109 @@ class AttnUpDecoderBlock2D(nn.Module): ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - if add_upsample: + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( - [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states, _ = resnet(hidden_states, temb=None) + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) return hidden_states -class AttnSkipUpBlock2D(nn.Module): +class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, - prev_output_channel: int, out_channels: int, + prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, - attention_type="default", - output_scale_factor=np.sqrt(2.0), - upsample_padding=1, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, ): super().__init__() - self.attentions = nn.ModuleList([]) - self.resnets = nn.ModuleList([]) + resnets = [] + attentions = [] - self.attention_type = attention_type + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads for i in range(num_layers): - res_skip_channels = in_channels if ( - i == num_layers - 1) else out_channels + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels - self.resnets.append( + resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=min(resnet_in_channels + - res_skip_channels // 4, 32), - groups_out=min(out_channels // 4, 32), + groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -1540,76 +2056,107 @@ class AttnSkipUpBlock2D(nn.Module): pre_norm=resnet_pre_norm, ) ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) - self.attentions.append( - AttentionBlock( - out_channels, - num_head_channels=attn_num_head_channels, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - ) - ) - - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=( - 3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None + self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): - for resnet in self.resnets: + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat( - [hidden_states, res_hidden_states], dim=1) - - hidden_states, _ = resnet(hidden_states, temb) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = self.attentions[0](hidden_states) + if self.training and self.gradient_checkpointing: - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) + return custom_forward - skip_sample = skip_sample + skip_sample_states + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - hidden_states = self.resnet_up(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states, skip_sample + return hidden_states -class SkipUpBlock2D(nn.Module): +class UpBlock2D(nn.Module): def __init__( self, in_channels: int, @@ -1621,28 +2168,25 @@ class SkipUpBlock2D(nn.Module): resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + resnet_groups: int = 32, resnet_pre_norm: bool = True, - output_scale_factor=np.sqrt(2.0), + output_scale_factor=1.0, add_upsample=True, - upsample_padding=1, ): super().__init__() - self.resnets = nn.ModuleList([]) + resnets = [] for i in range(num_layers): - res_skip_channels = in_channels if ( - i == num_layers - 1) else out_channels + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels - self.resnets.append( + resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=min( - (resnet_in_channels + res_skip_channels) // 4, 32), - groups_out=min(out_channels // 4, 32), + groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -1651,205 +2195,1004 @@ class SkipUpBlock2D(nn.Module): ) ) - self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + self.resnets = nn.ModuleList(resnets) + if add_upsample: - self.resnet_up = ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=min(out_channels // 4, 32), - groups_out=min(out_channels // 4, 32), - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - use_in_shortcut=True, - up=True, - kernel="fir", - ) - self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=( - 3, 3), stride=(1, 1), padding=(1, 1)) - self.skip_norm = torch.nn.GroupNorm( - num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True - ) - self.act = nn.SiLU() + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: - self.resnet_up = None - self.skip_conv = None - self.skip_norm = None - self.act = None + self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat( - [hidden_states, res_hidden_states], dim=1) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states, _ = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - if skip_sample is not None: - skip_sample = self.upsampler(skip_sample) - else: - skip_sample = 0 + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) - if self.resnet_up is not None: - skip_sample_states = self.skip_norm(hidden_states) - skip_sample_states = self.act(skip_sample_states) - skip_sample_states = self.skip_conv(skip_sample_states) + return custom_forward - skip_sample = skip_sample + skip_sample_states + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) - hidden_states = self.resnet_up(hidden_states, temb) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) - return hidden_states, skip_sample + return hidden_states -class ResnetBlock2D(nn.Module): +class UpDecoderBlock2D(nn.Module): def __init__( self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - kernel=None, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, output_scale_factor=1.0, - use_in_shortcut=None, - up=False, - down=False, + add_upsample=True, + temb_channels=None, ): super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.up = up - self.down = down - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - self.norm1 = torch.nn.GroupNorm( - num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if temb_channels is not None: - if self.time_embedding_norm == "default": - time_emb_proj_out_channels = out_channels - elif self.time_embedding_norm == "scale_shift": - time_emb_proj_out_channels = out_channels * 2 - else: - raise ValueError( - f"unknown time_embedding_norm : {self.time_embedding_norm} ") - - self.time_emb_proj = torch.nn.Linear( - temb_channels, time_emb_proj_out_channels) - else: - self.time_emb_proj = None - - self.norm2 = torch.nn.GroupNorm( - num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.upsample = self.downsample = None - if self.up: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.upsample = partial( - F.interpolate, scale_factor=2.0, mode="nearest") - else: - self.upsample = Upsample2D(in_channels, use_conv=False) - elif self.down: - if kernel == "fir": - fir_kernel = (1, 3, 3, 1) - self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) - elif kernel == "sde_vp": - self.downsample = partial( - F.avg_pool2d, kernel_size=2, stride=2) - else: - self.downsample = Downsample2D( - in_channels, use_conv=False, padding=1, name="op") - - self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, input_tensor, temb, inject_states=None): - hidden_states = input_tensor + resnets = [] - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels - if self.upsample is not None: - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) - hidden_states = self.conv1(hidden_states) + self.resnets = nn.ModuleList(resnets) - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[ - :, :, None, None] + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb + def forward(self, hidden_states, temb=None): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) - hidden_states = self.norm2(hidden_states) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift + return hidden_states - hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=1.0, + add_upsample=True, + temb_channels=None, + ): + super().__init__() + resnets = [] + attentions = [] - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels - if inject_states is not None: - output_tensor = (input_tensor + inject_states) / \ - self.output_scale_factor - else: - output_tensor = (input_tensor + hidden_states) / \ - self.output_scale_factor + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels - return output_tensor, hidden_states + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states, temb=None): + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim=1, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor=np.sqrt(2.0), + add_upsample=True, + upsample_padding=1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + skip_time_act=False, + only_cross_attention=False, + cross_attention_norm=None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample=True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim=1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + temb, + attention_mask, + cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, + )[0] + else: + # Rich-Text: ignore the features + hidden_states, _ = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states, height, weight): + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states