from typing import Tuple | |
from transformers import PretrainedConfig | |
class VQGANConfig(PretrainedConfig): | |
def __init__( | |
self, | |
ch: int = 128, | |
out_ch: int = 3, | |
in_channels: int = 3, | |
num_res_blocks: int = 2, | |
resolution: int = 256, | |
z_channels: int = 256, | |
ch_mult: Tuple = (1, 1, 2, 2, 4), | |
attn_resolutions: int = (16,), | |
n_embed: int = 1024, | |
embed_dim: int = 256, | |
dropout: float = 0.0, | |
double_z: bool = False, | |
resamp_with_conv: bool = True, | |
give_pre_end: bool = False, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.ch = ch | |
self.out_ch = out_ch | |
self.in_channels = in_channels | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.z_channels = z_channels | |
self.ch_mult = list(ch_mult) | |
self.attn_resolutions = list(attn_resolutions) | |
self.n_embed = n_embed | |
self.embed_dim = embed_dim | |
self.dropout = dropout | |
self.double_z = double_z | |
self.resamp_with_conv = resamp_with_conv | |
self.give_pre_end = give_pre_end | |
self.num_resolutions = len(ch_mult) |