import math import torch import numpy as np from typing import Optional from einops import pack, rearrange, repeat import torch.nn as nn import torch.nn.functional as F """ DiT-v5 - Add convolution in DiTBlock to increase high-freq component """ class MLP(torch.nn.Module): def __init__( self, in_features:int, hidden_features:Optional[int]=None, out_features:Optional[int]=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0., ): super().__init__() hidden_features = hidden_features or in_features out_features = out_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.drop1 = nn.Dropout(drop) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x class Attention(torch.nn.Module): def __init__( self, dim: int, num_heads: int = 8, head_dim: int = 64, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.inner_dim = num_heads * head_dim self.scale = head_dim ** -0.5 self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias) self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias) self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) self.proj = nn.Linear(self.inner_dim, dim) def to_heads(self, ts:torch.Tensor): b, t, c = ts.shape # (b, t, nh, c) ts = ts.reshape(b, t, self.num_heads, c // self.num_heads) ts = ts.transpose(1, 2) return ts def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: """Args: x(torch.Tensor): shape (b, t, c) attn_mask(torch.Tensor): shape (b, t, t) """ b, t, c = x.shape q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) q = self.to_heads(q) # (b, nh, t, c) k = self.to_heads(k) v = self.to_heads(v) q = self.q_norm(q) k = self.k_norm(k) attn_mask = attn_mask.unsqueeze(1) x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) # (b, nh, t, c) x = x.transpose(1, 2).reshape(b, t, -1) x = self.proj(x) x = self.proj_drop(x) return x def forward_chunk(self, x: torch.Tensor, att_cache: torch.Tensor=None, attn_mask: torch.Tensor=None): """ Args: x: shape (b, dt, c) att_cache: shape (b, nh, t, c*2) """ b, t, c = x.shape q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) q = self.to_heads(q) # (b, nh, t, c) k = self.to_heads(k) v = self.to_heads(v) q = self.q_norm(q) k = self.k_norm(k) # unpack {k,v}_cache if att_cache is not None: if attn_mask is not None: k_cache, v_cache = att_cache.chunk(2, dim=3) k = torch.cat([k, k_cache], dim=2) v = torch.cat([v, v_cache], dim=2) else: k_cache, v_cache = att_cache.chunk(2, dim=3) k = torch.cat([k, k_cache], dim=2) v = torch.cat([v, v_cache], dim=2) # new {k,v}_cache new_att_cache = torch.cat([k, v], dim=3) # attn_mask = torch.ones((b, 1, t, t1), dtype=torch.bool, device=x.device) if attn_mask is not None: attn_mask = attn_mask.unsqueeze(1) x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # (b, nh, t, c) x = x.transpose(1, 2).reshape(b, t, -1) x = self.proj(x) x = self.proj_drop(x) return x, new_att_cache def modulate(x, shift, scale): return x * (1 + scale) + shift class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size # from SinusoidalPosEmb self.scale = 1000 @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half) / half ).to(t) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb # Convolution related class Transpose(torch.nn.Module): def __init__(self, dim0: int, dim1: int): super().__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: torch.Tensor): x = torch.transpose(x, self.dim0, self.dim1) return x class CausalConv1d(torch.nn.Conv1d): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, ) -> None: super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size) self.causal_padding = (kernel_size - 1, 0) def forward(self, x: torch.Tensor): x = F.pad(x, self.causal_padding) x = super(CausalConv1d, self).forward(x) return x def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None): if cnn_cache is None: cnn_cache = x.new_zeros((x.shape[0], self.in_channels, self.causal_padding[0])) x = torch.cat([cnn_cache, x], dim=2) new_cnn_cache = x[..., -self.causal_padding[0]:] x = super(CausalConv1d, self).forward(x) return x, new_cnn_cache class CausalConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.block = torch.nn.Sequential( # norm # conv1 Transpose(1, 2), CausalConv1d(in_channels, out_channels, kernel_size), Transpose(1, 2), # norm & act nn.LayerNorm(out_channels), nn.Mish(), # conv2 Transpose(1, 2), CausalConv1d(out_channels, out_channels, kernel_size), Transpose(1, 2), ) def forward(self, x: torch.Tensor, mask: torch.Tensor = None): """ Args: x: shape (b, t, c) mask: shape (b, t, 1) """ if mask is not None: x = x * mask x = self.block(x) if mask is not None: x = x * mask return x def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None): """ Args: x: shape (b, dt, c) cnn_cache: shape (b, c1+c2, 2) """ if cnn_cache is not None: cnn_cache1, cnn_cache2 = cnn_cache.split((self.in_channels, self.out_channels), dim=1) else: cnn_cache1, cnn_cache2 = None, None x = self.block[0](x) x, new_cnn_cache1 = self.block[1].forward_chunk(x, cnn_cache1) x = self.block[2:6](x) x, new_cnn_cache2 = self.block[6].forward_chunk(x, cnn_cache2) x = self.block[7](x) new_cnn_cache = torch.cat((new_cnn_cache1, new_cnn_cache2), dim=1) return x, new_cnn_cache class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True) ) def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor): """Args x: shape (b, t, c) c: shape (b, 1, c) attn_mask: shape (b, t, t), bool type attention mask """ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \ = self.adaLN_modulation(c).chunk(9, dim=-1) # attention x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask) # conv x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv)) # mlp x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x def forward_chunk(self, x: torch.Tensor, c: torch.Tensor, cnn_cache: torch.Tensor=None, att_cache: torch.Tensor=None, mask: torch.Tensor=None): """ Args: x: shape (b, dt, c) c: shape (b, 1, c) cnn_cache: shape (b, c1+c2, 2) att_cache: shape (b, nh, t, c * 2) """ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \ = self.adaLN_modulation(c).chunk(9, dim=-1) # attention x_att, new_att_cache = self.attn.forward_chunk(modulate(self.norm1(x), shift_msa, scale_msa), att_cache, mask) x = x + gate_msa * x_att # conv x_conv, new_cnn_cache = self.conv.forward_chunk(modulate(self.norm3(x), shift_conv, scale_conv), cnn_cache) x = x + gate_conv * x_conv # mlp x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x, new_cnn_cache, new_att_cache class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, out_channels): super().__init__() self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, out_channels, bias=True) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiT(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__( self, in_channels: int, out_channels: int, mlp_ratio: float = 4.0, depth: int = 28, num_heads: int = 8, head_dim: int = 64, hidden_size: int = 256, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.t_embedder = TimestepEmbedder(hidden_size) self.in_proj = nn.Linear(in_channels, hidden_size) self.blocks = nn.ModuleList([ DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth) ]) self.final_layer = FinalLayer(hidden_size, self.out_channels) self.initialize_weights() self.enable_cuda_graph = False self.use_cuda_graph = False self.graph_chunk = {} self.inference_buffers_chunk = {} self.max_size_chunk = {} self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False) self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False) def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def _init_cuda_graph_chunk(self): # get dtype, device from registered buffer dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device # init cuda graph for streaming forward with torch.no_grad(): for chunk_size in [30, 48, 96]: if chunk_size == 30 or chunk_size == 48: max_size = 500 self.max_size_chunk[chunk_size] = max_size else: max_size = 1000 self.max_size_chunk[chunk_size] = max_size static_x1 = torch.zeros((2, 320, chunk_size), dtype=dtype, device=device) static_t1 = torch.zeros((2, 1, 512), dtype=dtype, device=device) static_mask1 = torch.ones((2, chunk_size, max_size+chunk_size), dtype=torch.bool, device=device) static_att_cache = torch.zeros((16, 2, 8, max_size, 128), dtype=dtype, device=device) static_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device) static_inputs1 = [ static_x1, static_t1, static_mask1, static_cnn_cache, static_att_cache, ] static_new_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device) static_new_att_cache = torch.zeros((16, 2, 8, max_size+chunk_size, 128), dtype=dtype, device=device) self.blocks_forward_chunk( static_inputs1[0], static_inputs1[1], static_inputs1[2], static_inputs1[3], static_inputs1[4], static_new_cnn_cache, static_new_att_cache) graph_chunk = torch.cuda.CUDAGraph() with torch.cuda.graph(graph_chunk): static_out1 = self.blocks_forward_chunk(static_x1, static_t1, static_mask1, static_cnn_cache, static_att_cache, static_new_cnn_cache, static_new_att_cache) static_outputs1 = [static_out1, static_new_cnn_cache, static_new_att_cache] self.inference_buffers_chunk[chunk_size] = { 'static_inputs': static_inputs1, 'static_outputs': static_outputs1 } self.graph_chunk[chunk_size] = graph_chunk def _init_cuda_graph_all(self): self._init_cuda_graph_chunk() self.use_cuda_graph = True print(f"CUDA Graph initialized successfully for chunk decoder") def forward(self, x, mask, mu, t, spks=None, cond=None): """Args: x: shape (b, c, t) mask: shape (b, 1, t) t: shape (b,) spks: shape (b, c) cond: shape (b, c, t) """ # (sfy) chunk training strategy should not be open-sourced # time t = self.t_embedder(t).unsqueeze(1) # (b, 1, c) x = pack([x, mu], "b * t")[0] if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) x = pack([x, spks], "b * t")[0] if cond is not None: x = pack([x, cond], "b * t")[0] return self.blocks_forward(x, t, mask) def blocks_forward(self, x, t, mask): x = x.transpose(1, 2) attn_mask = mask.bool() x = self.in_proj(x) for block in self.blocks: x = block(x, t, attn_mask) x = self.final_layer(x, t) x = x.transpose(1, 2) return x def forward_chunk(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, spks: torch.Tensor, cond: torch.Tensor, cnn_cache: torch.Tensor = None, att_cache: torch.Tensor = None, ): """ Args: x: shape (b, dt, c) mu: shape (b, dt, c) t: shape (b,) spks: shape (b, c) cond: shape (b, dt, c) cnn_cache: shape (depth, b, c1+c2, 2) att_cache: shape (depth, b, nh, t, c * 2) """ # time t = self.t_embedder(t).unsqueeze(1) # (b, 1, c) x = pack([x, mu], "b * t")[0] if spks is not None: spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) x = pack([x, spks], "b * t")[0] if cond is not None: x = pack([x, cond], "b * t")[0] # create fake cache if cnn_cache is None: cnn_cache = [None] * len(self.blocks) if att_cache is None: att_cache = [None] * len(self.blocks) if att_cache[0] is not None: last_att_len = att_cache.shape[3] else: last_att_len = 0 chunk_size = x.shape[2] mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device) if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]: padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device) padded_mask[:, :, :mask.shape[-1]] = mask padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device) padded_att_cache[:, :, :, :last_att_len, :] = att_cache self.inference_buffers_chunk[chunk_size]['static_inputs'][0].copy_(x) self.inference_buffers_chunk[chunk_size]['static_inputs'][1].copy_(t) self.inference_buffers_chunk[chunk_size]['static_inputs'][2].copy_(padded_mask) self.inference_buffers_chunk[chunk_size]['static_inputs'][3].copy_(cnn_cache) self.inference_buffers_chunk[chunk_size]['static_inputs'][4].copy_(padded_att_cache) self.graph_chunk[chunk_size].replay() x = self.inference_buffers_chunk[chunk_size]['static_outputs'][0][:, :, :chunk_size] new_cnn_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][1] new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :] else: mask = None x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer) new_cnn_cache = self.cnn_cache_buffer new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :] return x, new_cnn_cache, new_att_cache def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None): x = x.transpose(1, 2) x = self.in_proj(x) for b_idx, block in enumerate(self.blocks): x, this_new_cnn_cache, this_new_att_cache \ = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask) cnn_cache_buffer[b_idx] = this_new_cnn_cache att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache x = self.final_layer(x, t) x = x.transpose(1, 2) return x