# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # GLIDE: https://github.com/openai/glide-text2im # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py # -------------------------------------------------------- import math import ipdb # noqa: F401 import numpy as np import torch import torch.nn as nn from timm.models.vision_transformer import Attention, Mlp, PatchEmbed from diffusionsfm.model.memory_efficient_attention import MEAttention def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) ################################################################################# # Embedding Layers for Timesteps and Class Labels # ################################################################################# class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb ################################################################################# # Core DiT Model # ################################################################################# class DiTBlock(nn.Module): """ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__( self, hidden_size, num_heads, mlp_ratio=4.0, use_xformers_attention=False, **block_kwargs ): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) attn = MEAttention if use_xformers_attention else Attention self.attn = attn( hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) def approx_gelu(): return nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c): ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, ) = self.adaLN_modulation(c).chunk(6, dim=1) x = x + gate_msa.unsqueeze(1) * self.attn( modulate(self.norm1(x), shift_msa, scale_msa) ) x = x + gate_mlp.unsqueeze(1) * self.mlp( modulate(self.norm2(x), shift_mlp, scale_mlp) ) return x class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear( hidden_size, patch_size * patch_size * out_channels, bias=True ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class DiT(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__( self, in_channels=442, out_channels=6, width=16, hidden_size=1152, depth=8, num_heads=16, mlp_ratio=4.0, max_num_images=8, P=1, within_image=False, ): super().__init__() self.num_heads = num_heads self.in_channels = in_channels self.out_channels = out_channels self.width = width self.hidden_size = hidden_size self.max_num_images = max_num_images self.P = P self.within_image = within_image # self.x_embedder = nn.Linear(in_channels, hidden_size) # self.x_embedder = PatchEmbed(in_channels, hidden_size, kernel_size=P, hidden_size=P) self.x_embedder = PatchEmbed( img_size=self.width, patch_size=self.P, in_chans=in_channels, embed_dim=hidden_size, bias=True, flatten=False, ) self.x_pos_enc = FeaturePositionalEncoding( max_num_images, hidden_size, width**2, P=self.P ) self.t_embedder = TimestepEmbedder(hidden_size) try: import xformers use_xformers_attention = True except ImportError: # xformers not available use_xformers_attention = False self.blocks = nn.ModuleList( [ DiTBlock( hidden_size, num_heads, mlp_ratio=mlp_ratio, use_xformers_attention=use_xformers_attention, ) for _ in range(depth) ] ) self.final_layer = FinalLayer(hidden_size, P, out_channels) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) # nn.init.constant_(self.final_layer.linear.weight, 0) # nn.init.constant_(self.final_layer.linear.bias, 0) def unpatchify(self, x): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels p = self.x_embedder.patch_size[0] h = w = int(x.shape[1] ** 0.5) # print("unpatchify", c, p, h, w, x.shape) # assert h * w == x.shape[2] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum("nhwpqc->nhpwqc", x) imgs = x.reshape(shape=(x.shape[0], h * p, h * p, c)) return imgs def forward( self, x, t, return_dpt_activations=False, multiview_unconditional=False, ): """ Args: x: Image/Ray features (B, N, C, H, W). t: Timesteps (N,). Returns: (B, N, D, H, W) """ B, N, c, h, w = x.shape P = self.P x = x.reshape((B * N, c, h, w)) # (B * N, C, H, W) x = self.x_embedder(x) # (B * N, C, H / P, W / P) x = x.permute(0, 2, 3, 1) # (B * N, H / P, W / P, C) # (B, N, H / P, W / P, C) x = x.reshape((B, N, h // P, w // P, self.hidden_size)) x = self.x_pos_enc(x) # (B, N, H * W / P ** 2, C) # TODO: fix positional encoding to work with (N, C, H, W) format. # Eval time, we get a scalar t if x.shape[0] != t.shape[0] and t.shape[0] == 1: t = t.repeat_interleave(B) if self.within_image or multiview_unconditional: t_within = t.repeat_interleave(N) t_within = self.t_embedder(t_within) t = self.t_embedder(t) dpt_activations = [] for i, block in enumerate(self.blocks): # Within image block if (self.within_image and i % 2 == 0) or multiview_unconditional: x = x.reshape((B * N, h * w // P**2, self.hidden_size)) x = block(x, t_within) # All patches block # Final layer is an all patches layer else: x = x.reshape((B, N * h * w // P**2, self.hidden_size)) x = block(x, t) # (N, T, D) if return_dpt_activations and i % 4 == 3: x_prime = x.reshape(B, N, h, w, self.hidden_size) x_prime = x.reshape(B * N, h, w, self.hidden_size) x_prime = x_prime.permute((0, 3, 1, 2)) dpt_activations.append(x_prime) # Reshape the output back to original shape if multiview_unconditional: x = x.reshape((B, N * h * w // P**2, self.hidden_size)) # (B, N * H * W / P ** 2, D) x = self.final_layer( x, t ) # (B, N * H * W / P ** 2, 6 * P ** 2) or (N, T, patch_size ** 2 * out_channels) x = x.reshape((B * N, w * w // P**2, self.out_channels * P**2)) x = self.unpatchify(x) # (B * N, H, W, C) x = x.reshape((B, N) + x.shape[1:]) x = x.permute(0, 1, 4, 2, 3) # (B, N, C, H, W) if return_dpt_activations: return dpt_activations[:4] return x class FeaturePositionalEncoding(nn.Module): def _get_sinusoid_encoding_table(self, n_position, d_hid, base): """Sinusoid position encoding table""" def get_position_angle_vec(position): return [ position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ] sinusoid_table = np.array( [get_position_angle_vec(pos_i) for pos_i in range(n_position)] ) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.FloatTensor(sinusoid_table).unsqueeze(0) def __init__(self, max_num_images=8, feature_dim=1152, num_patches=256, P=1): super().__init__() self.max_num_images = max_num_images self.feature_dim = feature_dim self.P = P self.num_patches = num_patches // self.P**2 self.register_buffer( "image_pos_table", self._get_sinusoid_encoding_table( self.max_num_images, self.feature_dim, 10000 ), ) self.register_buffer( "token_pos_table", self._get_sinusoid_encoding_table( self.num_patches, self.feature_dim, 70007 ), ) def forward(self, x): batch_size = x.shape[0] num_images = x.shape[1] x = x.reshape(batch_size, num_images, self.num_patches, self.feature_dim) # To encode image index pe1 = self.image_pos_table[:, :num_images].clone().detach() pe1 = pe1.reshape((1, num_images, 1, self.feature_dim)) pe1 = pe1.repeat((batch_size, 1, self.num_patches, 1)) # To encode patch index pe2 = self.token_pos_table.clone().detach() pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim)) pe2 = pe2.repeat((batch_size, num_images, 1, 1)) x_pe = x + pe1 + pe2 x_pe = x_pe.reshape( (batch_size, num_images * self.num_patches, self.feature_dim) ) return x_pe def forward_unet(self, x, B, N): D = int(self.num_patches**0.5) # x should be (B, N, T, D, D) x = x.permute((0, 2, 3, 1)) x = x.reshape(B, N, self.num_patches, self.feature_dim) # To encode image index pe1 = self.image_pos_table[:, :N].clone().detach() pe1 = pe1.reshape((1, N, 1, self.feature_dim)) pe1 = pe1.repeat((B, 1, self.num_patches, 1)) # To encode patch index pe2 = self.token_pos_table.clone().detach() pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim)) pe2 = pe2.repeat((B, N, 1, 1)) x_pe = x + pe1 + pe2 x_pe = x_pe.reshape((B * N, D, D, self.feature_dim)) x_pe = x_pe.permute((0, 3, 1, 2)) return x_pe