james-oldfield's picture
Upload 194 files
2a76164
raw
history blame
No virus
4.79 kB
# python3.7
"""Collects all available models together."""
from .model_zoo import MODEL_ZOO
from .pggan_generator import PGGANGenerator
from .pggan_discriminator import PGGANDiscriminator
from .stylegan_generator import StyleGANGenerator
from .stylegan_discriminator import StyleGANDiscriminator
from .stylegan2_generator import StyleGAN2Generator
from .stylegan2_discriminator import StyleGAN2Discriminator
from .encoder import EncoderNet
from .perceptual_model import PerceptualModel
__all__ = [
'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator',
'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator',
'EncoderNet', 'PerceptualModel', 'build_generator', 'build_discriminator',
'build_encoder', 'build_perceptual', 'build_model'
]
_GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2']
_MODULES_ALLOWED = ['generator', 'discriminator', 'encoder', 'perceptual']
def build_generator(gan_type, resolution, **kwargs):
"""Builds generator by GAN type.
Args:
gan_type: GAN type to which the generator belong.
resolution: Synthesis resolution.
**kwargs: Additional arguments to build the generator.
Raises:
ValueError: If the `gan_type` is not supported.
NotImplementedError: If the `gan_type` is not implemented.
"""
if gan_type not in _GAN_TYPES_ALLOWED:
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
if gan_type == 'pggan':
return PGGANGenerator(resolution, **kwargs)
if gan_type == 'stylegan':
return StyleGANGenerator(resolution, **kwargs)
if gan_type == 'stylegan2':
return StyleGAN2Generator(resolution, **kwargs)
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
def build_discriminator(gan_type, resolution, **kwargs):
"""Builds discriminator by GAN type.
Args:
gan_type: GAN type to which the discriminator belong.
resolution: Synthesis resolution.
**kwargs: Additional arguments to build the discriminator.
Raises:
ValueError: If the `gan_type` is not supported.
NotImplementedError: If the `gan_type` is not implemented.
"""
if gan_type not in _GAN_TYPES_ALLOWED:
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
if gan_type == 'pggan':
return PGGANDiscriminator(resolution, **kwargs)
if gan_type == 'stylegan':
return StyleGANDiscriminator(resolution, **kwargs)
if gan_type == 'stylegan2':
return StyleGAN2Discriminator(resolution, **kwargs)
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
def build_encoder(gan_type, resolution, **kwargs):
"""Builds encoder by GAN type.
Args:
gan_type: GAN type to which the encoder belong.
resolution: Input resolution for encoder.
**kwargs: Additional arguments to build the encoder.
Raises:
ValueError: If the `gan_type` is not supported.
NotImplementedError: If the `gan_type` is not implemented.
"""
if gan_type not in _GAN_TYPES_ALLOWED:
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
if gan_type in ['stylegan', 'stylegan2']:
return EncoderNet(resolution, **kwargs)
raise NotImplementedError(f'Unsupported GAN type `{gan_type}` for encoder!')
def build_perceptual(**kwargs):
"""Builds perceptual model.
Args:
**kwargs: Additional arguments to build the encoder.
"""
return PerceptualModel(**kwargs)
def build_model(gan_type, module, resolution, **kwargs):
"""Builds a GAN module (generator/discriminator/etc).
Args:
gan_type: GAN type to which the model belong.
module: GAN module to build, such as generator or discrimiantor.
resolution: Synthesis resolution.
**kwargs: Additional arguments to build the discriminator.
Raises:
ValueError: If the `module` is not supported.
NotImplementedError: If the `module` is not implemented.
"""
if module not in _MODULES_ALLOWED:
raise ValueError(f'Invalid module: `{module}`!\n'
f'Modules allowed: {_MODULES_ALLOWED}.')
if module == 'generator':
return build_generator(gan_type, resolution, **kwargs)
if module == 'discriminator':
return build_discriminator(gan_type, resolution, **kwargs)
if module == 'encoder':
return build_encoder(gan_type, resolution, **kwargs)
if module == 'perceptual':
return build_perceptual(**kwargs)
raise NotImplementedError(f'Unsupported module `{module}`!')