import torch from transformers import PretrainedConfig from typing import List class STDiTConfig(PretrainedConfig): model_type = "stdit" def __init__( self, input_size=(1, 32, 32), in_channels=4, patch_size=(1, 2, 2), hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path=0.0, no_temporal_pos_emb=False, caption_channels=4096, model_max_length=120, space_scale=1.0, time_scale=1.0, freeze=None, enable_flash_attn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, **kwargs, ): self.input_size = input_size self.in_channels = in_channels self.patch_size = patch_size self.hidden_size = hidden_size self.depth = depth self.num_heads = num_heads self.mlp_ratio = mlp_ratio self.class_dropout_prob = class_dropout_prob self.pred_sigma = pred_sigma self.drop_path = drop_path self.no_temporal_pos_emb = no_temporal_pos_emb self.caption_channels = caption_channels self.model_max_length = model_max_length self.space_scale = space_scale self.time_scale = time_scale self.freeze = freeze self.enable_flash_attn = enable_flash_attn self.enable_layernorm_kernel = enable_layernorm_kernel self.enable_sequence_parallelism = enable_sequence_parallelism super().__init__(**kwargs)