# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from typing import Callable, Optional, Union import torch import torch.nn.functional as F from torch import nn from diffusers.utils import deprecate, logging, maybe_allow_in_graph logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class Attention(nn.Module): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection # we make use of this private variable to know whether this class is loaded # with an deprecated state dict so that we can convert it on the fly self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim self.only_cross_attention = only_cross_attention if self.added_kv_proj_dim is None and self.only_cross_attention: raise ValueError( "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." ) if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None if spatial_norm_dim is not None: self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) else: self.spatial_norm = None if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": self.norm_cross = nn.LayerNorm(cross_attention_dim) elif cross_attention_norm == "group_norm": if self.added_kv_proj_dim is not None: # The given `encoder_hidden_states` are initially of shape # (batch_size, seq_len, added_kv_proj_dim) before being projected # to (batch_size, seq_len, cross_attention_dim). The norm is applied # before the projection, so we need to use `added_kv_proj_dim` as # the number of channels for the group norm. norm_cross_num_channels = added_kv_proj_dim else: norm_cross_num_channels = cross_attention_dim self.norm_cross = nn.GroupNorm( num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True ) else: raise ValueError( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) else: self.to_k = None self.to_v = None if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim) self.to_out = nn.ModuleList([]) self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias)) self.to_out.append(nn.Dropout(dropout)) # set attention processor # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: # processor = ( # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() # ) # Note: efficient attention is not used. We can use efficient attention to speed up. processor = AttnProcessor() self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): # if current processor is in `self._modules` and if passed `processor` is not, we need to # pop `processor` from `self._modules` if ( hasattr(self, "processor") and isinstance(self.processor, torch.nn.Module) and not isinstance(processor, torch.nn.Module) ): logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") self._modules.pop("processor") self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, return_attntion_probs=False, **cross_attention_kwargs): # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, return_attntion_probs=return_attntion_probs, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor, out_dim=3): head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def get_attention_scores(self, query, key, attention_mask=None): dtype = query.dtype if self.upcast_attention: query = query.float() key = key.float() if attention_mask is None: baddbmm_input = torch.empty( query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device ) beta = 0 else: baddbmm_input = attention_mask beta = 1 attention_scores = torch.baddbmm( baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, ) del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) del attention_scores attention_probs = attention_probs.to(dtype) return attention_probs def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3): if batch_size is None: deprecate( "batch_size=None", "0.0.15", ( "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" " `prepare_attention_mask` when preparing the attention_mask." ), ) batch_size = 1 head_size = self.heads if attention_mask is None: return attention_mask current_length: int = attention_mask.shape[-1] if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: attention_mask = attention_mask.repeat_interleave(head_size, dim=0) elif out_dim == 4: attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.repeat_interleave(head_size, dim=1) return attention_mask def norm_encoder_hidden_states(self, encoder_hidden_states): assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states" if isinstance(self.norm_cross, nn.LayerNorm): encoder_hidden_states = self.norm_cross(encoder_hidden_states) elif isinstance(self.norm_cross, nn.GroupNorm): # Group norm norms along the channels dimension and expects # input to be in the shape of (N, C, *). In this case, we want # to norm along the hidden dimension, so we need to move # (batch_size, sequence_length, hidden_size) -> # (batch_size, hidden_size, sequence_length) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) encoder_hidden_states = self.norm_cross(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.transpose(1, 2) else: assert False return encoder_hidden_states class AttnProcessor: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call_fast__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) inner_dim = hidden_states.shape[-1] if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, return_attntion_probs=False, attn_key=None, attn_process_fn=None, return_cond_ca_only=False, return_token_ca_only=None, offload_cross_attn_to_cpu=False, save_attn_to_dict=None, save_keys=None, enable_flash_attn=True, ): """ attn_key: current key (a tuple of hierarchy index (up/mid/down, stage id, block id, sub-block id), sub block id should always be 0 in SD UNet) save_attn_to_dict: pass in a dict to save to dict """ cross_attn = encoder_hidden_states is not None if (not cross_attn) or ( (attn_process_fn is None) and not (save_attn_to_dict is not None and (save_keys is None or (tuple(attn_key) in save_keys))) and not return_attntion_probs): with torch.backends.cuda.sdp_kernel(enable_flash=enable_flash_attn, enable_math=True, enable_mem_efficient=enable_flash_attn): return self.__call_fast__(attn, hidden_states, encoder_hidden_states, attention_mask, temb) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # Currently only process cross-attention if attn_process_fn is not None and cross_attn: attention_probs_before_process = attention_probs.clone() attention_probs = attn_process_fn(attention_probs, query, key, value, attn_key=attn_key, cross_attn=cross_attn, batch_size=batch_size, heads=attn.heads) else: attention_probs_before_process = attention_probs hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor if return_attntion_probs or save_attn_to_dict is not None: # Recover batch dimension: (batch_size, heads, flattened_2d, text_tokens) attention_probs_unflattened = attention_probs_before_process.unflatten(dim=0, sizes=(batch_size, attn.heads)) if return_token_ca_only is not None: # (batch size, n heads, 2d dimension, num text tokens) if isinstance(return_token_ca_only, int): # return_token_ca_only: an integer attention_probs_unflattened = attention_probs_unflattened[:, :, :, return_token_ca_only:return_token_ca_only+1] else: # return_token_ca_only: A 1d index tensor attention_probs_unflattened = attention_probs_unflattened[:, :, :, return_token_ca_only] if return_cond_ca_only: assert batch_size % 2 == 0, f"Samples are not in pairs: {batch_size} samples" attention_probs_unflattened = attention_probs_unflattened[batch_size // 2:] if offload_cross_attn_to_cpu: attention_probs_unflattened = attention_probs_unflattened.cpu() if save_attn_to_dict is not None and (save_keys is None or (tuple(attn_key) in save_keys)): save_attn_to_dict[tuple(attn_key)] = attention_probs_unflattened if return_attntion_probs: return hidden_states, attention_probs_unflattened return hidden_states # For typing AttentionProcessor = AttnProcessor class SpatialNorm(nn.Module): """ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 """ def __init__( self, f_channels, zq_channels, ): super().__init__() self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) def forward(self, f, zq): f_size = f.shape[-2:] zq = F.interpolate(zq, size=f_size, mode="nearest") norm_f = self.norm_layer(f) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f