|
from importlib import import_module |
|
from typing import Callable, Optional, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from diffusers.utils import deprecate, logging |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from diffusers.models.attention import Attention |
|
|
|
|
|
def register_attention_processor( |
|
model: Optional[nn.Module] = None, |
|
processor_type: str = "MasaCtrlProcessor", |
|
**attn_args, |
|
): |
|
""" |
|
Args: |
|
model: a unet model or a list of unet models |
|
processor_type: the type of the processor |
|
""" |
|
if not isinstance(model, (list, tuple)): |
|
model = [model] |
|
|
|
if processor_type == "MasaCtrlProcessor": |
|
processor = MasaCtrlProcessor(**attn_args) |
|
else: |
|
processor = AttnProcessor() |
|
|
|
for m in model: |
|
m.set_attn_processor(processor) |
|
print(f"Model {m.__class__.__name__} is registered attention processor: {processor_type}") |
|
|
|
|
|
class AttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class MasaCtrlProcessor(nn.Module): |
|
""" |
|
Mutual Self-attention Processor for diffusers library. |
|
Note that the all attention layers should register the same processor. |
|
""" |
|
MODEL_TYPE = { |
|
"SD": 16, |
|
"SDXL": 70 |
|
} |
|
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_layers=32, total_steps=50, model_type="SD"): |
|
""" |
|
Mutual self-attention control for Stable-Diffusion model |
|
Args: |
|
start_step: the step to start mutual self-attention control |
|
start_layer: the layer to start mutual self-attention control |
|
layer_idx: list of the layers to apply mutual self-attention control |
|
step_idx: list the steps to apply mutual self-attention control |
|
total_steps: the total number of steps, must be same to the denoising steps used in denoising scheduler |
|
model_type: the model type, SD or SDXL |
|
""" |
|
super().__init__() |
|
self.total_steps = total_steps |
|
self.total_layers = self.MODEL_TYPE.get(model_type, 16) |
|
self.start_step = start_step |
|
self.start_layer = start_layer |
|
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers)) |
|
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps)) |
|
print("MasaCtrl at denoising steps: ", self.step_idx) |
|
print("MasaCtrl at U-Net layers: ", self.layer_idx) |
|
|
|
self.cur_step = 0 |
|
self.cur_att_layer = 0 |
|
self.num_attn_layers = total_layers |
|
|
|
def after_step(self): |
|
pass |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
temb: Optional[torch.FloatTensor] = None, |
|
scale: float = 1.0, |
|
): |
|
out = self.attn_forward( |
|
attn, |
|
hidden_states, |
|
encoder_hidden_states, |
|
attention_mask, |
|
temb, |
|
scale, |
|
) |
|
self.cur_att_layer += 1 |
|
if self.cur_att_layer == self.num_attn_layers: |
|
self.cur_att_layer = 0 |
|
self.cur_step += 1 |
|
self.cur_step %= self.total_steps |
|
|
|
self.after_step() |
|
return out |
|
|
|
def masactrl_forward( |
|
self, |
|
query, |
|
key, |
|
value, |
|
): |
|
""" |
|
Rearrange the key and value for mutual self-attention control |
|
""" |
|
ku_src, ku_tgt, kc_src, kc_tgt = key.chunk(4) |
|
vu_src, vu_tgt, vc_src, vc_tgt = value.chunk(4) |
|
|
|
k_rearranged = torch.cat([ku_src, ku_src, kc_src, kc_src]) |
|
v_rearranged = torch.cat([vu_src, vu_src, vc_src, vc_src]) |
|
|
|
return query, k_rearranged, v_rearranged |
|
|
|
def attn_forward( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
cur_transformer_layer = self.cur_att_layer // 2 |
|
residual = hidden_states |
|
|
|
is_cross = True if encoder_hidden_states is not None else False |
|
|
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states, *args) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states, *args) |
|
value = attn.to_v(encoder_hidden_states, *args) |
|
|
|
|
|
if not is_cross and self.cur_step in self.step_idx and cur_transformer_layer in self.layer_idx: |
|
query, key, value = self.masactrl_forward(query, key, value) |
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, *args) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|