from importlib import import_module from typing import Callable, Optional, Union import torch import torch.nn.functional as F from torch import nn from diffusers.utils import deprecate, logging from diffusers.utils.import_utils import is_xformers_available from diffusers.models.attention import Attention def register_attention_processor( model: Optional[nn.Module] = None, processor_type: str = "MasaCtrlProcessor", **attn_args, ): """ Args: model: a unet model or a list of unet models processor_type: the type of the processor """ if not isinstance(model, (list, tuple)): model = [model] if processor_type == "MasaCtrlProcessor": processor = MasaCtrlProcessor(**attn_args) else: processor = AttnProcessor() for m in model: m.set_attn_processor(processor) print(f"Model {m.__class__.__name__} is registered attention processor: {processor_type}") class AttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class MasaCtrlProcessor(nn.Module): """ Mutual Self-attention Processor for diffusers library. Note that the all attention layers should register the same processor. """ MODEL_TYPE = { "SD": 16, "SDXL": 70 } def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_layers=32, total_steps=50, model_type="SD"): """ Mutual self-attention control for Stable-Diffusion model Args: start_step: the step to start mutual self-attention control start_layer: the layer to start mutual self-attention control layer_idx: list of the layers to apply mutual self-attention control step_idx: list the steps to apply mutual self-attention control total_steps: the total number of steps, must be same to the denoising steps used in denoising scheduler model_type: the model type, SD or SDXL """ super().__init__() self.total_steps = total_steps self.total_layers = self.MODEL_TYPE.get(model_type, 16) self.start_step = start_step self.start_layer = start_layer self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) print("MasaCtrl at denoising steps: ", self.step_idx) print("MasaCtrl at U-Net layers: ", self.layer_idx) self.cur_step = 0 self.cur_att_layer = 0 self.num_attn_layers = total_layers def after_step(self): pass def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, ): out = self.attn_forward( attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale, ) self.cur_att_layer += 1 if self.cur_att_layer == self.num_attn_layers: self.cur_att_layer = 0 self.cur_step += 1 self.cur_step %= self.total_steps # after step self.after_step() return out def masactrl_forward( self, query, key, value, ): """ Rearrange the key and value for mutual self-attention control """ ku_src, ku_tgt, kc_src, kc_tgt = key.chunk(4) vu_src, vu_tgt, vc_src, vc_tgt = value.chunk(4) k_rearranged = torch.cat([ku_src, ku_src, kc_src, kc_src]) v_rearranged = torch.cat([vu_src, vu_src, vc_src, vc_src]) return query, k_rearranged, v_rearranged def attn_forward( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, *args, **kwargs, ): cur_transformer_layer = self.cur_att_layer // 2 residual = hidden_states is_cross = True if encoder_hidden_states is not None else False if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states, *args) # mutual self-attention control if not is_cross and self.cur_step in self.step_idx and cur_transformer_layer in self.layer_idx: query, key, value = self.masactrl_forward(query, key, value) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states, *args) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states