zejunyang
init
2e4e201
raw
history blame contribute delete
No virus
35.4 kB
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from typing import Any, Dict, Optional
import torch
from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from einops import rearrange
from torch import nn
from diffusers.models.attention import *
from diffusers.models.attention_processor import *
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (
num_embeds_ada_norm is not None
) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(
dim, max_seq_length=num_positional_embeddings
)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim
if not double_self_attention
else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(
dim, cross_attention_dim, num_attention_heads, attention_head_dim
)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = (
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
)
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class TemporalBasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
# SC-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
video_length=None,
):
norm_hidden_states = (
self.norm1(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm1(hidden_states)
)
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
video_length=video_length,
)
+ hidden_states
)
else:
hidden_states = (
self.attn1(norm_hidden_states, attention_mask=attention_mask)
+ hidden_states
)
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
class ResidualTemporalBasicTransformerBlock(TemporalBasicTransformerBlock):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
super(TemporalBasicTransformerBlock, self).__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
# SC-Attn
self.attn1 = ResidualAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = ResidualAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
attention_mask=None,
video_length=None,
block_idx: Optional[int] = None,
additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None
):
norm_hidden_states = (
self.norm1(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm1(hidden_states)
)
if self.unet_use_cross_frame_attention:
hidden_states = (
self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
video_length=video_length,
block_idx=block_idx,
additional_residuals=additional_residuals,
)
+ hidden_states
)
else:
hidden_states = (
self.attn1(norm_hidden_states, attention_mask=attention_mask,
block_idx=block_idx,
additional_residuals=additional_residuals
)
+ hidden_states
)
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
block_idx=block_idx,
additional_residuals=additional_residuals,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(
hidden_states, "(b f) d c -> (b d) f c", f=video_length
)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
class ResidualAttention(Attention):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
if is_lora:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = ResidualXFormersAttnProcessor(attention_op=attention_op)
else:
if is_lora:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionAttnProcessor(
train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out,
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
else:
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
processor = (
AttnProcessor2_0()
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
else AttnProcessor()
)
self.set_processor(processor)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
block_idx: Optional[int] = None, additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None,
is_self_attn: Optional[bool] = None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
block_idx=block_idx,
additional_residuals=additional_residuals,
is_self_attn=is_self_attn,
**cross_attention_kwargs,
)
class ResidualXFormersAttnProcessor(XFormersAttnProcessor):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
block_idx: Optional[int] = None,
additional_residuals: Optional[Dict[str, torch.FloatTensor]] = None,
is_self_attn: Optional[bool] = None
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
# newly added
if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_q" in additional_residuals:
query = query + additional_residuals[f"block_{block_idx}_self_attn_q"]
elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_q" in additional_residuals:
query = query + additional_residuals[f"block_{block_idx}_cross_attn_q"]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_c" in additional_residuals:
not_uc = torch.abs(encoder_hidden_states - torch.zeros_like(encoder_hidden_states)).mean(dim=[1, 2], keepdim=True) < 1e-4
encoder_hidden_states = encoder_hidden_states + additional_residuals[f"block_{block_idx}_cross_attn_c"] * not_uc
# encoder_hidden_states[not_uc] = encoder_hidden_states[not_uc] + \
# additional_residuals[f"block_{block_idx}_cross_attn_c"][not_uc]
# encoder_hidden_states[~not_uc] = encoder_hidden_states[~not_uc] + \
# additional_residuals[f"block_{block_idx}_cross_attn_c"][~not_uc] * 0.
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# newly added
if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_k" in additional_residuals:
key = key + additional_residuals[f"block_{block_idx}_self_attn_k"]
elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_k" in additional_residuals:
key = key + additional_residuals[f"block_{block_idx}_cross_attn_k"]
if is_self_attn and additional_residuals and f"block_{block_idx}_self_attn_v" in additional_residuals:
value = value + additional_residuals[f"block_{block_idx}_self_attn_v"]
elif not is_self_attn and additional_residuals and f"block_{block_idx}_cross_attn_v" in additional_residuals:
value = value + additional_residuals[f"block_{block_idx}_cross_attn_v"]
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states