Spaces:
Runtime error
Runtime error
File size: 6,345 Bytes
8fd2f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
from einops import rearrange
from diffusers.models.attention_processor import Attention
class CrossAttention(nn.Module):
"""
CrossAttention module implements per-pixel temporal attention to fuse the conditional attention module with the base module.
Args:
input_channels (int): Number of input channels.
attention_head_dim (int): Dimension of attention head.
norm_num_groups (int): Number of groups for GroupNorm normalization (default is 32).
Attributes:
attention (Attention): Attention module for computing attention scores.
norm (torch.nn.GroupNorm): Group normalization layer.
proj_in (nn.Linear): Linear layer for projecting input data.
proj_out (nn.Linear): Linear layer for projecting output data.
dropout (nn.Dropout): Dropout layer for regularization.
Methods:
forward(hidden_state, encoder_hidden_states, num_frames, num_conditional_frames):
Forward pass of the CrossAttention module.
"""
def __init__(self, input_channels, attention_head_dim, norm_num_groups=32):
super().__init__()
self.attention = Attention(
query_dim=input_channels, cross_attention_dim=input_channels, heads=input_channels//attention_head_dim, dim_head=attention_head_dim, bias=False, upcast_attention=False)
self.norm = torch.nn.GroupNorm(
num_groups=norm_num_groups, num_channels=input_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(input_channels, input_channels)
self.proj_out = nn.Linear(input_channels, input_channels)
self.dropout = nn.Dropout(p=0.25)
def forward(self, hidden_state, encoder_hidden_states, num_frames, num_conditional_frames):
"""
The input hidden state is normalized, then projected using a linear layer.
Multi-head cross attention is computed between the hidden state (latent of noisy video) and encoder hidden states (CLIP image encoder).
The output is projected using a linear layer.
We apply dropout to the newly generated frames (without the control frames).
Args:
hidden_state (torch.Tensor): Input hidden state tensor.
encoder_hidden_states (torch.Tensor): Encoder hidden states tensor.
num_frames (int): Number of frames.
num_conditional_frames (int): Number of conditional frames.
Returns:
output (torch.Tensor): Output tensor after processing with attention mechanism.
"""
h, w = hidden_state.shape[2], hidden_state.shape[3]
hidden_state_norm = rearrange(
hidden_state, "(B F) C H W -> B C F H W", F=num_frames)
hidden_state_norm = self.norm(hidden_state_norm)
hidden_state_norm = rearrange(
hidden_state_norm, "B C F H W -> (B H W) F C")
hidden_state_norm = self.proj_in(hidden_state_norm)
attn = self.attention(hidden_state_norm,
encoder_hidden_states=encoder_hidden_states,
attention_mask=None,
)
# proj_out
residual = self.proj_out(attn) # (B H W) F C
hidden_state = rearrange(
hidden_state, "(B F) ... -> B F ...", F=num_frames)
hidden_state = torch.cat([hidden_state[:, :num_conditional_frames], self.dropout(
hidden_state[:, num_conditional_frames:])], dim=1)
hidden_state = rearrange(hidden_state, "B F ... -> (B F) ... ")
residual = rearrange(
residual, "(B H W) F C -> (B F) C H W", H=h, W=w)
output = hidden_state + residual
return output
class ConditionalModel(nn.Module):
"""
ConditionalModel module performs the fusion of the conditional attention module to be base model.
Args:
input_channels (int): Number of input channels.
conditional_model (str): Type of conditional model to use. Currently only "cross_attention" is implemented.
attention_head_dim (int): Dimension of attention head (default is 64).
Attributes:
temporal_transformer (CrossAttention): CrossAttention module for temporal transformation.
conditional_model (str): Type of conditional model used.
Methods:
forward(sample, conditioning, num_frames=None, num_conditional_frames=None):
Forward pass of the ConditionalModel module.
"""
def __init__(self, input_channels, conditional_model: str, attention_head_dim=64):
super().__init__()
if conditional_model == "cross_attention":
self.temporal_transformer = CrossAttention(
input_channels=input_channels, attention_head_dim=attention_head_dim)
else:
raise NotImplementedError(
f"mode {conditional_model} not implemented")
nn.init.zeros_(self.temporal_transformer.proj_out.weight)
nn.init.zeros_(self.temporal_transformer.proj_out.bias)
self.conditional_model = conditional_model
def forward(self, sample, conditioning, num_frames=None, num_conditional_frames=None):
"""
Forward pass of the ConditionalModel module.
Args:
sample (torch.Tensor): Input sample tensor.
conditioning (torch.Tensor): Conditioning tensor containing the enconding of the conditional frames.
num_frames (int): Number of frames in the sample.
num_conditional_frames (int): Number of conditional frames.
Returns:
sample (torch.Tensor): Transformed sample tensor.
"""
sample = rearrange(sample, "(B F) ... -> B F ...", F=num_frames)
batch_size = sample.shape[0]
conditioning = rearrange(
conditioning, "(B F) ... -> B F ...", B=batch_size)
assert conditioning.ndim == 5
assert sample.ndim == 5
conditioning = rearrange(conditioning, "B F C H W -> (B H W) F C")
sample = rearrange(sample, "B F C H W -> (B F) C H W")
sample = self.temporal_transformer(
sample, encoder_hidden_states=conditioning, num_frames=num_frames, num_conditional_frames=num_conditional_frames)
return sample
if __name__ == "__main__":
model = CrossAttention(input_channels=320, attention_head_dim=32)
|