# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from diffusers.models.resnet import Downsample2D, ResnetBlock2D from einops import rearrange logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): """ The output of [`ControlNetModel`]. Args: down_block_res_samples (`tuple[torch.Tensor]`): A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be used to condition the original UNet's downsampling activations. mid_down_block_re_sample (`torch.Tensor`): The activation of the midde block (the lowest sample resolution). Each tensor should be of shape `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. Output can be used to condition the original UNet's middle block activation. """ down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class Block2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op", ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet in zip(self.resnets): hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class IdentityModule(nn.Module): def __init__(self): super(IdentityModule, self).__init__() def forward(self, *args): if len(args) > 0: return args[0] else: return None class BasicBlock(nn.Module): def __init__(self, in_channels: int, out_channels: Optional[int] = None, stride=1, conv_shortcut: bool = False, dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, pre_norm: bool = True, eps: float = 1e-6, non_linearity: str = "swish", skip_time_act: bool = False, time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial kernel: Optional[torch.FloatTensor] = None, output_scale_factor: float = 1.0, use_in_shortcut: Optional[bool] = None, up: bool = False, down: bool = False, conv_shortcut_bias: bool = True, conv_2d_out_channels: Optional[int] = None,): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.downsample = None if stride != 1 or in_channels != out_channels: self.downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3 if stride != 1 else 1, stride=stride, padding=1 if stride != 1 else 0, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x, *args): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Block2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): # in_channels = in_channels if i == 0 else out_channels resnets.append( # ResnetBlock2D( # in_channels=in_channels, # out_channels=out_channels, # temb_channels=temb_channels, # eps=resnet_eps, # groups=resnet_groups, # dropout=dropout, # time_embedding_norm=resnet_time_scale_shift, # non_linearity=resnet_act_fn, # output_scale_factor=output_scale_factor, # pre_norm=resnet_pre_norm, BasicBlock( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) if i == num_layers - 1 else \ IdentityModule() ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ # Downsample2D( # out_channels, # use_conv=True, # out_channels=out_channels, # padding=downsample_padding, # name="op", # ) BasicBlock( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, stride=2, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class ControlProject(nn.Module): def __init__(self, num_channels, scale=8, is_empty=False) -> None: super().__init__() assert scale and scale & (scale - 1) == 0 self.is_empty = is_empty self.scale = scale if not is_empty: if scale > 1: self.down_scale = nn.AvgPool2d(scale, scale) else: self.down_scale = nn.Identity() self.out = nn.Conv2d(num_channels, num_channels, kernel_size=1, stride=1, bias=False) for p in self.out.parameters(): nn.init.zeros_(p) def forward( self, hidden_states: torch.FloatTensor): if self.is_empty: shape = list(hidden_states.shape) shape[-2] = shape[-2] // self.scale shape[-1] = shape[-1] // self.scale return torch.zeros(shape).to(hidden_states) if len(hidden_states.shape) == 5: B, F, C, H, W = hidden_states.shape hidden_states = rearrange(hidden_states, "B F C H W -> (B F) C H W") hidden_states = self.down_scale(hidden_states) hidden_states = self.out(hidden_states) hidden_states = rearrange(hidden_states, "(B F) C H W -> B F C H W", F=F) else: hidden_states = self.down_scale(hidden_states) hidden_states = self.out(hidden_states) return hidden_states class ControlNetModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: List[int] = [128, 128], out_channels: List[int] = [128, 256], groups: List[int] = [4, 8], time_embed_dim: int = 256, final_out_channels: int = 320, ): super().__init__() self.time_proj = Timesteps(128, True, downscale_freq_shift=0) self.time_embedding = TimestepEmbedding(128, time_embed_dim) self.embedding = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), nn.GroupNorm(2, 64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.GroupNorm(2, 64), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.GroupNorm(2, 128), nn.ReLU(), ) self.down_res = nn.ModuleList() self.down_sample = nn.ModuleList() for i in range(len(in_channels)): self.down_res.append( ResnetBlock2D( in_channels=in_channels[i], out_channels=out_channels[i], temb_channels=time_embed_dim, groups=groups[i] ), ) self.down_sample.append( Downsample2D( out_channels[i], use_conv=True, out_channels=out_channels[i], padding=1, name="op", ) ) self.mid_convs = nn.ModuleList() self.mid_convs.append(nn.Sequential( nn.Conv2d( in_channels=out_channels[-1], out_channels=out_channels[-1], kernel_size=3, stride=1, padding=1 ), nn.ReLU(), nn.GroupNorm(8, out_channels[-1]), nn.Conv2d( in_channels=out_channels[-1], out_channels=out_channels[-1], kernel_size=3, stride=1, padding=1 ), nn.GroupNorm(8, out_channels[-1]), )) self.mid_convs.append( nn.Conv2d( in_channels=out_channels[-1], out_channels=final_out_channels, kernel_size=1, stride=1, )) self.scale = 1.0 # nn.Parameter(torch.tensor(1.)) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). Parameters: chunk_size (`int`, *optional*): The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually over each tensor of dim=`dim`. dim (`int`, *optional*, defaults to `0`): The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) or dim=1 (sequence length). """ if dim not in [0, 1]: raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") # By default chunk size is 1 chunk_size = chunk_size or 1 def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) for child in module.children(): fn_recursive_feed_forward(child, chunk_size, dim) for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], ) -> Union[ControlNetOutput, Tuple]: timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML batch_size = sample.shape[0] timesteps = timesteps.expand(batch_size) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb_batch = self.time_embedding(t_emb) # Repeat the embeddings num_video_frames times # emb: [batch, channels] -> [batch * frames, channels] emb = emb_batch sample = self.embedding(sample) for res, downsample in zip(self.down_res, self.down_sample): sample = res(sample, emb) sample = downsample(sample, emb) sample = self.mid_convs[0](sample) + sample sample = self.mid_convs[1](sample) return { 'out': sample, 'scale': self.scale, } def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module