ZIT-Controlnet / videox_fun /models /wan_audio_injector.py
Alexander Bagus
22
be751d2
# Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/motioner.py
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import importlib.metadata
import math
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models import ModelMixin
from diffusers.models.attention import AdaLayerNorm
from diffusers.utils import BaseOutput, is_torch_version, logging
from einops import rearrange, repeat
from .attention_utils import attention
from .wan_transformer3d import WanAttentionBlock, WanCrossAttention
def rope_precompute(x, grid_sizes, freqs, start=None):
b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
# split freqs
if type(freqs) is list:
trainable_freqs = freqs[1]
freqs = freqs[0]
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = torch.view_as_complex(x.detach().reshape(b, s, n, -1,
2).to(torch.float64))
seq_bucket = [0]
if not type(grid_sizes) is list:
grid_sizes = [grid_sizes]
for g in grid_sizes:
if not type(g) is list:
g = [torch.zeros_like(g), g]
batch_size = g[0].shape[0]
for i in range(batch_size):
if start is None:
f_o, h_o, w_o = g[0][i]
else:
f_o, h_o, w_o = start[i]
f, h, w = g[1][i]
t_f, t_h, t_w = g[2][i]
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
seq_len = int(seq_f * seq_h * seq_w)
if seq_len > 0:
if t_f > 0:
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
t_h / seq_h).item(), (t_w / seq_w).item()
# Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
if f_o >= 0:
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
seq_f).astype(int).tolist()
else:
f_sam = np.linspace(-f_o.item(),
(-t_f - f_o).item() + 1,
seq_f).astype(int).tolist()
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
seq_h).astype(int).tolist()
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
seq_w).astype(int).tolist()
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
f_sam].conj()
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
freqs_i = torch.cat([
freqs_0.expand(seq_f, seq_h, seq_w, -1),
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
seq_f, seq_h, seq_w, -1),
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
seq_f, seq_h, seq_w, -1),
],
dim=-1).reshape(seq_len, 1, -1)
elif t_f < 0:
freqs_i = trainable_freqs.unsqueeze(1)
# apply rotary embedding
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
seq_bucket.append(seq_bucket[-1] + seq_len)
return output
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
# calculation
sinusoid = torch.outer(
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs, start=None):
n, c = x.size(2), x.size(3) // 2
# split freqs
if type(freqs) is list:
trainable_freqs = freqs[1]
freqs = freqs[0]
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
output = x.clone()
seq_bucket = [0]
if not type(grid_sizes) is list:
grid_sizes = [grid_sizes]
for g in grid_sizes:
if not type(g) is list:
g = [torch.zeros_like(g), g]
batch_size = g[0].shape[0]
for i in range(batch_size):
if start is None:
f_o, h_o, w_o = g[0][i]
else:
f_o, h_o, w_o = start[i]
f, h, w = g[1][i]
t_f, t_h, t_w = g[2][i]
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
seq_len = int(seq_f * seq_h * seq_w)
if seq_len > 0:
if t_f > 0:
factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
t_h / seq_h).item(), (t_w / seq_w).item()
if f_o >= 0:
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
seq_f).astype(int).tolist()
else:
f_sam = np.linspace(-f_o.item(),
(-t_f - f_o).item() + 1,
seq_f).astype(int).tolist()
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
seq_h).astype(int).tolist()
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
seq_w).astype(int).tolist()
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
f_sam].conj()
freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
freqs_i = torch.cat([
freqs_0.expand(seq_f, seq_h, seq_w, -1),
freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
seq_f, seq_h, seq_w, -1),
freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
seq_f, seq_h, seq_w, -1),
],
dim=-1).reshape(seq_len, 1, -1)
elif t_f < 0:
freqs_i = trainable_freqs.unsqueeze(1)
# apply rotary embedding
# precompute multipliers
x_i = torch.view_as_complex(
x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to(
torch.float64).reshape(seq_len, n, -1, 2))
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i
seq_bucket.append(seq_bucket[-1] + seq_len)
return output.float()
class CausalConv1d(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size=3,
stride=1,
dilation=1,
pad_mode='replicate',
**kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
need_global=True,
dtype=None,
device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(
in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
if need_global:
self.conv1_global = CausalConv1d(
in_dim, hidden_dim // 4, 3, stride=1)
self.norm1 = nn.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
if need_global:
self.final_linear = nn.Linear(hidden_dim, hidden_dim,
**factory_kwargs)
self.norm1 = nn.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm2 = nn.LayerNorm(
hidden_dim // 2,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm3 = nn.LayerNorm(
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local
class CausalAudioEncoder(nn.Module):
def __init__(self,
dim=5120,
num_layers=25,
out_dim=2048,
video_rate=8,
num_token=4,
need_global=False):
super().__init__()
self.encoder = MotionEncoder_tc(
in_dim=dim,
hidden_dim=out_dim,
num_heads=num_token,
need_global=need_global)
weight = torch.ones((1, num_layers, 1, 1)) * 0.01
self.weights = torch.nn.Parameter(weight)
self.act = torch.nn.SiLU()
def forward(self, features):
with amp.autocast(dtype=torch.float32):
# features B * num_layers * dim * video_length
weights = self.act(self.weights)
weights_sum = weights.sum(dim=1, keepdims=True)
weighted_feat = ((features * weights) / weights_sum).sum(
dim=1) # b dim f
weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
res = self.encoder(weighted_feat) # b f n dim
return res # b f n dim
class AudioCrossAttention(WanCrossAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
v = self.v(context.to(dtype)).view(b, -1, n, d)
# compute attention
x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens, attention_type="FLASH_ATTENTION")
# output
x = x.flatten(2)
x = self.o(x.to(dtype))
return x
class AudioInjector_WAN(nn.Module):
def __init__(self,
all_modules,
all_modules_names,
dim=2048,
num_heads=32,
inject_layer=[0, 27],
root_net=None,
enable_adain=False,
adain_dim=2048,
need_adain_ont=False):
super().__init__()
num_injector_layers = len(inject_layer)
self.injected_block_id = {}
audio_injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, WanAttentionBlock):
for inject_id in inject_layer:
if f'transformer_blocks.{inject_id}' in mod_name:
self.injected_block_id[inject_id] = audio_injector_id
audio_injector_id += 1
self.injector = nn.ModuleList([
AudioCrossAttention(
dim=dim,
num_heads=num_heads,
qk_norm=True,
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_feat = nn.ModuleList([
nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)
])
self.injector_pre_norm_vec = nn.ModuleList([
nn.LayerNorm(
dim,
elementwise_affine=False,
eps=1e-6,
) for _ in range(audio_injector_id)
])
if enable_adain:
self.injector_adain_layers = nn.ModuleList([
AdaLayerNorm(
output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1)
for _ in range(audio_injector_id)
])
if need_adain_ont:
self.injector_adain_output_layers = nn.ModuleList(
[nn.Linear(dim, dim) for _ in range(audio_injector_id)])
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
class LayerNorm(nn.LayerNorm):
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module):
def __init__(self,
dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
def forward(self, x, seq_lens, grid_sizes, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
x = attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
class SwinSelfAttention(SelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
assert b == 1, 'Only support batch_size 1'
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
T, H, W = grid_sizes[0].tolist()
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
ref_q = q[-1:]
q = q[:-1]
ref_k = repeat(
k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
k = k[:-1]
k = torch.cat([k[:1], k, k[-1:]])
k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1)
v = v[:-1]
v = torch.cat([v[:1], v, v[-1:]])
v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)
# q: b (t h w) n d
# k: b (t h w) n d
out = attention(
q=q,
k=k,
v=v,
# k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),
window_size=self.window_size)
out = torch.cat([out, ref_v[:1]], axis=0)
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
x = out
# output
x = x.flatten(2)
x = self.o(x)
return x
#Fix the reference frame RoPE to 1,H,W.
#Set the current frame RoPE to 1.
#Set the previous frame RoPE to 0.
class CasualSelfAttention(SelfAttention):
def forward(self, x, seq_lens, grid_sizes, freqs):
shifting = 3
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
assert b == 1, 'Only support batch_size 1'
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
T, H, W = grid_sizes[0].tolist()
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
ref_q = q[-1:]
q = q[:-1]
grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long)
start = [[shifting, 0, 0]] * q.shape[0]
q = rope_apply(q, grid_sizes, freqs, start=start)
ref_k = k[-1:]
grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long)
# start = [[shifting, H, W]]
start = [[shifting + 10, 0, 0]]
ref_k = rope_apply(ref_k, grid_sizes, freqs, start)
ref_k = repeat(
ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
k = k[:-1]
k = torch.cat([*([k[:1]] * shifting), k])
cat_k = []
for i in range(shifting):
cat_k.append(k[i:i - shifting])
cat_k.append(k[shifting:])
k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d
grid_sizes = torch.tensor(
[[shifting + 1, H, W]] * q.shape[0], dtype=torch.long)
k = rope_apply(k, grid_sizes, freqs)
k = torch.cat([k, ref_k], dim=1)
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d
v = v[:-1]
v = torch.cat([*([v[:1]] * shifting), v])
cat_v = []
for i in range(shifting):
cat_v.append(v[i:i - shifting])
cat_v.append(v[shifting:])
v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d
v = torch.cat([v, ref_v], dim=1)
# q: b (t h w) n d
# k: b (t h w) n d
outs = []
for i in range(q.shape[0]):
out = attention(
q=q[i:i + 1],
k=k[i:i + 1],
v=v[i:i + 1],
window_size=self.window_size)
outs.append(out)
out = torch.cat(outs, dim=0)
out = torch.cat([out, ref_v[:1]], axis=0)
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
x = out
# output
x = x.flatten(2)
x = self.o(x)
return x
class MotionerAttentionBlock(nn.Module):
def __init__(self,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
self_attn_block="SelfAttention"):
super().__init__()
self.dim = dim
self.ffn_dim = ffn_dim
self.num_heads = num_heads
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
# layers
self.norm1 = LayerNorm(dim, eps)
if self_attn_block == "SelfAttention":
self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm,
eps)
elif self_attn_block == "SwinSelfAttention":
self.self_attn = SwinSelfAttention(dim, num_heads, window_size,
qk_norm, eps)
elif self_attn_block == "CasualSelfAttention":
self.self_attn = CasualSelfAttention(dim, num_heads, window_size,
qk_norm, eps)
self.norm2 = LayerNorm(dim, eps)
self.ffn = nn.Sequential(
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
nn.Linear(ffn_dim, dim))
def forward(
self,
x,
seq_lens,
grid_sizes,
freqs,
):
# self-attention
y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs)
x = x + y
y = self.ffn(self.norm2(x).float())
x = x + y
return x
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
super().__init__()
self.dim = dim
self.out_dim = out_dim
self.patch_size = patch_size
self.eps = eps
# layers
out_dim = math.prod(patch_size) * out_dim
self.norm = LayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
def forward(self, x):
x = self.head(self.norm(x))
return x
class MotionerTransformers(nn.Module, PeftAdapterMixin):
def __init__(
self,
patch_size=(1, 2, 2),
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
self_attn_block="SelfAttention",
motion_token_num=1024,
enable_tsm=False,
motion_stride=4,
expand_ratio=2,
trainable_token_pos_emb=False,
):
super().__init__()
self.patch_size = patch_size
self.in_dim = in_dim
self.dim = dim
self.ffn_dim = ffn_dim
self.freq_dim = freq_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.num_layers = num_layers
self.window_size = window_size
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
self.enable_tsm = enable_tsm
self.motion_stride = motion_stride
self.expand_ratio = expand_ratio
self.sample_c = self.patch_size[0]
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
# blocks
self.blocks = nn.ModuleList([
MotionerAttentionBlock(
dim,
ffn_dim,
num_heads,
window_size,
qk_norm,
cross_attn_norm,
eps,
self_attn_block=self_attn_block) for _ in range(num_layers)
])
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
self.gradient_checkpointing = False
self.motion_side_len = int(math.sqrt(motion_token_num))
assert self.motion_side_len**2 == motion_token_num
self.token = nn.Parameter(
torch.zeros(1, motion_token_num, dim).contiguous())
self.trainable_token_pos_emb = trainable_token_pos_emb
if trainable_token_pos_emb:
x = torch.zeros([1, motion_token_num, num_heads, d])
x[..., ::2] = 1
gride_sizes = [[
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([1, self.motion_side_len,
self.motion_side_len]).unsqueeze(0).repeat(1, 1),
torch.tensor([1, self.motion_side_len,
self.motion_side_len]).unsqueeze(0).repeat(1, 1),
]]
token_freqs = rope_apply(x, gride_sizes, self.freqs)
token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2)
token_freqs = token_freqs * 0.01
self.token_freqs = torch.nn.Parameter(token_freqs)
def after_patch_embedding(self, x):
return x
def forward(
self,
x,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
# params
motion_frames = x[0].shape[1]
device = self.patch_embedding.weight.device
freqs = self.freqs
if freqs.device != device:
freqs = freqs.to(device)
if self.trainable_token_pos_emb:
with amp.autocast(dtype=torch.float64):
token_freqs = self.token_freqs.to(torch.float64)
token_freqs = token_freqs / token_freqs.norm(
dim=-1, keepdim=True)
freqs = [freqs, torch.view_as_complex(token_freqs)]
if self.enable_tsm:
sample_idx = [
sample_indices(
u.shape[1],
stride=self.motion_stride,
expand_ratio=self.expand_ratio,
c=self.sample_c) for u in x
]
x = [
torch.flip(torch.flip(u, [1])[:, idx], [1])
for idx, u in zip(sample_idx, x)
]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
x = self.after_patch_embedding(x)
seq_f, seq_h, seq_w = x[0].shape[-3:]
batch_size = len(x)
if not self.enable_tsm:
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
grid_sizes = [[
torch.zeros_like(grid_sizes), grid_sizes, grid_sizes
]]
seq_f = 0
else:
grid_sizes = []
for idx in sample_idx[0][::-1][::self.sample_c]:
tsm_frame_grid_sizes = [[
torch.tensor([idx, 0,
0]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([idx + 1, seq_h,
seq_w]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor([1, seq_h,
seq_w]).unsqueeze(0).repeat(batch_size, 1),
]]
grid_sizes += tsm_frame_grid_sizes
seq_f = sample_idx[0][-1] + 1
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
x = torch.cat([u for u in x])
batch_size = len(x)
token_grid_sizes = [[
torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor(
[seq_f + 1, self.motion_side_len,
self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1),
torch.tensor(
[1 if not self.trainable_token_pos_emb else -1, seq_h,
seq_w]).unsqueeze(0).repeat(batch_size, 1),
] # 第三行代表rope emb的想要覆盖到的范围
]
grid_sizes = grid_sizes + token_grid_sizes
token_unpatch_grid_sizes = torch.stack([
torch.tensor([1, 32, 32], dtype=torch.long)
for b in range(batch_size)
])
token_len = self.token.shape[1]
token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous()
seq_lens = seq_lens + torch.tensor([t.size(0) for t in token],
dtype=torch.long)
x = torch.cat([x, token], dim=1)
# arguments
kwargs = dict(
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=freqs,
)
# grad ckpt args
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs, **kwargs):
if return_dict is not None:
return module(*inputs, **kwargs, return_dict=return_dict)
else:
return module(*inputs, **kwargs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = ({
"use_reentrant": False
} if is_torch_version(">=", "1.11.0") else {})
for idx, block in enumerate(self.blocks):
if self.training and self.gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
**kwargs,
**ckpt_kwargs,
)
else:
x = block(x, **kwargs)
# head
out = x[:, -token_len:]
return out
def unpatchify(self, x, grid_sizes):
c = self.out_dim
out = []
for u, v in zip(x, grid_sizes.tolist()):
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
u = torch.einsum('fhwpqrc->cfphqwr', u)
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
out.append(u)
return out
def init_weights(self):
# basic init
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# init embeddings
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
class FramePackMotioner(nn.Module):
def __init__(
self,
inner_dim=1024,
num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
zip_frame_buckets=[
1, 2, 16
], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.proj = nn.Conv3d(
16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(
16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(
16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
self.zip_frame_buckets = torch.tensor(
zip_frame_buckets, dtype=torch.long)
self.inner_dim = inner_dim
self.num_heads = num_heads
assert (inner_dim %
num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
d = inner_dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
self.drop_mode = drop_mode
def forward(self, motion_latents, add_last_motion=2):
motion_frames = motion_latents[0].shape[1]
mot = []
mot_remb = []
for m in motion_latents:
lat_height, lat_width = m.shape[2], m.shape[3]
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height,
lat_width).to(
device=m.device, dtype=m.dtype)
overlap_frame = min(padd_lat.shape[1], m.shape[1])
if overlap_frame > 0:
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
if add_last_motion < 2 and self.drop_mode != "drop":
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.
__len__() -
add_last_motion -
1].sum()
padd_lat[:, -zero_end_frame:] = 0
padd_lat = padd_lat.unsqueeze(0)
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum(
):, :, :].split(
list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1
# patchfy
clean_latents_post = self.proj(clean_latents_post).flatten(
2).transpose(1, 2)
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(
2).transpose(1, 2)
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(
2).transpose(1, 2)
if add_last_motion < 2 and self.drop_mode == "drop":
clean_latents_post = clean_latents_post[:, :
0] if add_last_motion < 2 else clean_latents_post
clean_latents_2x = clean_latents_2x[:, :
0] if add_last_motion < 1 else clean_latents_2x
motion_lat = torch.cat(
[clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
# rope
start_time_id = -(self.zip_frame_buckets[:1].sum())
end_time_id = start_time_id + self.zip_frame_buckets[0]
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:2].sum())
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
[
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
]
start_time_id = -(self.zip_frame_buckets[:3].sum())
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
grid_sizes_4x = [[
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
torch.tensor([end_time_id, lat_height // 8,
lat_width // 8]).unsqueeze(0).repeat(1, 1),
torch.tensor([
self.zip_frame_buckets[2], lat_height // 2, lat_width // 2
]).unsqueeze(0).repeat(1, 1),
]]
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
motion_rope_emb = rope_precompute(
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads,
self.inner_dim // self.num_heads),
grid_sizes,
self.freqs,
start=None)
mot.append(motion_lat)
mot_remb.append(motion_rope_emb)
return mot, mot_remb
def sample_indices(N, stride, expand_ratio, c):
indices = []
current_start = 0
while current_start < N:
bucket_width = int(stride * (expand_ratio**(len(indices) / stride)))
interval = int(bucket_width / stride * c)
current_end = min(N, current_start + bucket_width)
bucket_samples = []
for i in range(current_end - 1, current_start - 1, -interval):
for near in range(c):
bucket_samples.append(i - near)
indices += bucket_samples[::-1]
current_start += bucket_width
return indices