peekaboo-demo / src /models /transformer_temporal.py
Anshul Nasery
Demo commit
44f2ca8
raw history blame
No virus
10.5 kB
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
import math
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput
from diffusers.models.modeling_utils import ModelMixin
from .attention import BasicTransformerBlock
@dataclass
class TransformerTemporalModelOutput(BaseOutput):
"""
The output of [`TransformerTemporalModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input.
"""
sample: torch.FloatTensor
class TransformerTemporalModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
num_frames=1,
cross_attention_kwargs=None,
return_dict: bool = True,
attention_mask=None,
encoder_attention_mask=None,
**kwargs,
):
"""
The [`TransformerTemporal`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
if attention_mask is not None:
if not isinstance(attention_mask, list):
# Attn mask - (32, 1, 1024
new_attn_mask = attention_mask.clone()
# Convert to (2,16,1024)
new_attn_mask = new_attn_mask.permute(1,0,2).reshape(-1,num_frames, new_attn_mask.shape[2])
# spatial_dim_attn_mask = int(math.sqrt(new_attn_mask.shape[-1]))
scaling_factor = int(math.sqrt(new_attn_mask.shape[2] / (height*width)))
mask_x = int(height * scaling_factor)
mask_y = int(width * scaling_factor)
# Scale the attention mask possibly
new_attn_mask = new_attn_mask.reshape(-1, num_frames, mask_x, mask_y)[:,:,::scaling_factor, ::scaling_factor]
# Convert to (2,16,64)
new_attn_mask = new_attn_mask.reshape(-1, num_frames, height*width).permute(0,2,1)
# Convert to (128, 1, 16) when hidden states are (128, 16, 1280)
new_attn_mask = new_attn_mask.reshape(-1,1,num_frames)
# Trying to invert this mask, so that background is the only thing active -
new_attn_mask = torch.where(new_attn_mask < 0., 0., -10000.).type(new_attn_mask.dtype).to(new_attn_mask.device)
else:
new_attn_mask_list = []
for attn_mask in attention_mask:
new_attn_mask = attn_mask.clone()
new_attn_mask = new_attn_mask.permute(1,0,2).reshape(-1,num_frames, new_attn_mask.shape[2])
scaling_factor = int(math.sqrt(new_attn_mask.shape[2] / (height*width)))
mask_x = int(height * scaling_factor)
mask_y = int(width * scaling_factor)
# Scale the attention mask possibly
new_attn_mask = new_attn_mask.reshape(-1, num_frames, mask_x, mask_y)[:,:,::scaling_factor, ::scaling_factor]
new_attn_mask = new_attn_mask.reshape(-1, num_frames, height*width).permute(0,2,1)
new_attn_mask = new_attn_mask.reshape(-1,1,num_frames)
new_attn_mask = torch.where(new_attn_mask < 0., 0., -10000.).type(new_attn_mask.dtype).to(new_attn_mask.device)
new_attn_mask_list.append(new_attn_mask)
new_attn_mask = new_attn_mask_list
else:
new_attn_mask = None
residual = hidden_states
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
attention_mask=new_attn_mask,
encoder_attention_mask=encoder_attention_mask,
# make_2d_attention_mask=True, # Check this
# block_diagonal_attention=True, # TODO - Check this
**kwargs,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, channel, num_frames)
.permute(0, 3, 4, 1, 2)
.contiguous()
)
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)