marselgames9-gif135 / hotshot_xl /models /transformer_temporal.py
marselgames9's picture
Upload folder using huggingface_hub
48c12ab verified
raw
history blame contribute delete
No virus
7.12 kB
# Copyright 2023 Natural Synthetics Inc. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
import torch
import math
from dataclasses import dataclass
from torch import nn
from diffusers.utils import BaseOutput
from diffusers.models.attention import Attention, FeedForward
from einops import rearrange, repeat
from typing import Optional
class PositionalEncoding(nn.Module):
"""
Implements positional encoding as described in "Attention Is All You Need".
Adds sinusoidal based positional encodings to the input tensor.
"""
_SCALE_FACTOR = 10000.0 # Scale factor used in the positional encoding computation.
def __init__(self, dim: int, dropout: float = 0.0, max_length: int = 24):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# The size is (1, max_length, dim) to allow easy addition to input tensors.
positional_encoding = torch.zeros(1, max_length, dim)
# Position and dim are used in the sinusoidal computation.
position = torch.arange(max_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(self._SCALE_FACTOR) / dim))
positional_encoding[0, :, 0::2] = torch.sin(position * div_term)
positional_encoding[0, :, 1::2] = torch.cos(position * div_term)
# Register the positional encoding matrix as a buffer,
# so it's part of the model's state but not the parameters.
self.register_buffer('positional_encoding', positional_encoding)
def forward(self, hidden_states: torch.Tensor, length: int) -> torch.Tensor:
hidden_states = hidden_states + self.positional_encoding[:, :length]
return self.dropout(hidden_states)
class TemporalAttention(Attention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pos_encoder = PositionalEncoding(kwargs["query_dim"], dropout=0)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, number_of_frames=8):
sequence_length = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) s c -> (b s) f c", f=number_of_frames)
hidden_states = self.pos_encoder(hidden_states, length=number_of_frames)
if encoder_hidden_states:
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b s) n c", s=sequence_length)
hidden_states = super().forward(hidden_states, encoder_hidden_states, attention_mask=attention_mask)
return rearrange(hidden_states, "(b s) f c -> (b f) s c", s=sequence_length)
@dataclass
class TransformerTemporalOutput(BaseOutput):
sample: torch.FloatTensor
class TransformerTemporal(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
in_channels: int,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
upcast_attention: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim
)
for _ in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None):
_, num_channels, f, height, width = hidden_states.shape
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
skip = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = rearrange(hidden_states, "bf c h w -> bf (h w) c")
hidden_states = self.proj_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, number_of_frames=f)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "bf (h w) c -> bf c h w", h=height, w=width).contiguous()
output = hidden_states + skip
output = rearrange(output, "(b f) c h w -> b c f h w", f=f)
return output
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
dropout=0.0,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
depth=2,
cross_attention_dim: Optional[int] = None
):
super().__init__()
self.is_cross = cross_attention_dim is not None
attention_blocks = []
norms = []
for _ in range(depth):
attention_blocks.append(
TemporalAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, number_of_frames=None):
if not self.is_cross:
encoder_hidden_states = None
for block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
number_of_frames=number_of_frames
) + hidden_states
norm_hidden_states = self.ff_norm(hidden_states)
hidden_states = self.ff(norm_hidden_states) + hidden_states
output = hidden_states
return output