|
from transformers import PretrainedConfig
|
|
|
|
|
|
class LightweightGANConfig(PretrainedConfig):
|
|
model_type = "lightweight-gan"
|
|
|
|
def __init__(
|
|
self,
|
|
image_size=64,
|
|
latent_dim=256,
|
|
fmap_max=512,
|
|
fmap_inverse_coef=12,
|
|
transparent=False,
|
|
greyscale=False,
|
|
attn_res_layers=[32],
|
|
freq_chan_attn=False,
|
|
syncbatchnorm=False,
|
|
antialias=False,
|
|
**kwargs,
|
|
):
|
|
self.image_size = image_size
|
|
self.latent_dim = latent_dim
|
|
self.fmap_max = fmap_max
|
|
self.fmap_inverse_coef = fmap_inverse_coef
|
|
self.transparent = transparent
|
|
self.greyscale = greyscale
|
|
self.attn_res_layers = attn_res_layers
|
|
self.freq_chan_attn = freq_chan_attn
|
|
self.syncbatchnorm = syncbatchnorm
|
|
self.antialias = antialias
|
|
super().__init__(**kwargs)
|
|
|