import torch from transformers import PretrainedConfig class STDiT2Config(PretrainedConfig): model_type = "stdit2" def __init__( self, input_size=(None, None, None), input_sq_size=32, in_channels=4, patch_size=(1, 2, 2), hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, pred_sigma=True, drop_path=0.0, no_temporal_pos_emb=False, caption_channels=4096, model_max_length=120, freeze=None, qk_norm=False, enable_flash_attn=False, enable_layernorm_kernel=False, enable_sequence_parallelism=False, **kwargs, ): self.input_size = input_size self.input_sq_size = input_sq_size self.in_channels = in_channels self.patch_size = patch_size self.hidden_size = hidden_size self.depth = depth self.num_heads = num_heads self.mlp_ratio = mlp_ratio self.class_dropout_prob = class_dropout_prob self.pred_sigma = pred_sigma self.drop_path = drop_path self.no_temporal_pos_emb = no_temporal_pos_emb self.caption_channels = caption_channels self.model_max_length = model_max_length self.freeze = freeze self.qk_norm = qk_norm self.enable_flash_attn = enable_flash_attn self.enable_layernorm_kernel = enable_layernorm_kernel self.enable_sequence_parallelism = enable_sequence_parallelism super().__init__(**kwargs)