# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Union import torch import torch.nn.functional as F from torch import nn from diffusers.utils import logging from diffusers.models.normalization import RMSNorm try: # from .dcformer import DCMHAttention from .customer_attention_processor import Attention, CustomLiteLAProcessor2_0, CustomerAttnProcessor2_0 except ImportError: # from dcformer import DCMHAttention from customer_attention_processor import Attention, CustomLiteLAProcessor2_0, CustomerAttnProcessor2_0 logger = logging.get_logger(__name__) def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore """Repeat `val` for `repeat_time` times and return the list or val if list/tuple.""" if isinstance(x, (list, tuple)): return list(x) return [x for _ in range(repeat_time)] def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore """Return tuple with min_len by repeating element at idx_repeat.""" # convert to list first x = val2list(x) # repeat elements if necessary if len(x) > 0: x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] return tuple(x) def t2i_modulate(x, shift, scale): return x * (1 + scale) + shift def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]: if isinstance(kernel_size, tuple): return tuple([get_same_padding(ks) for ks in kernel_size]) else: assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number" return kernel_size // 2 class ConvLayer(nn.Module): def __init__( self, in_dim: int, out_dim: int, kernel_size=3, stride=1, dilation=1, groups=1, padding: Union[int, None] = None, use_bias=False, norm=None, act=None, ): super().__init__() if padding is None: padding = get_same_padding(kernel_size) padding *= dilation self.in_dim = in_dim self.out_dim = out_dim self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.groups = groups self.padding = padding self.use_bias = use_bias self.conv = nn.Conv1d( in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=use_bias, ) if norm is not None: self.norm = RMSNorm(out_dim, elementwise_affine=False) else: self.norm = None if act is not None: self.act = nn.SiLU(inplace=True) else: self.act = None def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) if self.norm: x = self.norm(x) if self.act: x = self.act(x) return x class GLUMBConv(nn.Module): def __init__( self, in_features: int, hidden_features: int, out_feature=None, kernel_size=3, stride=1, padding: Union[int, None] = None, use_bias=False, norm=(None, None, None), act=("silu", "silu", None), dilation=1, ): out_feature = out_feature or in_features super().__init__() use_bias = val2tuple(use_bias, 3) norm = val2tuple(norm, 3) act = val2tuple(act, 3) self.glu_act = nn.SiLU(inplace=False) self.inverted_conv = ConvLayer( in_features, hidden_features * 2, 1, use_bias=use_bias[0], norm=norm[0], act=act[0], ) self.depth_conv = ConvLayer( hidden_features * 2, hidden_features * 2, kernel_size, stride=stride, groups=hidden_features * 2, padding=padding, use_bias=use_bias[1], norm=norm[1], act=None, dilation=dilation, ) self.point_conv = ConvLayer( hidden_features, out_feature, 1, use_bias=use_bias[2], norm=norm[2], act=act[2], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.transpose(1, 2) x = self.inverted_conv(x) x = self.depth_conv(x) x, gate = torch.chunk(x, 2, dim=1) gate = self.glu_act(gate) x = x * gate x = self.point_conv(x) x = x.transpose(1, 2) return x class LinearTransformerBlock(nn.Module): """ A Sana block with global shared adaptive layer norm (adaLN-single) conditioning. """ def __init__( self, dim, num_attention_heads, attention_head_dim, use_adaln_single=True, cross_attention_dim=None, added_kv_proj_dim=None, context_pre_only=False, mlp_ratio=4.0, add_cross_attention=False, add_cross_attention_dim=None, qk_norm=None, ): super().__init__() self.norm1 = RMSNorm(dim, elementwise_affine=False, eps=1e-6) self.attn = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, added_kv_proj_dim=added_kv_proj_dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, qk_norm=qk_norm, processor=CustomLiteLAProcessor2_0(), ) self.add_cross_attention = add_cross_attention self.context_pre_only = context_pre_only if add_cross_attention and add_cross_attention_dim is not None: self.cross_attn = Attention( query_dim=dim, cross_attention_dim=add_cross_attention_dim, added_kv_proj_dim=add_cross_attention_dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=context_pre_only, bias=True, qk_norm=qk_norm, processor=CustomerAttnProcessor2_0(), ) self.norm2 = RMSNorm(dim, 1e-06, elementwise_affine=False) self.ff = GLUMBConv( in_features=dim, hidden_features=int(dim * mlp_ratio), use_bias=(True, True, False), norm=(None, None, None), act=("silu", "silu", None), ) self.use_adaln_single = use_adaln_single if use_adaln_single: self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, encoder_attention_mask: torch.FloatTensor = None, rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None, rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None, temb: torch.FloatTensor = None, ): N = hidden_states.shape[0] # step 1: AdaLN single if self.use_adaln_single: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + temb.reshape(N, 6, -1) ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) if self.use_adaln_single: norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa # step 2: attention if not self.add_cross_attention: attn_output, encoder_hidden_states = self.attn( hidden_states=norm_hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, ) else: attn_output, _ = self.attn( hidden_states=norm_hidden_states, attention_mask=attention_mask, encoder_hidden_states=None, encoder_attention_mask=None, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=None, ) if self.use_adaln_single: attn_output = gate_msa * attn_output hidden_states = attn_output + hidden_states if self.add_cross_attention: attn_output = self.cross_attn( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, rotary_freqs_cis=rotary_freqs_cis, rotary_freqs_cis_cross=rotary_freqs_cis_cross, ) hidden_states = attn_output + hidden_states # step 3: add norm norm_hidden_states = self.norm2(hidden_states) if self.use_adaln_single: norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp # step 4: feed forward ff_output = self.ff(norm_hidden_states) if self.use_adaln_single: ff_output = gate_mlp * ff_output hidden_states = hidden_states + ff_output return hidden_states