# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Changes were made to this source code by Yuwei Guo. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from diffusers import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.utils import BaseOutput, logging from einops import rearrange, repeat from torch import nn from torch.nn import functional as F from .resnet import InflatedConv3d from .unet_blocks import ( CrossAttnDownBlock3D, DownBlock3D, UNetMidBlock3DCrossAttn, get_down_block, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ImageControlNetOutput(BaseOutput): down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class ImageControlNetConditioningEmbedding(nn.Module): def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = InflatedConv3d( conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 ) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append( InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) ) self.blocks.append( InflatedConv3d( channel_in, channel_out, kernel_size=3, padding=1, stride=2 ) ) self.conv_out = zero_module( InflatedConv3d( block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1, ) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ImageControlNetModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, use_motion_module=True, motion_module_resolutions=(1, 2, 4, 8), motion_module_mid_block=False, motion_module_type="Vanilla", motion_module_kwargs={ "num_attention_heads": 8, "num_transformer_block": 1, "attention_block_types": ["Temporal_Self"], "temporal_position_encoding": True, "temporal_position_encoding_max_len": 32, "temporal_attention_dim_div": 1, "causal_temporal_attention": False, }, concate_conditioning_mask: bool = True, use_simplified_condition_embedding: bool = False, set_noisy_sample_input_to_zero: bool = False, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len( only_cross_attention ) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( down_block_types ): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) # input self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = InflatedConv3d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding, ) if concate_conditioning_mask: conditioning_channels = conditioning_channels + 1 self.concate_conditioning_mask = concate_conditioning_mask # control net conditioning embedding if use_simplified_condition_embedding: self.controlnet_cond_embedding = zero_module( InflatedConv3d( conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding, ) ) else: self.controlnet_cond_embedding = ImageControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) self.use_simplified_condition_embedding = use_simplified_condition_embedding # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding( projection_class_embeddings_input_dim, time_embed_dim ) else: self.class_embedding = None self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): res = 2**i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=( attention_head_dim[i] if attention_head_dim[i] is not None else output_channel ), downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, use_inflated_groupnorm=True, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = InflatedConv3d( output_channel, output_channel, kernel_size=1 ) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = InflatedConv3d( output_channel, output_channel, kernel_size=1 ) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = InflatedConv3d( mid_block_channel, mid_block_channel, kernel_size=1 ) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock3DCrossAttn( in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, attn_num_head_channels=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, use_inflated_groupnorm=True, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, controlnet_additional_kwargs: dict = {}, ): controlnet = cls( in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, **controlnet_additional_kwargs, ) if load_weights_from_unet: m, u = controlnet.conv_in.load_state_dict( cls.image_layer_filter(unet.conv_in.state_dict()), strict=False ) assert len(u) == 0 m, u = controlnet.time_proj.load_state_dict( cls.image_layer_filter(unet.time_proj.state_dict()), strict=False ) assert len(u) == 0 m, u = controlnet.time_embedding.load_state_dict( cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False ) assert len(u) == 0 if controlnet.class_embedding: m, u = controlnet.class_embedding.load_state_dict( cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False, ) assert len(u) == 0 m, u = controlnet.down_blocks.load_state_dict( cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False ) assert len(u) == 0 m, u = controlnet.mid_block.load_state_dict( cls.image_layer_filter(unet.mid_block.state_dict()), strict=False ) assert len(u) == 0 return controlnet @staticmethod def image_layer_filter(state_dict): new_state_dict = {} for name, param in state_dict.items(): if "motion_modules." in name or "lora" in name: continue new_state_dict[name] = param return new_state_dict # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = ( num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size ) if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice( module: torch.nn.Module, slice_size: List[int] ): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_mask: Optional[torch.FloatTensor] = None, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, ) -> Union[ImageControlNetOutput, Tuple]: # set input noise to zero if self.set_noisy_sample_input_to_zero: sample = torch.zeros_like(sample).to(sample.device) # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) encoder_hidden_states = encoder_hidden_states.repeat( sample.shape[0] // encoder_hidden_states.shape[0], 1, 1 ) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: if class_labels is None: raise ValueError( "class_labels should be provided when num_class_embeds > 0" ) if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process sample = self.conv_in(sample) if self.concate_conditioning_mask: controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if ( hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention ): sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, # cross_attention_kwargs=cross_attention_kwargs, ) # 5. controlnet blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip( down_block_res_samples, self.controlnet_down_blocks ): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + ( down_block_res_sample, ) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace( -1, 0, len(down_block_res_samples) + 1, device=sample.device ) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [ sample * scale for sample, scale in zip(down_block_res_samples, scales) ] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [ sample * conditioning_scale for sample in down_block_res_samples ] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean( mid_block_res_sample, dim=(2, 3), keepdim=True ) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return ImageControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, ) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors( name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor], ): if hasattr(module, "set_processor"): processors[f"{name}.processor"] = module.processor for sub_name, child in module.named_children(): if "temporal_transformer" not in sub_name: fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): if "temporal_transformer" not in name: fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] ): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): if "temporal_transformer" not in sub_name: fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): if "temporal_transformer" not in name: fn_recursive_attn_processor(name, module, processor) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module