import json from dataclasses import dataclass from genie.factorization_utils import nth_root from typing import List @dataclass class GenieConfig: num_layers: int num_heads: int d_model: int T: int = 12 # temporal sequence length S: int = 256 # spatial sequence length, e.g. 256 for 16x16 image_vocab_size: int = 262144 # image_vocab_size: number of distinct image tokens; # actual model vocab size is larger to include special (e.g. mask) tokens. use_mup: bool = False dataloader_apply_mask: bool = True # apply mask in dataloader dataloader_apply_corruption: bool = True dataloader_mask_ratio_min: float = 0.2 drop_action_ratio: float = 0.0 # for datasets arch: str = "STTransformerDecoder" random_dummy_action: bool = True # for model # Factorization for large vocabs (e.g. Open-MAGVIT2) num_factored_vocabs: int = 1 factored_vocab_size: int = None # MaskGIT training (all arbitrary numbers) max_corrupt_rate: float = 0.2 # Corrupt all tokens, uniform between [0, max_corrupt_rate] # Case 1: MLM training. # Case 2: Not standard MLM, `non_mlm`. Some earlier frames are left unmasked, as in Copilot4D. non_mlm_ratio: float = 0.2 num_prompt_frames: int = 4 # action related init_actions: bool = False d_action: int = 28 # action dimensions use_actions: bool = True action_domains: List[str] = None d_actions: List[int] = None action_stats: List[List[float]] = None # TODO: is this actually three nested lists? action_network: str = "mlp" shared_action_mlps: bool = True action_contrastive_loss: bool = False jointly_predict_actions: bool = False # jointly predict actions jointly_predict_states: bool = True # jointly predict states action_token_size: int = 64 # images are 16x16 label_drop_prob: float = 0.5 # the drop ratio for action tokens action_loss_weight: float = 0.5 # weight for action loss # Attention qkv_bias: bool = False proj_bias: bool = True attn_drop: float = 0.0 qk_norm: bool = True # MLP mlp_ratio: float = 4.0 mlp_drop: float = 0.0 mlp_bias: bool = True def save_pretrained(self, json_path): with open(json_path, "w") as f: json.dump(vars(self), f) @classmethod def from_pretrained(cls, json_path): with open(json_path, "r") as f: config = json.load(f) return cls(**config) def shallow_copy(self): return GenieConfig(**vars(self)) def __post_init__(self): if self.image_vocab_size == None: self.factored_vocab_size = 64 # dummy else: self.factored_vocab_size = nth_root(self.image_vocab_size, self.num_factored_vocabs) @dataclass class DiffusionGenieConfig(GenieConfig): Diffusion: bool = True # Attention dim: int = 512 dataloader_apply_mask: bool = True # apply mask inside the model dataloader_apply_corruption: bool = False # no need for random corruptions dataloader_mask_ratio_min: float = 0.1 # MLP vae_stride: int = 1 patch_size: int = 1 vae_embed_dim: int = 4 mask_ratio_min: float = 0.7 label_drop_prob: float = 0.1 attn_dropout: float = 0.1 proj_dropout: float = 0.1 buffer_size: int = 64 diffloss_d: int = 4 diffloss_w: int = 1024 # 1024 num_sampling_steps: str = '100' diffusion_batch_mul: int = 1 grad_checkpointing: bool = False use_actions: bool = True jointly_predict_actions: bool = False # jointly predict actions jointly_predict_states: bool = True # jointly predict states action_token_size: int = 64 # images are 16x16 label_drop_prob: float = 0.5 # the drop ratio for action tokens action_loss_weight: float = 1.0 # weight for action loss predict_unmask: bool = False # also predict tokens in unmasked regions maskgit_steps: int = 16 # the mask iterations during inference def shallow_copy(self): return DiffusionGenieConfig(**vars(self)) @dataclass class CogVideoGenieConfig(GenieConfig): CogVideo: bool = True # Attention dim: int = 512 num_attention_heads: int = 30 attention_head_dim: int = 16 time_embed_dim: int = 128 # MLP mlp_ratio: float = 4.0 mlp_drop: float = 0.0 mlp_bias: bool = True