ASVA / unet_utils.py
Lin Z
init commit
d6d7648
from typing import Optional
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import Attention
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class FFInflatedConv3d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
**kwargs,
)
self.conv_temp = nn.Linear(3 * out_channels, out_channels)
nn.init.zeros_(self.conv_temp.weight.data) # initialized to be ones
nn.init.zeros_(self.conv_temp.bias.data)
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
*_, h, w = x.shape
x = rearrange(x, "(b f) c h w -> (b h w) f c", f=video_length)
head_frame_index = [0, ] * video_length
prev_frame_index = torch.clamp(
torch.arange(video_length) - 1, min=0.0
).long()
curr_frame_index = torch.arange(video_length).long()
conv_temp_nn_input = torch.cat([
x[:, head_frame_index],
x[:, prev_frame_index],
x[:, curr_frame_index]
], dim=2).contiguous()
x = x + self.conv_temp(conv_temp_nn_input)
x = rearrange(x, "(b h w) f c -> b c f h w", h=h, w=w)
return x
class FFAttention(Attention):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
*args,
scale_qk: bool = True,
processor: Optional["FFAttnProcessor"] = None,
**kwargs
):
super().__init__(*args, scale_qk=scale_qk, processor=processor, **kwargs)
# set attention processor
# We use the AttnProcessor by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument.
if processor is None:
processor = FFAttnProcessor()
self.set_processor(processor)
def forward(self, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None,
**cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
video_length=video_length,
**cross_attention_kwargs,
)
class FFAttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FFAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(self, attn: Attention, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# sparse causal attention
former_frame_index = torch.arange(video_length) - 1
former_frame_index[0] = 0
key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
key = key[:, [0] * video_length].contiguous()
key = rearrange(key, "b f d c -> (b f) d c")
value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
value = value[:, [0] * video_length].contiguous()
value = rearrange(value, "b f d c -> (b f) d c")
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states