import torch.nn as nn
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from pdb import set_trace as st

from ldm.modules.attention import MemoryEfficientCrossAttention
from .dit_models_xformers import DiT, get_2d_sincos_pos_embed, ImageCondDiTBlock, FinalLayer, CaptionEmbedder, approx_gelu, ImageCondDiTBlockPixelArt, t2i_modulate, ImageCondDiTBlockPixelArtRMSNorm, T2IFinalLayer, ImageCondDiTBlockPixelArtRMSNormNoClip
from timm.models.vision_transformer import Mlp

try:
    from apex.normalization import FusedLayerNorm as LayerNorm
    from apex.normalization import FusedRMSNorm as RMSNorm
except:
    from torch.nn import LayerNorm
    from dit.norm import RMSNorm

# from vit.vit_triplane import XYZPosEmbed


class DiT_I23D(DiT):
    # DiT with 3D_aware operations
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        mixing_logit_init=-3,
        mixed_prediction=True,
        context_dim=False,
        pooling_ctx_dim=768,
        roll_out=False,
        vit_blk=ImageCondDiTBlock,
        final_layer_blk=T2IFinalLayer,
    ):
        # st()

        super().__init__(input_size, patch_size, in_channels, hidden_size,
                         depth, num_heads, mlp_ratio, class_dropout_prob,
                         num_classes, learn_sigma, mixing_logit_init,
                         mixed_prediction, context_dim, roll_out, vit_blk,
                         T2IFinalLayer)

        assert self.roll_out

        # if context_dim is not None:
        # self.dino_proj = CaptionEmbedder(context_dim,
        self.clip_ctx_dim = 1024 # vit-l
        # self.dino_proj = CaptionEmbedder(self.clip_ctx_dim, # ! dino-vitl/14 here, for img-cond
        self.dino_proj = CaptionEmbedder(context_dim, # ! dino-vitb/14 here, for MV-cond. hard coded for now...
        # self.dino_proj = CaptionEmbedder(1024, # ! dino-vitb/14 here, for MV-cond. hard coded for now...
                                                hidden_size,
                                                act_layer=approx_gelu)

        self.clip_spatial_proj = CaptionEmbedder(1024, # clip_I-L
                                                hidden_size,
                                                act_layer=approx_gelu)

    def init_PE_3D_aware(self):

        self.pos_embed = nn.Parameter(torch.zeros(
            1, self.plane_n * self.x_embedder.num_patches, self.embed_dim),
                                      requires_grad=False)

        # Initialize (and freeze) pos_embed by sin-cos embedding:
        p = int(self.x_embedder.num_patches**0.5)
        D = self.pos_embed.shape[-1]
        grid_size = (self.plane_n, p * p)  # B n HW C

        pos_embed = get_2d_sincos_pos_embed(D, grid_size).reshape(
            self.plane_n * p * p, D)  # H*W, D

        self.pos_embed.data.copy_(
            torch.from_numpy(pos_embed).float().unsqueeze(0))

    def initialize_weights(self):
        super().initialize_weights()

        # ! add 3d-aware PE
        self.init_PE_3D_aware()

    def forward(self,
                x,
                timesteps=None,
                context=None,
                y=None,
                get_attr='',
                **kwargs):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # t = timesteps
        assert isinstance(context, dict)
        # context = self.clip_text_proj(context)
        clip_cls_token = self.clip_text_proj(context['vector'])
        clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])

        t = self.t_embedder(timesteps) + clip_cls_token  # (N, D)
        # ! todo, return spatial clip features.

        # if self.roll_out:  # !
        x = rearrange(x, 'b (c n) h w->(b n) c h w',
                      n=3)  # downsample with same conv
        x = self.x_embedder(x)  # (b n) c h/f w/f

        x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
        x = x + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2

        # if self.roll_out:  # ! roll-out in the L dim, not B dim. add condition to all tokens.
        # x = rearrange(x, '(b n) l c ->b (n l) c', n=3)

        # assert context.ndim == 2
        # if isinstance(context, dict):
        #     context = context['crossattn']  # sgm conditioner compat


        # c = t + context
        # else:
        # c = t  # BS 1024

        for blk_idx, block in enumerate(self.blocks):
            x = block(x, t, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token)  # (N, T, D)

        # todo later
        x = self.final_layer(x, t)  # (N, T, patch_size ** 2 * out_channels)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)

        x = self.unpatchify(x)  # (N, out_channels, H, W)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
            # x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)

        # cast to float32 for better accuracy
        x = x.to(torch.float32).contiguous()

        return x

    # ! compat issue
    def forward_with_cfg(self, x, t, context, cfg_scale):
        """
        Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
        """
        # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
        # half = x[: len(x) // 2]
        # combined = torch.cat([half, half], dim=0)
        eps = self.forward(x, t, context)
        # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        # eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return eps




class DiT_I23D_PixelArt(DiT_I23D):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        mixing_logit_init=-3,
        mixed_prediction=True,
        context_dim=False,
        pooling_ctx_dim=768,
        roll_out=False,
        vit_blk=ImageCondDiTBlockPixelArtRMSNorm,
        final_layer_blk=FinalLayer,
    ):
        # st()
        super().__init__(input_size, patch_size, in_channels, hidden_size,
                         depth, num_heads, mlp_ratio, class_dropout_prob,
                         num_classes, learn_sigma, mixing_logit_init,
                        #  mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
                         mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk,
                         final_layer_blk)

        # ! a shared one
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))

        # ! single
        nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.adaLN_modulation[-1].bias, 0)

        del self.clip_text_proj
        self.cap_embedder = nn.Sequential( # TODO, init with zero here.
            LayerNorm(pooling_ctx_dim),
            nn.Linear(
                pooling_ctx_dim,
                hidden_size,
            ),
        )

        nn.init.constant_(self.cap_embedder[-1].weight, 0)
        nn.init.constant_(self.cap_embedder[-1].bias, 0)

        print(self) # check model arch

        self.attention_y_norm = RMSNorm(
            1024, eps=1e-5
        )  # https://github.com/Alpha-VLLM/Lumina-T2X/blob/0c8dd6a07a3b7c18da3d91f37b1e00e7ae661293/lumina_t2i/models/model.py#L570C9-L570C61


    def forward(self,
                x,
                timesteps=None,
                context=None,
                y=None,
                get_attr='',
                **kwargs):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # t = timesteps
        assert isinstance(context, dict)
        # context = self.clip_text_proj(context)
        clip_cls_token = self.cap_embedder(context['vector'])
        clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:])
        clip_spatial_token = self.attention_y_norm(clip_spatial_token) # avoid re-normalization in each blk

        t = self.t_embedder(timesteps) + clip_cls_token  # (N, D)
        t0 = self.adaLN_modulation(t) # single-adaLN, B 6144

        # if self.roll_out:  # !
        x = rearrange(x, 'b (c n) h w->(b n) c h w',
                      n=3)  # downsample with same conv
        x = self.x_embedder(x)  # (b n) c h/f w/f

        x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
        x = x + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2

        # if self.roll_out:  # ! roll-out in the L dim, not B dim. add condition to all tokens.
        # x = rearrange(x, '(b n) l c ->b (n l) c', n=3)

        # assert context.ndim == 2
        # if isinstance(context, dict):
        #     context = context['crossattn']  # sgm conditioner compat


        # c = t + context
        # else:
        # c = t  # BS 1024

        for blk_idx, block in enumerate(self.blocks):
            x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token)  # (N, T, D)

        # todo later
        x = self.final_layer(x, t)  # (N, T, patch_size ** 2 * out_channels)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)

        x = self.unpatchify(x)  # (N, out_channels, H, W)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)
            # x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3)

        # cast to float32 for better accuracy
        x = x.to(torch.float32).contiguous()

        return x


class DiT_I23D_PixelArt_MVCond(DiT_I23D_PixelArt):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        mixing_logit_init=-3,
        mixed_prediction=True,
        context_dim=False,
        pooling_ctx_dim=768,
        roll_out=False,
        vit_blk=ImageCondDiTBlockPixelArt,
        final_layer_blk=FinalLayer,
    ):
        super().__init__(input_size, patch_size, in_channels, hidden_size,
                         depth, num_heads, mlp_ratio, class_dropout_prob,
                         num_classes, learn_sigma, mixing_logit_init,
                        #  mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
                         mixed_prediction, context_dim,
                         pooling_ctx_dim, roll_out, ImageCondDiTBlockPixelArtRMSNorm,
                         final_layer_blk)


        # support multi-view img condition
        # DINO handles global pooling here; clip takes care of camera-cond with ModLN
        # Input DINO concat also + global pool. InstantMesh adopts DINO (but CA).
        # expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size.
        del self.dino_proj

    def forward(self,
                x,
                timesteps=None,
                context=None,
                y=None,
                get_attr='',
                **kwargs):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # t = timesteps
        assert isinstance(context, dict)

        # st()
        # (Pdb) p context.keys()
        # dict_keys(['crossattn', 'vector', 'concat'])
        # (Pdb) p context['vector'].shape
        # torch.Size([2, 768])
        # (Pdb) p context['crossattn'].shape
        # torch.Size([2, 256, 1024])
        # (Pdb) p context['concat'].shape
        # torch.Size([2, 4, 256, 768]) # mv dino spatial features

        # ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
        # DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
        clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
        dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.

        t = self.t_embedder(timesteps) + clip_cls_token  # (N, D)
        t0 = self.adaLN_modulation(t) # single-adaLN, B 6144

        # if self.roll_out:  # !
        x = rearrange(x, 'b (c n) h w->(b n) c h w',
                      n=3)  # downsample with same conv
        x = self.x_embedder(x)  # (b n) c h/f w/f

        x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
        x = x + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2

        for blk_idx, block in enumerate(self.blocks):
            # x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token)  # (N, T, D)
            # ! DINO tokens for CA, CLIP tokens for append here.
            x = block(x, t0, dino_spatial_token=clip_spatial_token, clip_spatial_token=dino_spatial_token)  # (N, T, D)

        # todo later
        x = self.final_layer(x, t)  # (N, T, patch_size ** 2 * out_channels)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)

        x = self.unpatchify(x)  # (N, out_channels, H, W)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)

        x = x.to(torch.float32).contiguous()

        return x


class DiT_I23D_PixelArt_MVCond_noClip(DiT_I23D_PixelArt):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        mixing_logit_init=-3,
        mixed_prediction=True,
        context_dim=False,
        pooling_ctx_dim=768,
        roll_out=False,
        vit_blk=ImageCondDiTBlockPixelArt,
        final_layer_blk=FinalLayer,
    ):
        super().__init__(input_size, patch_size, in_channels, hidden_size,
                         depth, num_heads, mlp_ratio, class_dropout_prob,
                         num_classes, learn_sigma, mixing_logit_init,
                        #  mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
                         mixed_prediction, context_dim,
                         pooling_ctx_dim, roll_out, 
                         ImageCondDiTBlockPixelArtRMSNormNoClip,
                         final_layer_blk)


        # support multi-view img condition
        # DINO handles global pooling here; clip takes care of camera-cond with ModLN
        # Input DINO concat also + global pool. InstantMesh adopts DINO (but CA).
        # expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size.

        del self.dino_proj
        del self.clip_spatial_proj, self.cap_embedder # no clip required

    def forward(self,
                x,
                timesteps=None,
                context=None,
                y=None,
                get_attr='',
                **kwargs):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # t = timesteps
        assert isinstance(context, dict)

        # st()
        # (Pdb) p context.keys()
        # dict_keys(['crossattn', 'vector', 'concat'])
        # (Pdb) p context['vector'].shape
        # torch.Size([2, 768])
        # (Pdb) p context['crossattn'].shape
        # torch.Size([2, 256, 1024])
        # (Pdb) p context['concat'].shape
        # torch.Size([2, 4, 256, 768]) # mv dino spatial features

        # ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
        # DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
        # clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
        dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.

        # t = self.t_embedder(timesteps) + clip_cls_token  # (N, D)
        t = self.t_embedder(timesteps)
        t0 = self.adaLN_modulation(t) # single-adaLN, B 6144

        # if self.roll_out:  # !
        x = rearrange(x, 'b (c n) h w->(b n) c h w',
                      n=3)  # downsample with same conv
        x = self.x_embedder(x)  # (b n) c h/f w/f

        x = rearrange(x, '(b n) l c -> b (n l) c', n=3)
        x = x + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2

        for blk_idx, block in enumerate(self.blocks):
            # x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token)  # (N, T, D)
            # ! DINO tokens for CA, CLIP tokens for append here.
            x = block(x, t0, dino_spatial_token=dino_spatial_token)  # (N, T, D)

        # todo later
        x = self.final_layer(x, t)  # (N, T, patch_size ** 2 * out_channels)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, 'b (n l) c ->(b n) l c', n=3)

        x = self.unpatchify(x)  # (N, out_channels, H, W)

        if self.roll_out:  # move n from L to B axis
            x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3)

        x = x.to(torch.float32).contiguous()

        return x





# pcd-structured latent ddpm

class DiT_pcd_I23D_PixelArt_MVCond(DiT_I23D_PixelArt_MVCond):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        mixing_logit_init=-3,
        mixed_prediction=True,
        context_dim=False,
        pooling_ctx_dim=768,
        roll_out=False,
        vit_blk=ImageCondDiTBlockPixelArt,
        final_layer_blk=FinalLayer,
    ):
        super().__init__(input_size, patch_size, in_channels, hidden_size,
                         depth, num_heads, mlp_ratio, class_dropout_prob,
                         num_classes, learn_sigma, mixing_logit_init,
                        #  mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt,
                         mixed_prediction, context_dim,
                         pooling_ctx_dim,
                          roll_out, ImageCondDiTBlockPixelArtRMSNorm,
                         final_layer_blk)
        # ! first, normalize xyz from [-0.45,0.45] to [-1,1]
        # Then, encode xyz with point fourier feat + MLP projection, serves as PE here.
        # a separate MLP for the KL feature
        # add them together in the feature space
        # use a single MLP (final_layer) to map them back to 16 + 3 dims.
        self.x_embedder = Mlp(in_features=in_channels,
                          hidden_features=hidden_size,
                          out_features=hidden_size,
                          act_layer=approx_gelu,
                          drop=0)
        del self.pos_embed


    def forward(self,
                x,
                timesteps=None,
                context=None,
                y=None,
                get_attr='',
                **kwargs):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # t = timesteps
        assert isinstance(context, dict)

        # st()
        # (Pdb) p context.keys()
        # dict_keys(['crossattn', 'vector', 'concat'])
        # (Pdb) p context['vector'].shape
        # torch.Size([2, 768])
        # (Pdb) p context['crossattn'].shape
        # torch.Size([2, 256, 1024])
        # (Pdb) p context['concat'].shape
        # torch.Size([2, 4, 256, 768]) # mv dino spatial features

        # ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj)
        # DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk)
        clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit
        dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features.

        t = self.t_embedder(timesteps) + clip_cls_token  # (N, D)
        t0 = self.adaLN_modulation(t) # single-adaLN, B 6144

        x = self.x_embedder(x)

        for blk_idx, block in enumerate(self.blocks):
            # x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token)  # (N, T, D)
            # ! DINO tokens for CA, CLIP tokens for append here.
            x = block(x, t0, dino_spatial_token=clip_spatial_token, clip_spatial_token=dino_spatial_token)  # (N, T, D)

        # todo later
        x = self.final_layer(x, t)  # (N, T, patch_size ** 2 * out_channels)

        x = x.to(torch.float32).contiguous()

        return x



#################################################################################
#                                   DiT_I23D Configs                                  #
#################################################################################


def DiT_XL_2(**kwargs):
    return DiT_I23D(depth=28,
                         hidden_size=1152,
                         patch_size=2,
                         num_heads=16,
                         **kwargs)


def DiT_L_2(**kwargs):
    return DiT_I23D(depth=24,
                         hidden_size=1024,
                         patch_size=2,
                         num_heads=16,
                         **kwargs)


def DiT_B_2(**kwargs):
    return DiT_I23D(depth=12,
                         hidden_size=768,
                         patch_size=2,
                         num_heads=12,
                         **kwargs)


def DiT_B_1(**kwargs):
    return DiT_I23D(depth=12,
                         hidden_size=768,
                         patch_size=1,
                         num_heads=12,
                         **kwargs)


def DiT_L_Pixelart_2(**kwargs):
    return DiT_I23D_PixelArt(depth=24,
                         hidden_size=1024,
                         patch_size=2,
                         num_heads=16,
                         **kwargs)


def DiT_B_Pixelart_2(**kwargs):
    return DiT_I23D_PixelArt(depth=12,
                         hidden_size=768,
                         patch_size=2,
                         num_heads=12,
                         **kwargs)

def DiT_L_Pixelart_MV_2(**kwargs):
    return DiT_I23D_PixelArt_MVCond(depth=24,
                         hidden_size=1024,
                         patch_size=2,
                         num_heads=16,
                         **kwargs)

def DiT_L_Pixelart_MV_2_noclip(**kwargs):
    return DiT_I23D_PixelArt_MVCond_noClip(depth=24,
                         hidden_size=1024,
                         patch_size=2,
                         num_heads=16,
                         **kwargs)

def DiT_XL_Pixelart_MV_2(**kwargs):
    return DiT_I23D_PixelArt_MVCond(depth=28,
                         hidden_size=1152,
                         patch_size=2,
                         num_heads=16,
                         **kwargs)



def DiT_B_Pixelart_MV_2(**kwargs):
    return DiT_I23D_PixelArt_MVCond(depth=12,
                         hidden_size=768,
                         patch_size=2,
                         num_heads=12,
                         **kwargs)

# pcd latent 

def DiT_L_Pixelart_MV_pcd(**kwargs):
    return DiT_pcd_I23D_PixelArt_MVCond(depth=24,
                         hidden_size=1024,
                         patch_size=1, # no spatial compression here
                         num_heads=16,
                         **kwargs)



DiT_models = {
    'DiT-XL/2': DiT_XL_2,
    'DiT-L/2': DiT_L_2,
    'DiT-B/2': DiT_B_2,
    'DiT-B/1': DiT_B_1,
    'DiT-PixArt-L/2': DiT_L_Pixelart_2,
    'DiT-PixArt-MV-XL/2': DiT_XL_Pixelart_MV_2,
    # 'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2,
    'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2_noclip,
    'DiT-PixArt-MV-PCD-L': DiT_L_Pixelart_MV_pcd,
    'DiT-PixArt-MV-B/2': DiT_B_Pixelart_MV_2,
    'DiT-PixArt-B/2': DiT_B_Pixelart_2,
}