Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
from math import dist | |
import sys | |
import os | |
import click | |
import re | |
import json | |
import glob | |
import tempfile | |
import torch | |
import dnnlib | |
import hydra | |
from datetime import date | |
from training import training_loop | |
from metrics import metric_main | |
from torch_utils import training_stats, custom_ops, distributed_utils | |
from torch_utils.distributed_utils import get_init_file, get_shared_folder | |
from omegaconf import DictConfig, OmegaConf | |
#---------------------------------------------------------------------------- | |
class UserError(Exception): | |
pass | |
#---------------------------------------------------------------------------- | |
def setup_training_loop_kwargs(cfg): | |
args = OmegaConf.create({}) | |
# ------------------------------------------ | |
# General options: gpus, snap, metrics, seed | |
# ------------------------------------------ | |
args.rank = 0 | |
args.gpu = 0 | |
args.num_gpus = torch.cuda.device_count() if cfg.gpus is None else cfg.gpus | |
args.nodes = cfg.nodes if cfg.nodes is not None else 1 | |
args.world_size = 1 | |
args.dist_url = 'env://' | |
args.launcher = cfg.launcher | |
args.partition = cfg.partition | |
args.comment = cfg.comment | |
args.timeout = 4320 if cfg.timeout is None else cfg.timeout | |
args.job_dir = '' | |
if cfg.snap is None: | |
cfg.snap = 50 | |
assert isinstance(cfg.snap, int) | |
if cfg.snap < 1: | |
raise UserError('snap must be at least 1') | |
args.image_snapshot_ticks = cfg.imgsnap | |
args.network_snapshot_ticks = cfg.snap | |
if hasattr(cfg, 'ucp'): | |
args.update_cam_prior_ticks = cfg.ucp | |
if cfg.metrics is None: | |
cfg.metrics = ['fid50k_full'] | |
cfg.metrics = list(cfg.metrics) | |
if not all(metric_main.is_valid_metric(metric) for metric in cfg.metrics): | |
raise UserError('\n'.join(['metrics can only contain the following values:'] + metric_main.list_valid_metrics())) | |
args.metrics = cfg.metrics | |
if cfg.seed is None: | |
cfg.seed = 0 | |
assert isinstance(cfg.seed, int) | |
args.random_seed = cfg.seed | |
# ----------------------------------- | |
# Dataset: data, cond, subset, mirror | |
# ----------------------------------- | |
assert cfg.data is not None | |
assert isinstance(cfg.data, str) | |
args.update({"training_set_kwargs": dict(class_name='training.dataset.ImageFolderDataset', path=cfg.data, resolution=cfg.resolution, use_labels=True, max_size=None, xflip=False)}) | |
args.update({"data_loader_kwargs": dict(pin_memory=True, num_workers=3, prefetch_factor=2)}) | |
args.generation_with_image = getattr(cfg, 'generate_with_image', False) | |
try: | |
training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset | |
args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution | |
args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels | |
args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size | |
desc = training_set.name | |
del training_set # conserve memory | |
except IOError as err: | |
raise UserError(f'data: {err}') | |
if cfg.cond is None: | |
cfg.cond = False | |
assert isinstance(cfg.cond, bool) | |
if cfg.cond: | |
if not args.training_set_kwargs.use_labels: | |
raise UserError('cond=True requires labels specified in dataset.json') | |
desc += '-cond' | |
else: | |
args.training_set_kwargs.use_labels = False | |
if cfg.subset is not None: | |
assert isinstance(cfg.subset, int) | |
if not 1 <= cfg.subset <= args.training_set_kwargs.max_size: | |
raise UserError(f'subset must be between 1 and {args.training_set_kwargs.max_size}') | |
desc += f'-subset{cfg.subset}' | |
if cfg.subset < args.training_set_kwargs.max_size: | |
args.training_set_kwargs.max_size = cfg.subset | |
args.training_set_kwargs.random_seed = args.random_seed | |
if cfg.mirror is None: | |
cfg.mirror = False | |
assert isinstance(cfg.mirror, bool) | |
if cfg.mirror: | |
desc += '-mirror' | |
args.training_set_kwargs.xflip = True | |
# ------------------------------------ | |
# Base config: cfg, model, gamma, kimg, batch | |
# ------------------------------------ | |
if cfg.auto: | |
cfg.spec.name = 'auto' | |
desc += f'-{cfg.spec.name}' | |
desc += f'-{cfg.model.name}' | |
if cfg.spec.name == 'auto': | |
res = args.training_set_kwargs.resolution | |
cfg.spec.fmaps = 1 if res >= 512 else 0.5 | |
cfg.spec.lrate = 0.002 if res >= 1024 else 0.0025 | |
cfg.spec.gamma = 0.0002 * (res ** 2) / cfg.spec.mb # heuristic formula | |
cfg.spec.ema = cfg.spec.mb * 10 / 32 | |
if getattr(cfg.spec, 'lrate_disc', None) is None: | |
cfg.spec.lrate_disc = cfg.spec.lrate # use the same learning rate for discriminator | |
# model (generator, discriminator) | |
args.update({"G_kwargs": dict(**cfg.model.G_kwargs)}) | |
args.update({"D_kwargs": dict(**cfg.model.D_kwargs)}) | |
args.update({"G_opt_kwargs": dict(class_name='torch.optim.Adam', lr=cfg.spec.lrate, betas=[0,0.99], eps=1e-8)}) | |
args.update({"D_opt_kwargs": dict(class_name='torch.optim.Adam', lr=cfg.spec.lrate_disc, betas=[0,0.99], eps=1e-8)}) | |
args.update({"loss_kwargs": dict(class_name='training.loss.StyleGAN2Loss', r1_gamma=cfg.spec.gamma, **cfg.model.loss_kwargs)}) | |
if cfg.spec.name == 'cifar': | |
args.loss_kwargs.pl_weight = 0 # disable path length regularization | |
args.loss_kwargs.style_mixing_prob = 0 # disable style mixing | |
args.D_kwargs.architecture = 'orig' # disable residual skip connections | |
# kimg data config | |
args.spec = cfg.spec # just keep the dict. | |
args.total_kimg = cfg.spec.kimg | |
args.batch_size = cfg.spec.mb | |
args.batch_gpu = cfg.spec.mbstd | |
args.ema_kimg = cfg.spec.ema | |
args.ema_rampup = cfg.spec.ramp | |
# --------------------------------------------------- | |
# Discriminator augmentation: aug, p, target, augpipe | |
# --------------------------------------------------- | |
if cfg.aug is None: | |
cfg.aug = 'ada' | |
else: | |
assert isinstance(cfg.aug, str) | |
desc += f'-{cfg.aug}' | |
if cfg.aug == 'ada': | |
args.ada_target = 0.6 | |
elif cfg.aug == 'noaug': | |
pass | |
elif cfg.aug == 'fixed': | |
if cfg.p is None: | |
raise UserError(f'--aug={cfg.aug} requires specifying --p') | |
else: | |
raise UserError(f'--aug={cfg.aug} not supported') | |
if cfg.p is not None: | |
assert isinstance(cfg.p, float) | |
if cfg.aug != 'fixed': | |
raise UserError('--p can only be specified with --aug=fixed') | |
if not 0 <= cfg.p <= 1: | |
raise UserError('--p must be between 0 and 1') | |
desc += f'-p{cfg.p:g}' | |
args.augment_p = cfg.p | |
if cfg.target is not None: | |
assert isinstance(cfg.target, float) | |
if cfg.aug != 'ada': | |
raise UserError('--target can only be specified with --aug=ada') | |
if not 0 <= cfg.target <= 1: | |
raise UserError('--target must be between 0 and 1') | |
desc += f'-target{cfg.target:g}' | |
args.ada_target = cfg.target | |
assert cfg.augpipe is None or isinstance(cfg.augpipe, str) | |
if cfg.augpipe is None: | |
cfg.augpipe = 'bgc' | |
else: | |
if cfg.aug == 'noaug': | |
raise UserError('--augpipe cannot be specified with --aug=noaug') | |
desc += f'-{cfg.augpipe}' | |
augpipe_specs = { | |
'blit': dict(xflip=1, rotate90=1, xint=1), | |
'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), | |
'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), | |
'filter': dict(imgfilter=1), | |
'noise': dict(noise=1), | |
'cutout': dict(cutout=1), | |
'bgc0': dict(xint=1, scale=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), | |
'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), | |
'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), | |
'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), | |
'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), | |
'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), | |
} | |
assert cfg.augpipe in augpipe_specs | |
if cfg.aug != 'noaug': | |
args.update({"augment_kwargs": dict(class_name='training.augment.AugmentPipe', **augpipe_specs[cfg.augpipe])}) | |
# ---------------------------------- | |
# Transfer learning: resume, freezed | |
# ---------------------------------- | |
resume_specs = { | |
'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', | |
'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', | |
'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', | |
'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', | |
'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', | |
} | |
assert cfg.resume is None or isinstance(cfg.resume, str) | |
if cfg.resume is None: | |
cfg.resume = 'noresume' | |
elif cfg.resume == 'noresume': | |
desc += '-noresume' | |
elif cfg.resume in resume_specs: | |
desc += f'-resume{cfg.resume}' | |
args.resume_pkl = resume_specs[cfg.resume] # predefined url | |
else: | |
desc += '-resumecustom' | |
args.resume_pkl = cfg.resume # custom path or url | |
if cfg.resume != 'noresume': | |
args.ada_kimg = 100 # make ADA react faster at the beginning | |
args.ema_rampup = None # disable EMA rampup | |
if cfg.freezed is not None: | |
assert isinstance(cfg.freezed, int) | |
if not cfg.freezed >= 0: | |
raise UserError('--freezed must be non-negative') | |
desc += f'-freezed{cfg.freezed:d}' | |
args.D_kwargs.block_kwargs.freeze_layers = cfg.freezed | |
# ------------------------------------------------- | |
# Performance options: fp32, nhwc, nobench, workers | |
# ------------------------------------------------- | |
args.num_fp16_res = cfg.num_fp16_res | |
if cfg.fp32 is None: | |
cfg.fp32 = False | |
assert isinstance(cfg.fp32, bool) | |
if cfg.fp32: | |
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 | |
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None | |
if cfg.nhwc is None: | |
cfg.nhwc = False | |
assert isinstance(cfg.nhwc, bool) | |
if cfg.nhwc: | |
args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True | |
if cfg.nobench is None: | |
cfg.nobench = False | |
assert isinstance(cfg.nobench, bool) | |
if cfg.nobench: | |
args.cudnn_benchmark = False | |
if cfg.allow_tf32 is None: | |
cfg.allow_tf32 = False | |
assert isinstance(cfg.allow_tf32, bool) | |
args.allow_tf32 = cfg.allow_tf32 | |
if cfg.workers is not None: | |
assert isinstance(cfg.workers, int) | |
if not cfg.workers >= 1: | |
raise UserError('--workers must be at least 1') | |
args.data_loader_kwargs.num_workers = cfg.workers | |
args.debug = cfg.debug | |
if getattr(cfg, "prefix", None) is not None: | |
desc = cfg.prefix + '-' + desc | |
return desc, args | |
#---------------------------------------------------------------------------- | |
def subprocess_fn(rank, args): | |
if not args.debug: | |
dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True) | |
# Init torch.distributed. | |
distributed_utils.init_distributed_mode(rank, args) | |
if args.rank != 0: | |
custom_ops.verbosity = 'none' | |
# Execute training loop. | |
training_loop.training_loop(**args) | |
#---------------------------------------------------------------------------- | |
class CommaSeparatedList(click.ParamType): | |
name = 'list' | |
def convert(self, value, param, ctx): | |
_ = param, ctx | |
if value is None or value.lower() == 'none' or value == '': | |
return [] | |
return value.split(',') | |
def main(cfg: DictConfig): | |
outdir = cfg.outdir | |
# Setup training options | |
run_desc, args = setup_training_loop_kwargs(cfg) | |
# Pick output directory. | |
prev_run_dirs = [] | |
if os.path.isdir(outdir): | |
prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))] | |
if cfg.resume_run is None: | |
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] | |
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] | |
cur_run_id = max(prev_run_ids, default=-1) + 1 | |
else: | |
cur_run_id = cfg.resume_run | |
args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}') | |
print(outdir, args.run_dir) | |
if cfg.resume_run is not None: | |
pkls = sorted(glob.glob(args.run_dir + '/network*.pkl')) | |
if len(pkls) > 0: | |
args.resume_pkl = pkls[-1] | |
args.resume_start = int(args.resume_pkl.split('-')[-1][:-4]) * 1000 | |
else: | |
args.resume_start = 0 | |
# Print options. | |
print() | |
print('Training options:') | |
print(OmegaConf.to_yaml(args)) | |
print() | |
print(f'Output directory: {args.run_dir}') | |
print(f'Training data: {args.training_set_kwargs.path}') | |
print(f'Training duration: {args.total_kimg} kimg') | |
print(f'Number of images: {args.training_set_kwargs.max_size}') | |
print(f'Image resolution: {args.training_set_kwargs.resolution}') | |
print(f'Conditional model: {args.training_set_kwargs.use_labels}') | |
print(f'Dataset x-flips: {args.training_set_kwargs.xflip}') | |
print() | |
# Dry run? | |
if cfg.dry_run: | |
print('Dry run; exiting.') | |
return | |
# Create output directory. | |
print('Creating output directory...') | |
if not os.path.exists(args.run_dir): | |
os.makedirs(args.run_dir) | |
with open(os.path.join(args.run_dir, 'training_options.yaml'), 'wt') as fp: | |
OmegaConf.save(config=args, f=fp.name) | |
# Launch processes. | |
print('Launching processes...') | |
if (args.launcher == 'spawn') and (args.num_gpus > 1): | |
args.dist_url = distributed_utils.get_init_file().as_uri() | |
torch.multiprocessing.set_start_method('spawn') | |
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args,), nprocs=args.num_gpus) | |
else: | |
subprocess_fn(rank=0, args=args) | |
#---------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
if os.getenv('SLURM_ARGS') is not None: | |
# deparcated launcher for slurm jobs. | |
slurm_arg = eval(os.getenv('SLURM_ARGS')) | |
all_args = sys.argv[1:] | |
print(slurm_arg) | |
print(all_args) | |
from launcher import launch | |
launch(slurm_arg, all_args) | |
else: | |
main() # pylint: disable=no-value-for-parameter | |
#---------------------------------------------------------------------------- | |