Spaces:
Running
Running
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
import torch | |
import torch.cuda.amp as amp | |
from xfuser.core.distributed import (get_sequence_parallel_rank, | |
get_sequence_parallel_world_size, | |
get_sp_group) | |
from xfuser.core.long_ctx_attention import xFuserLongContextAttention | |
from ..modules.model import sinusoidal_embedding_1d | |
def pad_freqs(original_tensor, target_len): | |
seq_len, s1, s2 = original_tensor.shape | |
pad_size = target_len - seq_len | |
padding_tensor = torch.ones( | |
pad_size, | |
s1, | |
s2, | |
dtype=original_tensor.dtype, | |
device=original_tensor.device) | |
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) | |
return padded_tensor | |
def rope_apply(x, grid_sizes, freqs): | |
""" | |
x: [B, L, N, C]. | |
grid_sizes: [B, 3]. | |
freqs: [M, C // 2]. | |
""" | |
s, n, c = x.size(1), x.size(2), x.size(3) // 2 | |
# split freqs | |
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) | |
# loop over samples | |
output = [] | |
for i, (f, h, w) in enumerate(grid_sizes.tolist()): | |
seq_len = f * h * w | |
# precompute multipliers | |
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( | |
s, n, -1, 2)) | |
freqs_i = torch.cat([ | |
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
], | |
dim=-1).reshape(seq_len, 1, -1) | |
# apply rotary embedding | |
sp_size = get_sequence_parallel_world_size() | |
sp_rank = get_sequence_parallel_rank() | |
freqs_i = pad_freqs(freqs_i, s * sp_size) | |
s_per_rank = s | |
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * | |
s_per_rank), :, :] | |
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) | |
x_i = torch.cat([x_i, x[i, s:]]) | |
# append to collection | |
output.append(x_i) | |
return torch.stack(output).float() | |
def usp_dit_forward( | |
self, | |
x, | |
t, | |
context, | |
seq_len, | |
clip_fea=None, | |
y=None, | |
): | |
""" | |
x: A list of videos each with shape [C, T, H, W]. | |
t: [B]. | |
context: A list of text embeddings each with shape [L, C]. | |
""" | |
if self.model_type == 'i2v': | |
assert clip_fea is not None and y is not None | |
# params | |
device = self.patch_embedding.weight.device | |
if self.freqs.device != device: | |
self.freqs = self.freqs.to(device) | |
if y is not None: | |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] | |
# embeddings | |
x = [self.patch_embedding(u.unsqueeze(0)) for u in x] | |
grid_sizes = torch.stack( | |
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) | |
x = [u.flatten(2).transpose(1, 2) for u in x] | |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) | |
assert seq_lens.max() <= seq_len | |
x = torch.cat([ | |
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) | |
for u in x | |
]) | |
# time embeddings | |
with amp.autocast(dtype=torch.float32): | |
e = self.time_embedding( | |
sinusoidal_embedding_1d(self.freq_dim, t).float()) | |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) | |
assert e.dtype == torch.float32 and e0.dtype == torch.float32 | |
# context | |
context_lens = None | |
context = self.text_embedding( | |
torch.stack([ | |
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) | |
for u in context | |
])) | |
if clip_fea is not None: | |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim | |
context = torch.concat([context_clip, context], dim=1) | |
# arguments | |
kwargs = dict( | |
e=e0, | |
seq_lens=seq_lens, | |
grid_sizes=grid_sizes, | |
freqs=self.freqs, | |
context=context, | |
context_lens=context_lens) | |
# Context Parallel | |
x = torch.chunk( | |
x, get_sequence_parallel_world_size(), | |
dim=1)[get_sequence_parallel_rank()] | |
for block in self.blocks: | |
x = block(x, **kwargs) | |
# head | |
x = self.head(x, e) | |
# Context Parallel | |
x = get_sp_group().all_gather(x, dim=1) | |
# unpatchify | |
x = self.unpatchify(x, grid_sizes) | |
return [u.float() for u in x] | |
def usp_attn_forward(self, | |
x, | |
seq_lens, | |
grid_sizes, | |
freqs, | |
dtype=torch.bfloat16): | |
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim | |
half_dtypes = (torch.float16, torch.bfloat16) | |
def half(x): | |
return x if x.dtype in half_dtypes else x.to(dtype) | |
# query, key, value function | |
def qkv_fn(x): | |
q = self.norm_q(self.q(x)).view(b, s, n, d) | |
k = self.norm_k(self.k(x)).view(b, s, n, d) | |
v = self.v(x).view(b, s, n, d) | |
return q, k, v | |
q, k, v = qkv_fn(x) | |
q = rope_apply(q, grid_sizes, freqs) | |
k = rope_apply(k, grid_sizes, freqs) | |
# TODO: We should use unpaded q,k,v for attention. | |
# k_lens = seq_lens // get_sequence_parallel_world_size() | |
# if k_lens is not None: | |
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) | |
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) | |
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) | |
x = xFuserLongContextAttention()( | |
None, | |
query=half(q), | |
key=half(k), | |
value=half(v), | |
window_size=self.window_size) | |
# TODO: padding after attention. | |
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) | |
# output | |
x = x.flatten(2) | |
x = self.o(x) | |
return x | |