Genshin-Impact-XL-MasaCtrl / masactrl /masactrl_processor.py
svjack's picture
Upload 23 files
f070657 verified
from importlib import import_module
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import Attention
def register_attention_processor(
model: Optional[nn.Module] = None,
processor_type: str = "MasaCtrlProcessor",
**attn_args,
):
"""
Args:
model: a unet model or a list of unet models
processor_type: the type of the processor
"""
if not isinstance(model, (list, tuple)):
model = [model]
if processor_type == "MasaCtrlProcessor":
processor = MasaCtrlProcessor(**attn_args)
else:
processor = AttnProcessor()
for m in model:
m.set_attn_processor(processor)
print(f"Model {m.__class__.__name__} is registered attention processor: {processor_type}")
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class MasaCtrlProcessor(nn.Module):
"""
Mutual Self-attention Processor for diffusers library.
Note that the all attention layers should register the same processor.
"""
MODEL_TYPE = {
"SD": 16,
"SDXL": 70
}
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_layers=32, total_steps=50, model_type="SD"):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps, must be same to the denoising steps used in denoising scheduler
model_type: the model type, SD or SDXL
"""
super().__init__()
self.total_steps = total_steps
self.total_layers = self.MODEL_TYPE.get(model_type, 16)
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
print("MasaCtrl at denoising steps: ", self.step_idx)
print("MasaCtrl at U-Net layers: ", self.layer_idx)
self.cur_step = 0
self.cur_att_layer = 0
self.num_attn_layers = total_layers
def after_step(self):
pass
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
):
out = self.attn_forward(
attn,
hidden_states,
encoder_hidden_states,
attention_mask,
temb,
scale,
)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_attn_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.cur_step %= self.total_steps
# after step
self.after_step()
return out
def masactrl_forward(
self,
query,
key,
value,
):
"""
Rearrange the key and value for mutual self-attention control
"""
ku_src, ku_tgt, kc_src, kc_tgt = key.chunk(4)
vu_src, vu_tgt, vc_src, vc_tgt = value.chunk(4)
k_rearranged = torch.cat([ku_src, ku_src, kc_src, kc_src])
v_rearranged = torch.cat([vu_src, vu_src, vc_src, vc_src])
return query, k_rearranged, v_rearranged
def attn_forward(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
cur_transformer_layer = self.cur_att_layer // 2
residual = hidden_states
is_cross = True if encoder_hidden_states is not None else False
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
# mutual self-attention control
if not is_cross and self.cur_step in self.step_idx and cur_transformer_layer in self.layer_idx:
query, key, value = self.masactrl_forward(query, key, value)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states