BerfScene / models /__init__.py
3v324v23's picture
init
2f85de4
raw
history blame
2.4 kB
# python3.7
"""Collects all models."""
from .pggan_generator import PGGANGenerator
from .pggan_discriminator import PGGANDiscriminator
from .stylegan_generator import StyleGANGenerator
from .stylegan_discriminator import StyleGANDiscriminator
from .stylegan2_generator import StyleGAN2Generator
from .stylegan2_discriminator import StyleGAN2Discriminator
from .stylegan3_generator import StyleGAN3Generator
from .ghfeat_encoder import GHFeatEncoder
from .perceptual_model import PerceptualModel
from .inception_model import InceptionModel
from .eg3d_generator import EG3DGenerator
from .eg3d_discriminator import DualDiscriminator
from .pigan_generator import PiGANGenerator
from .pigan_discriminator import PiGANDiscriminator
from .volumegan_generator import VolumeGANGenerator
from .volumegan_discriminator import VolumeGANDiscriminator
from .eg3d_generator_fv import EG3DGeneratorFV
from .bev3d_generator import BEV3DGenerator
from .sgbev3d_generator import SGBEV3DGenerator
__all__ = ['build_model']
_MODELS = {
'PGGANGenerator': PGGANGenerator,
'PGGANDiscriminator': PGGANDiscriminator,
'StyleGANGenerator': StyleGANGenerator,
'StyleGANDiscriminator': StyleGANDiscriminator,
'StyleGAN2Generator': StyleGAN2Generator,
'StyleGAN2Discriminator': StyleGAN2Discriminator,
'StyleGAN3Generator': StyleGAN3Generator,
'GHFeatEncoder': GHFeatEncoder,
'PerceptualModel': PerceptualModel.build_model,
'InceptionModel': InceptionModel.build_model,
'EG3DGenerator': EG3DGenerator,
'EG3DDiscriminator': DualDiscriminator,
'PiGANGenerator': PiGANGenerator,
'PiGANDiscriminator': PiGANDiscriminator,
'VolumeGANGenerator': VolumeGANGenerator,
'VolumeGANDiscriminator': VolumeGANDiscriminator,
'EG3DGeneratorFV': EG3DGeneratorFV,
'BEV3DGenerator': BEV3DGenerator,
'SGBEV3DGenerator': SGBEV3DGenerator,
}
def build_model(model_type, **kwargs):
"""Builds a model based on its class type.
Args:
model_type: Class type to which the model belongs, which is case
sensitive.
**kwargs: Additional arguments to build the model.
Raises:
ValueError: If the `model_type` is not supported.
"""
if model_type not in _MODELS:
raise ValueError(f'Invalid model type: `{model_type}`!\n'
f'Types allowed: {list(_MODELS)}.')
return _MODELS[model_type](**kwargs)