import torch from typing import Optional from einops import rearrange from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention def sinusoidal_embedding_1d(dim, position): sinusoid = torch.outer(position.type(torch.float64), torch.pow( 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x.to(position.dtype) def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape pad_size = target_len - seq_len padding_tensor = torch.ones( pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device) padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) return padded_tensor def rope_apply(x, freqs, num_heads): x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) s_per_rank = x.shape[1] x_out = torch.view_as_complex(x.to(torch.float64).reshape( x.shape[0], x.shape[1], x.shape[2], -1, 2)) sp_size = get_sequence_parallel_world_size() sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype) def usp_dit_forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, ): t = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, timestep)) t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) context = self.text_embedding(context) if self.has_image_input: x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) x, (f, h, w) = self.patchify(x) freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # Context Parallel x = torch.chunk( x, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] for block in self.blocks: if self.training and use_gradient_checkpointing: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, use_reentrant=False, ) else: x = block(x, context, t_mod, freqs) x = self.head(x, t) # Context Parallel x = get_sp_group().all_gather(x, dim=1) # unpatchify x = self.unpatchify(x, (f, h, w)) return x def usp_attn_forward(self, x, freqs): q = self.norm_q(self.q(x)) k = self.norm_k(self.k(x)) v = self.v(x) q = rope_apply(q, freqs, self.num_heads) k = rope_apply(k, freqs, self.num_heads) q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads) k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads) v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads) x = xFuserLongContextAttention()( None, query=q, key=k, value=v, ) x = x.flatten(2) del q, k, v torch.cuda.empty_cache() return self.o(x)