# python3.7 """Collects all available models together.""" from .model_zoo import MODEL_ZOO from .pggan_generator import PGGANGenerator from .pggan_discriminator import PGGANDiscriminator from .stylegan_generator import StyleGANGenerator from .stylegan_discriminator import StyleGANDiscriminator from .stylegan2_generator import StyleGAN2Generator from .stylegan2_discriminator import StyleGAN2Discriminator __all__ = [ 'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator', 'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator', 'build_generator', 'build_discriminator', 'build_model' ] _GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2'] _MODULES_ALLOWED = ['generator', 'discriminator'] def build_generator(gan_type, resolution, **kwargs): """Builds generator by GAN type. Args: gan_type: GAN type to which the generator belong. resolution: Synthesis resolution. **kwargs: Additional arguments to build the generator. Raises: ValueError: If the `gan_type` is not supported. NotImplementedError: If the `gan_type` is not implemented. """ if gan_type not in _GAN_TYPES_ALLOWED: raise ValueError(f'Invalid GAN type: `{gan_type}`!\n' f'Types allowed: {_GAN_TYPES_ALLOWED}.') if gan_type == 'pggan': return PGGANGenerator(resolution, **kwargs) if gan_type == 'stylegan': return StyleGANGenerator(resolution, **kwargs) if gan_type == 'stylegan2': return StyleGAN2Generator(resolution, **kwargs) raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!') def build_discriminator(gan_type, resolution, **kwargs): """Builds discriminator by GAN type. Args: gan_type: GAN type to which the discriminator belong. resolution: Synthesis resolution. **kwargs: Additional arguments to build the discriminator. Raises: ValueError: If the `gan_type` is not supported. NotImplementedError: If the `gan_type` is not implemented. """ if gan_type not in _GAN_TYPES_ALLOWED: raise ValueError(f'Invalid GAN type: `{gan_type}`!\n' f'Types allowed: {_GAN_TYPES_ALLOWED}.') if gan_type == 'pggan': return PGGANDiscriminator(resolution, **kwargs) if gan_type == 'stylegan': return StyleGANDiscriminator(resolution, **kwargs) if gan_type == 'stylegan2': return StyleGAN2Discriminator(resolution, **kwargs) raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!') def build_model(gan_type, module, resolution, **kwargs): """Builds a GAN module (generator/discriminator/etc). Args: gan_type: GAN type to which the model belong. module: GAN module to build, such as generator or discrimiantor. resolution: Synthesis resolution. **kwargs: Additional arguments to build the discriminator. Raises: ValueError: If the `module` is not supported. NotImplementedError: If the `module` is not implemented. """ if module not in _MODULES_ALLOWED: raise ValueError(f'Invalid module: `{module}`!\n' f'Modules allowed: {_MODULES_ALLOWED}.') if module == 'generator': return build_generator(gan_type, resolution, **kwargs) if module == 'discriminator': return build_discriminator(gan_type, resolution, **kwargs) raise NotImplementedError(f'Unsupported module `{module}`!') def parse_gan_type(module): """Parses GAN type of a given module. Args: module: The module to parse GAN type from. Returns: A string, indicating the GAN type. Raises: ValueError: If the GAN type is unknown. """ if isinstance(module, (PGGANGenerator, PGGANDiscriminator)): return 'pggan' if isinstance(module, (StyleGANGenerator, StyleGANDiscriminator)): return 'stylegan' if isinstance(module, (StyleGAN2Generator, StyleGAN2Discriminator)): return 'stylegan2' raise ValueError(f'Unable to parse GAN type from type `{type(module)}`!')