Spaces:
Running
Running
File size: 10,869 Bytes
fc0a183 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
import numpy as np
import torch
import torch.amp as amp
from torch.backends.cuda import sdp_kernel
from xfuser.core.distributed import get_sequence_parallel_rank
from xfuser.core.distributed import get_sequence_parallel_world_size
from xfuser.core.distributed import get_sp_group
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.transformer import sinusoidal_embedding_1d
def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len
padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
return padded_tensor
@amp.autocast("cuda", enabled=False)
def rope_apply(x, grid_sizes, freqs):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
grid = [grid_sizes.tolist()] * x.size(0)
for i, (f, h, w) in enumerate(grid):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
freqs_i = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
).reshape(seq_len, 1, -1)
# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, :]
x_i = torch.view_as_real(x_i * freqs_i_rank.cuda()).flatten(2)
x_i = torch.cat([x_i, x[i, s:]])
# append to collection
output.append(x_i)
return torch.stack(output).float()
def broadcast_should_calc(should_calc: bool) -> bool:
import torch.distributed as dist
device = torch.cuda.current_device()
int_should_calc = 1 if should_calc else 0
tensor = torch.tensor([int_should_calc], device=device, dtype=torch.int8)
dist.broadcast(tensor, src=0)
should_calc = tensor.item() == 1
return should_calc
def usp_dit_forward(self, x, t, context, clip_fea=None, y=None, fps=None):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == "i2v":
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = torch.cat([x, y], dim=1)
# embeddings
x = self.patch_embedding(x)
grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
x = x.flatten(2).transpose(1, 2)
if self.flag_causal_attention:
frame_num = grid_sizes[0]
height = grid_sizes[1]
width = grid_sizes[2]
block_num = frame_num // self.num_frame_per_block
range_tensor = torch.arange(block_num).view(-1, 1)
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device)
casual_mask = casual_mask.repeat(1, height, width, 1, height, width)
casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width)
self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0)
# time embeddings
with amp.autocast("cuda", dtype=torch.float32):
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).float()
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
if _flag_df:
e = e.view(b, f, 1, 1, self.dim)
e0 = e0.view(b, f, 1, 1, 6, self.dim)
e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
e0 = e0.transpose(1, 2).contiguous()
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context = self.text_embedding(context)
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
if e0.ndim == 4:
e0 = torch.chunk(e0, get_sequence_parallel_world_size(), dim=2)[get_sequence_parallel_rank()]
kwargs = dict(e=e0, grid_sizes=grid_sizes, freqs=self.freqs, context=context, block_mask=self.block_mask)
if self.enable_teacache:
modulated_inp = e0 if self.use_ref_steps else e
# teacache
if self.cnt % 2 == 0: # even -> conditon
self.is_even = True
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.is_even = False
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
if self.enable_teacache:
if self.is_even:
should_calc_even = broadcast_should_calc(should_calc_even)
if not should_calc_even:
x += self.previous_residual_even
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
ori_x.mul_(-1)
ori_x.add_(x)
self.previous_residual_even = ori_x
else:
should_calc_odd = broadcast_should_calc(should_calc_odd)
if not should_calc_odd:
x += self.previous_residual_odd
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
ori_x.mul_(-1)
ori_x.add_(x)
self.previous_residual_odd = ori_x
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
else:
# Context Parallel
for block in self.blocks:
x = block(x, **kwargs)
# head
if e.ndim == 3:
e = torch.chunk(e, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
x = self.head(x, e)
# Context Parallel
x = get_sp_group().all_gather(x, dim=1)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x.float()
def usp_attn_forward(self, x, grid_sizes, freqs, block_mask):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
x = x.to(self.q.weight.dtype)
q, k, v = qkv_fn(x)
if not self._flag_ar_attention:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
else:
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
# x = torch.nn.functional.scaled_dot_product_attention(
# q.transpose(1, 2),
# k.transpose(1, 2),
# v.transpose(1, 2),
# ).transpose(1, 2).contiguous()
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
x = (
torch.nn.functional.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
x = xFuserLongContextAttention()(None, query=half(q), key=half(k), value=half(v), window_size=self.window_size)
# output
x = x.flatten(2)
x = self.o(x)
return x
|