from typing import Tuple from transformers import PretrainedConfig class VQGANConfig(PretrainedConfig): def __init__( self, ch: int = 128, out_ch: int = 3, in_channels: int = 3, num_res_blocks: int = 2, resolution: int = 256, z_channels: int = 256, ch_mult: Tuple = (1, 1, 2, 2, 4), attn_resolutions: int = (16,), n_embed: int = 1024, embed_dim: int = 256, dropout: float = 0.0, double_z: bool = False, resamp_with_conv: bool = True, give_pre_end: bool = False, **kwargs, ): super().__init__(**kwargs) self.ch = ch self.out_ch = out_ch self.in_channels = in_channels self.num_res_blocks = num_res_blocks self.resolution = resolution self.z_channels = z_channels self.ch_mult = list(ch_mult) self.attn_resolutions = list(attn_resolutions) self.n_embed = n_embed self.embed_dim = embed_dim self.dropout = dropout self.double_z = double_z self.resamp_with_conv = resamp_with_conv self.give_pre_end = give_pre_end self.num_resolutions = len(ch_mult)