import warnings import os import torch from torch import multiprocessing as mp import stylegan2 from stylegan2 import utils from stylegan2.external_models import inception, lpips from stylegan2.metrics import fid, ppl #---------------------------------------------------------------------------- def get_arg_parser(): parser = utils.ConfigArgumentParser() parser.add_argument( '--output', help='Output directory for model weights.', type=str, default=None, metavar='DIR' ) #---------------------------------------------------------------------------- # Model options parser.add_argument( '--channels', help='Specify the channels for each layer (can be overriden for individual ' + \ 'networks with "--g_channels" and "--d_channels". ' + \ 'Default: %(default)s', nargs='*', type=int, default=[32, 32, 64, 128, 256, 512, 512, 512, 512], metavar='CHANNELS' ) parser.add_argument( '--latent', help='Size of the prior (noise vector). Default: %(default)s', type=int, default=512, metavar='VALUE' ) parser.add_argument( '--label', help='Number of unique labels. Unused if not specified.', type=int, default=0, metavar='VALUE' ) parser.add_argument( '--base_shape', help='Data shape of first layer in generator or ' + \ 'last layer in discriminator. Default: %(default)s', nargs=2, type=int, default=(4, 4), metavar='SIZE' ) parser.add_argument( '--kernel_size', help='Size of conv kernel. Default: %(default)s', type=int, default=3, metavar='SIZE' ) parser.add_argument( '--pad_once', help='Pad filtered convs only once before filter instead ' + \ 'of twice. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--pad_mode', help='Padding mode for conv layers. Default: %(default)s', type=str, default='constant', metavar='MODE' ) parser.add_argument( '--pad_constant', help='Padding constant for conv layers when `pad_mode` is ' + \ '\'constant\'. Default: %(default)s', type=float, default=0, metavar='VALUE' ) parser.add_argument( '--filter_pad_mode', help='Padding mode for filter layers. Default: %(default)s', type=str, default='constant', metavar='MODE' ) parser.add_argument( '--filter_pad_constant', help='Padding constant for filter layers when `filter_pad_mode` ' + \ 'is \'constant\'. Default: %(default)s', type=float, default=0, metavar='VALUE' ) parser.add_argument( '--filter', help='Filter to use whenever FIR is applied. Default: %(default)s', nargs='*', type=float, default=[1, 3, 3, 1], metavar='VALUE' ) parser.add_argument( '--weight_scale', help='Use weight scaling for equalized learning rate. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) #---------------------------------------------------------------------------- # Generator options parser.add_argument( '--g_file', help='Load a generator model from a file instead of constructing a new one. Disabled unless a file is specified.', type=str, default=None, metavar='FILE' ) parser.add_argument( '--g_channels', help='Instead of the values of "--channels", ' + \ 'use these for the generator instead.', nargs='*', type=int, default=[], metavar='CHANNELS' ) parser.add_argument( '--g_skip', help='Use skip connections for the generator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--g_resnet', help='Use resnet connections for the generator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) parser.add_argument( '--g_conv_block_size', help='Number of layers in a conv block in the generator. Default: %(default)s', type=int, default=2, metavar='VALUE' ) parser.add_argument( '--g_normalize', help='Normalize conv features for generator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--g_fused_conv', help='Fuse conv & upsample into a transposed ' + \ 'conv for the generator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--g_activation', help='The non-linear activaiton function for ' + \ 'the generator. Default: %(default)s', default='leaky:0.2', type=str, metavar='ACTIVATION' ) parser.add_argument( '--g_conv_resample_mode', help='Resample mode for upsampling conv ' + \ 'layers for generator. Default: %(default)s', type=str, default='FIR', metavar='MODE' ) parser.add_argument( '--g_skip_resample_mode', help='Resample mode for skip connection ' + \ 'upsamples for the generator. Default: %(default)s', type=str, default='FIR', metavar='MODE' ) parser.add_argument( '--g_lr', help='The learning rate for the generator. Default: %(default)s', default=2e-3, type=float, metavar='VALUE' ) parser.add_argument( '--g_betas', help='Beta values for the generator Adam optimizer. Default: %(default)s', type=float, nargs=2, default=(0, 0.99), metavar='VALUE' ) parser.add_argument( '--g_loss', help='Loss function for the generator. Default: %(default)s', default='logistic_ns', type=str, metavar='LOSS' ) parser.add_argument( '--g_reg', help='Regularization function for the generator with an optional weight (:?). Default: %(default)s', default='pathreg:2', type=str, metavar='REG' ) parser.add_argument( '--g_reg_interval', help='Interval at which to regularize the generator. Default: %(default)s', default=4, type=int, metavar='INTERVAL' ) parser.add_argument( '--g_iter', help='Number of generator iterations per training iteration. Default: %(default)s', default=1, type=int, metavar='ITER' ) parser.add_argument( '--style_mix', help='The probability of passing more than one ' + \ 'latent to the generator. Default: %(default)s', type=float, default=0.9, metavar='PROBABILITY' ) parser.add_argument( '--latent_mapping_layers', help='The number of layers of the latent mapping network. Default: %(default)s', type=int, default=8, metavar='LAYERS' ) parser.add_argument( '--latent_mapping_lr_mul', help='The learning rate multiplier for the latent ' + \ 'mapping network. Default: %(default)s', type=float, default=0.01, metavar='LR_MUL' ) parser.add_argument( '--normalize_latent', help='Normalize latent inputs. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--modulate_rgb', help='Modulate RGB layers (use style for output ' + \ 'layers of generator). Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) #---------------------------------------------------------------------------- # Discriminator options parser.add_argument( '--d_file', help='Load a discriminator model from a file instead of constructing a new one. Disabled unless a file is specified.', type=str, default=None, metavar='FILE' ) parser.add_argument( '--d_channels', help='Instead of the values of "--channels", ' + \ 'use these for the discriminator instead.', nargs='*', type=int, default=[], metavar='CHANNELS' ) parser.add_argument( '--d_skip', help='Use skip connections for the discriminator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) parser.add_argument( '--d_resnet', help='Use resnet connections for the discriminator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--d_conv_block_size', help='Number of layers in a conv block in the discriminator. Default: %(default)s', type=int, default=2, metavar='VALUE' ) parser.add_argument( '--d_fused_conv', help='Fuse conv & downsample into a strided ' + \ 'conv for the discriminator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--group_size', help='Size of the groups in batch std layer. Default: %(default)s', type=int, default=4, metavar='VALUE' ) parser.add_argument( '--d_activation', help='The non-linear activaiton function for the discriminator. Default: %(default)s', default='leaky:0.2', type=str, metavar='ACTIVATION' ) parser.add_argument( '--d_conv_resample_mode', help='Resample mode for downsampling conv ' + \ 'layers for discriminator. Default: %(default)s', type=str, default='FIR', metavar='MODE' ) parser.add_argument( '--d_skip_resample_mode', help='Resample mode for skip connection ' + \ 'downsamples for the discriminator. Default: %(default)s', type=str, default='FIR', metavar='MODE' ) parser.add_argument( '--d_loss', help='Loss function for the disriminator. Default: %(default)s', default='logistic', type=str, metavar='LOSS' ) parser.add_argument( '--d_reg', help='Regularization function for the discriminator ' + \ 'with an optional weight (:?). Default: %(default)s', default='r1:10', type=str, metavar='REG' ) parser.add_argument( '--d_reg_interval', help='Interval at which to regularize the discriminator. Default: %(default)s', default=16, type=int, metavar='INTERVAL' ) parser.add_argument( '--d_iter', help='Number of discriminator iterations per training iteration. Default: %(default)s', default=1, type=int, metavar='ITER' ) parser.add_argument( '--d_lr', help='The learning rate for the discriminator. Default: %(default)s', default=2e-3, type=float, metavar='VALUE' ) parser.add_argument( '--d_betas', help='Beta values for the discriminator Adam optimizer. Default: %(default)s', type=float, nargs=2, default=(0, 0.99), metavar='VALUE' ) #---------------------------------------------------------------------------- # Training options parser.add_argument( '--iterations', help='Number of iterations to train for. Default: %(default)s', type=int, default=1000000, metavar='ITERATIONS' ) parser.add_argument( '--gpu', help='The cuda device(s) to use. Example: ""--gpu 0 1" will train ' + \ 'on GPU 0 and GPU 1. Default: Only use CPU', type=int, default=[], nargs='*', metavar='DEVICE_ID' ) parser.add_argument( '--distributed', help='When more than one gpu device is passed, automatically ' + \ 'start one process for each device and give it the correct ' + \ 'distributed args (rank, world_size etc). Disable this if ' + \ 'you want training to be performed with only one process ' + \ 'using the DataParallel module. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=True, metavar='BOOL' ) parser.add_argument( '--rank', help='Rank for distributed training.', type=int, default=None, ) parser.add_argument( '--world_size', help='World size for distributed training.', type=int, default=None, ) parser.add_argument( '--master_addr', help='Address for distributed training.', type=str, default=None, ) parser.add_argument( '--master_port', help='Port for distributed training.', type=str, default=None, ) parser.add_argument( '--batch_size', help='Size of each batch. Default: %(default)s', default=32, type=int, metavar='VALUE' ) parser.add_argument( '--device_batch_size', help='Maximum number of items to fit on single device at a time. Default: %(default)s', default=4, type=int, metavar='VALUE' ) parser.add_argument( '--g_reg_batch_size', help='Size of each batch used to regularize the generator. Default: %(default)s', default=16, type=int, metavar='VALUE' ) parser.add_argument( '--g_reg_device_batch_size', help='Maximum number of items to fit on single device when ' + \ 'regularizing the generator. Default: %(default)s', default=2, type=int, metavar='VALUE' ) parser.add_argument( '--half', help='Use mixed precision training. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) parser.add_argument( '--resume', help='Resume from the latest saved checkpoint in the checkpoint_dir. ' + \ 'This loads all previous training settings except for the dataset options, ' + \ 'device args (--gpu ...) and distributed training args (--rank, --world_size e.t.c) ' + \ 'as well as metrics and logging.', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) #---------------------------------------------------------------------------- # Extra metric options parser.add_argument( '--fid_interval', help='If specified, evaluate the FID metric with this interval.', default=None, type=int, metavar='INTERVAL' ) parser.add_argument( '--ppl_interval', help='If specified, evaluate the PPL metric with this interval.', default=None, type=int, metavar='INTERVAL' ) parser.add_argument( '--ppl_ffhq_crop', help='Crop images evaluated for PPL with crop values for FFHQ. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) #---------------------------------------------------------------------------- # Data options parser.add_argument( '--pixel_min', help='Minimum of the value range of pixels in generated images. Default: %(default)s', default=-1, type=float, metavar='VALUE' ) parser.add_argument( '--pixel_max', help='Maximum of the value range of pixels in generated images. Default: %(default)s', default=1, type=float, metavar='VALUE' ) parser.add_argument( '--data_channels', help='Number of channels in the data. Default: 3 (RGB)', default=3, type=int, choices=[1, 3], metavar='CHANNELS' ) parser.add_argument( '--data_dir', help='The root directory of the dataset. This argument is required!', type=str, default=None ) parser.add_argument( '--data_resize', help='Resize data to fit input size of discriminator. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) parser.add_argument( '--mirror_augment', help='Use random horizontal flipping for data images. Default: %(default)s', type=utils.bool_type, const=True, nargs='?', default=False, metavar='BOOL' ) parser.add_argument( '--data_workers', help='Number of worker processes that handles dataloading. Default: %(default)s', default=4, type=int, metavar='WORKERS' ) #---------------------------------------------------------------------------- # Logging options parser.add_argument( '--checkpoint_dir', help='If specified, save checkpoints to this directory.', default=None, type=str, metavar='DIR' ) parser.add_argument( '--checkpoint_interval', help='Save checkpoints with this interval. Default: %(default)s', default=10000, type=int, metavar='INTERVAL' ) parser.add_argument( '--tensorboard_log_dir', help='Log to this tensorboard directory if specified.', default=None, type=str, metavar='DIR' ) parser.add_argument( '--tensorboard_image_interval', help='Log images to tensorboard with this interval if specified.', default=None, type=int, metavar='INTERVAL' ) parser.add_argument( '--tensorboard_image_size', help='Size of images logged to tensorboard. Default: %(default)s', default=256, type=int, metavar='VALUE' ) return parser #---------------------------------------------------------------------------- def get_dataset(args): assert args.data_dir, '--data_dir has to be specified.' height, width = [ shape * 2 ** (len(args.d_channels or args.channels) - 1) for shape in args.base_shape ] dataset = utils.ImageFolder( args.data_dir, mirror=args.mirror_augment, pixel_min=args.pixel_min, pixel_max=args.pixel_max, height=height, width=width, resize=args.data_resize, grayscale=args.data_channels == 1 ) assert len(dataset), 'No images found at {}'.format(args.data_dir) return dataset #---------------------------------------------------------------------------- def get_models(args): common_kwargs = dict( data_channels=args.data_channels, base_shape=args.base_shape, conv_filter=args.filter, skip_filter=args.filter, kernel_size=args.kernel_size, conv_pad_mode=args.pad_mode, conv_pad_constant=args.pad_constant, filter_pad_mode=args.filter_pad_mode, filter_pad_constant=args.filter_pad_constant, pad_once=args.pad_once, weight_scale=args.weight_scale ) if args.g_file: G = stylegan2.models.load(args.g_file) assert isinstance(G, stylegan2.models.Generator), \ '`--g_file` should specify a generator model, found {}'.format(type(G)) else: G_M = stylegan2.models.GeneratorMapping( latent_size=args.latent, label_size=args.label, num_layers=args.latent_mapping_layers, hidden=args.latent, activation=args.g_activation, normalize_input=args.normalize_latent, lr_mul=args.latent_mapping_lr_mul, weight_scale=args.weight_scale ) G_S = stylegan2.models.GeneratorSynthesis( channels=args.g_channels or args.channels, latent_size=args.latent, demodulate=args.g_normalize, modulate_data_out=args.modulate_rgb, conv_block_size=args.g_conv_block_size, activation=args.g_activation, conv_resample_mode=args.g_conv_resample_mode, skip_resample_mode=args.g_skip_resample_mode, resnet=args.g_resnet, skip=args.g_skip, fused_resample=args.g_fused_conv, **common_kwargs ) G = stylegan2.models.Generator(G_mapping=G_M, G_synthesis=G_S) if args.d_file: D = stylegan2.models.load(args.d_file) assert isinstance(D, stylegan2.models.Discriminator), \ '`--d_file` should specify a discriminator model, found {}'.format(type(D)) else: D = stylegan2.models.Discriminator( channels=args.d_channels or args.channels, label_size=args.label, conv_block_size=args.d_conv_block_size, activation=args.d_activation, conv_resample_mode=args.d_conv_resample_mode, skip_resample_mode=args.d_skip_resample_mode, mbstd_group_size=args.group_size, resnet=args.d_resnet, skip=args.d_skip, fused_resample=args.d_fused_conv, **common_kwargs ) assert len(G.G_synthesis.channels) == len(D.channels), \ 'While the number of channels for each layer can ' + \ 'differ between generator and discriminator, the ' + \ 'number of layers have to be the same. Received ' + \ '{} generator layers and {} discriminator layers.'.format( len(G.G_synthesis.channels), len(D.channels)) return G, D #---------------------------------------------------------------------------- def get_trainer(args): dataset = get_dataset(args) if args.resume and stylegan2.train._find_checkpoint(args.checkpoint_dir): trainer = stylegan2.train.Trainer.load_checkpoint( args.checkpoint_dir, dataset, device=args.gpu, rank=args.rank, world_size=args.world_size, master_addr=args.master_addr, master_port=args.master_port, tensorboard_log_dir=args.tensorboard_log_dir ) else: G, D = get_models(args) trainer = stylegan2.train.Trainer( G=G, D=D, latent_size=args.latent, dataset=dataset, device=args.gpu, batch_size=args.batch_size, device_batch_size=args.device_batch_size, label_size=args.label, data_workers=args.data_workers, G_loss=args.g_loss, D_loss=args.d_loss, G_reg=args.g_reg, G_reg_interval=args.g_reg_interval, G_opt_kwargs={'lr': args.g_lr, 'betas': args.g_betas}, G_reg_batch_size=args.g_reg_batch_size, G_reg_device_batch_size=args.g_reg_device_batch_size, D_reg=args.d_reg, D_reg_interval=args.d_reg_interval, D_opt_kwargs={'lr': args.d_lr, 'betas': args.d_betas}, style_mix_prob=args.style_mix, G_iter=args.g_iter, D_iter=args.d_iter, tensorboard_log_dir=args.tensorboard_log_dir, checkpoint_dir=args.checkpoint_dir, checkpoint_interval=args.checkpoint_interval, half=args.half, rank=args.rank, world_size=args.world_size, master_addr=args.master_addr, master_port=args.master_port ) if args.fid_interval and not args.rank: fid_model = inception.InceptionV3FeatureExtractor( pixel_min=args.pixel_min, pixel_max=args.pixel_max) trainer.register_metric( name='FID (299x299)', eval_fn=fid.FID( trainer.Gs, trainer.prior_generator, dataset=dataset, fid_model=fid_model, fid_size=299, reals_batch_size=64 ), interval=args.fid_interval ) trainer.register_metric( name='FID', eval_fn=fid.FID( trainer.Gs, trainer.prior_generator, dataset=dataset, fid_model=fid_model, fid_size=None ), interval=args.fid_interval ) if args.ppl_interval and not args.rank: lpips_model = lpips.LPIPS_VGG16( pixel_min=args.pixel_min, pixel_max=args.pixel_max) crop = None if args.ppl_ffhq_crop: crop = ppl.PPL.FFHQ_CROP trainer.register_metric( name='PPL_end', eval_fn=ppl.PPL( trainer.Gs, trainer.prior_generator, full_sampling=False, crop=crop, lpips_model=lpips_model, lpips_size=256 ), interval=args.ppl_interval ) trainer.register_metric( name='PPL_full', eval_fn=ppl.PPL( trainer.Gs, trainer.prior_generator, full_sampling=True, crop=crop, lpips_model=lpips_model, lpips_size=256 ), interval=args.ppl_interval ) if args.tensorboard_image_interval: for static in [True, False]: for trunc in [0.5, 0.7, 1.0]: if static: name = 'static' else: name = 'random' name += '/trunc_{:.1f}'.format(trunc) trainer.add_tensorboard_image_logging( name=name, num_images=4, interval=args.tensorboard_image_interval, resize=args.tensorboard_image_size, seed=1234567890 if static else None, truncation_psi=trunc, pixel_min=args.pixel_min, pixel_max=args.pixel_max ) return trainer #---------------------------------------------------------------------------- def run(args): if not args.rank: if not (args.checkpoint_dir or args.output): warnings.warn( 'Neither an output path or checkpoint dir has been ' + \ 'given. Weights from this training run will never ' + \ 'be saved.' ) if args.output: assert os.path.isdir(args.output) or not os.path.splitext(args.output)[-1], \ '--output argument should specify a directory, not a file.' trainer = get_trainer(args) trainer.train(iterations=args.iterations) if not args.rank and args.output: print('Saving models to {}'.format(args.output)) if not os.path.exists(args.output): os.makedirs(args.output) for model_name in ['G', 'D', 'Gs']: getattr(trainer, model_name).save( os.path.join(args.output_dir, model_name + '.pth')) #---------------------------------------------------------------------------- def run_distributed(rank, args): args.rank = rank args.world_size = len(args.gpu) args.gpu = args.gpu[rank] args.master_addr = args.master_addr or '127.0.0.1' args.master_port = args.master_port or '23456' run(args) #---------------------------------------------------------------------------- def main(): parser = get_arg_parser() args = parser.parse_args() if len(args.gpu) > 1 and args.distributed: assert args.rank is None and args.world_size is None, \ 'When --distributed is enabled (default) the rank and ' + \ 'world size can not be given as this is set up automatically. ' + \ 'Use --distributed 0 to disable automatic setup of distributed training.' mp.spawn(run_distributed, nprocs=len(args.gpu), args=(args,)) else: run(args) #---------------------------------------------------------------------------- if __name__ == '__main__': main()