|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Train a GAN using the techniques described in the paper |
|
"Training Generative Adversarial Networks with Limited Data".""" |
|
import sys |
|
import os |
|
|
|
sys.path.insert(1, os.path.join(sys.path[0], "..")) |
|
import click |
|
import re |
|
import json |
|
import tempfile |
|
import torch |
|
import dnnlib |
|
|
|
import numpy as np |
|
|
|
import parser |
|
|
|
from training import training_loop |
|
from metrics import metric_main |
|
from torch_utils import training_stats |
|
from torch_utils import custom_ops |
|
|
|
|
|
|
|
|
|
|
|
class UserError(Exception): |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
def setup_training_loop_kwargs( |
|
|
|
exp_name=None, |
|
slurm=None, |
|
gpus=None, |
|
nodes=None, |
|
snap=None, |
|
metrics=None, |
|
seed=None, |
|
|
|
data=None, |
|
class_cond=None, |
|
subset=None, |
|
mirror=None, |
|
|
|
instance_cond=None, |
|
feature_augmentation=None, |
|
root_feats=None, |
|
root_nns=None, |
|
label_dim=None, |
|
|
|
cfg=None, |
|
lrate=None, |
|
gamma=None, |
|
kimg=None, |
|
batch=None, |
|
num_channel_g=None, |
|
num_channel_d=None, |
|
channel_max_g=None, |
|
channel_max_d=None, |
|
hidden_dim_c=None, |
|
hidden_dim_h=None, |
|
es_patience=None, |
|
|
|
aug=None, |
|
p=None, |
|
target=None, |
|
augpipe=None, |
|
|
|
resume=None, |
|
freezed=None, |
|
|
|
fp32=None, |
|
nhwc=None, |
|
allow_tf32=None, |
|
nobench=None, |
|
workers=None, |
|
**kwargs, |
|
): |
|
args = dnnlib.EasyDict() |
|
|
|
|
|
|
|
|
|
|
|
if gpus is None: |
|
gpus = 1 |
|
assert isinstance(gpus, int) |
|
if not (gpus >= 1 and gpus & (gpus - 1) == 0): |
|
raise UserError("--gpus must be a power of two") |
|
args.num_gpus = gpus * nodes |
|
|
|
if snap is None: |
|
snap = 50 |
|
assert isinstance(snap, int) |
|
if snap < 1: |
|
raise UserError("--snap must be at least 1") |
|
args.image_snapshot_ticks = snap |
|
args.network_snapshot_ticks = snap |
|
args.es_patience = es_patience |
|
|
|
if metrics is None: |
|
metrics = ["fid50k_full"] |
|
assert isinstance(metrics, list) |
|
if not all(metric_main.is_valid_metric(metric) for metric in metrics): |
|
raise UserError( |
|
"\n".join( |
|
["--metrics can only contain the following values:"] |
|
+ metric_main.list_valid_metrics() |
|
) |
|
) |
|
args.metrics = metrics |
|
|
|
if seed is None: |
|
seed = 0 |
|
assert isinstance(seed, int) |
|
args.random_seed = seed |
|
|
|
|
|
|
|
|
|
|
|
assert data is not None |
|
assert isinstance(data, str) |
|
|
|
class_name = "data_utils.datasets_common.ILSVRC_HDF5_feats" |
|
args.class_cond = class_cond |
|
args.instance_cond = instance_cond |
|
|
|
if mirror is None: |
|
mirror = False |
|
assert isinstance(mirror, bool) |
|
|
|
args.training_set_kwargs = dnnlib.EasyDict( |
|
class_name=class_name, |
|
root=data, |
|
max_size=None, |
|
xflip=False, |
|
load_labels=class_cond, |
|
load_features=instance_cond, |
|
root_feats=root_feats, |
|
root_nns=root_nns, |
|
transform=None, |
|
label_dim=label_dim, |
|
feature_dim=2048, |
|
apply_norm=False, |
|
label_onehot=True, |
|
feature_augmentation=feature_augmentation, |
|
) |
|
args.data_loader_kwargs = dnnlib.EasyDict( |
|
pin_memory=True, num_workers=3, prefetch_factor=2 |
|
) |
|
try: |
|
training_set = dnnlib.util.construct_class_by_name( |
|
**args.training_set_kwargs |
|
) |
|
args.training_set_kwargs.resolution = ( |
|
training_set.resolution |
|
) |
|
args.training_set_kwargs.load_labels = class_cond |
|
args.training_set_kwargs.max_size = len( |
|
training_set |
|
) |
|
desc = os.path.splitext(os.path.basename(data))[0] |
|
del training_set |
|
except IOError as err: |
|
raise UserError(f"--data: {err}") |
|
|
|
if mirror: |
|
desc += "-mirror" |
|
args.training_set_kwargs.xflip = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if subset is not None: |
|
assert isinstance(subset, int) |
|
if not 1 <= subset <= args.training_set_kwargs.max_size: |
|
raise UserError( |
|
f"--subset must be between 1 and {args.training_set_kwargs.max_size}" |
|
) |
|
desc += f"-subset{subset}" |
|
if subset < args.training_set_kwargs.max_size: |
|
args.training_set_kwargs.max_size = subset |
|
args.training_set_kwargs.random_seed = args.random_seed |
|
|
|
|
|
|
|
|
|
|
|
if cfg is None: |
|
cfg = "auto" |
|
assert isinstance(cfg, str) |
|
desc += f"-{cfg}" |
|
|
|
cfg_specs = { |
|
"auto": dict( |
|
ref_gpus=-1, |
|
kimg=25000, |
|
mb=-1, |
|
mbstd=-1, |
|
fmaps=-1, |
|
lrate=-1, |
|
gamma=-1, |
|
ema=-1, |
|
ramp=0.05, |
|
map=2, |
|
), |
|
"stylegan2": dict( |
|
ref_gpus=8, |
|
kimg=25000, |
|
mb=32, |
|
mbstd=4, |
|
fmaps=1, |
|
lrate=0.002, |
|
gamma=10, |
|
ema=10, |
|
ramp=None, |
|
map=8, |
|
), |
|
"paper256": dict( |
|
ref_gpus=8, |
|
kimg=25000, |
|
mb=64, |
|
mbstd=8, |
|
fmaps=0.5, |
|
lrate=0.0025, |
|
gamma=1, |
|
ema=20, |
|
ramp=None, |
|
map=8, |
|
), |
|
"paper512": dict( |
|
ref_gpus=8, |
|
kimg=25000, |
|
mb=64, |
|
mbstd=8, |
|
fmaps=1, |
|
lrate=0.0025, |
|
gamma=0.5, |
|
ema=20, |
|
ramp=None, |
|
map=8, |
|
), |
|
"paper1024": dict( |
|
ref_gpus=8, |
|
kimg=25000, |
|
mb=32, |
|
mbstd=4, |
|
fmaps=1, |
|
lrate=0.002, |
|
gamma=2, |
|
ema=10, |
|
ramp=None, |
|
map=8, |
|
), |
|
"cifar": dict( |
|
ref_gpus=2, |
|
kimg=100000, |
|
mb=64, |
|
mbstd=32, |
|
fmaps=1, |
|
lrate=0.0025, |
|
gamma=0.01, |
|
ema=500, |
|
ramp=0.05, |
|
map=2, |
|
), |
|
} |
|
|
|
assert cfg in cfg_specs |
|
spec = dnnlib.EasyDict(cfg_specs[cfg]) |
|
if cfg == "auto": |
|
desc += f"{gpus:d}" |
|
spec.ref_gpus = args.num_gpus |
|
res = args.training_set_kwargs.resolution |
|
spec.mb = max( |
|
min(args.num_gpus * min(4096 // res, 32), 64), args.num_gpus |
|
) |
|
spec.mbstd = min( |
|
spec.mb // args.num_gpus, 4 |
|
) |
|
spec.fmaps = 1 if res >= 512 else 0.5 |
|
spec.lrate = 0.002 if res >= 1024 else 0.0025 |
|
spec.gamma = 0.0002 * (res ** 2) / spec.mb |
|
spec.ema = spec.mb * 10 / 32 |
|
|
|
args.G_kwargs = dnnlib.EasyDict( |
|
class_name="training.networks.Generator", |
|
z_dim=512, |
|
w_dim=512, |
|
mapping_kwargs=dnnlib.EasyDict(), |
|
synthesis_kwargs=dnnlib.EasyDict(), |
|
) |
|
args.D_kwargs = dnnlib.EasyDict( |
|
class_name="training.networks.Discriminator", |
|
block_kwargs=dnnlib.EasyDict(), |
|
mapping_kwargs=dnnlib.EasyDict(), |
|
epilogue_kwargs=dnnlib.EasyDict(), |
|
) |
|
args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int( |
|
spec.fmaps * 32768 |
|
) |
|
args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 |
|
args.G_kwargs.mapping_kwargs.num_layers = spec.map |
|
if hidden_dim_c is not None: |
|
args.G_kwargs.mapping_kwargs.embed_features = hidden_dim_c |
|
args.D_kwargs.mapping_kwargs.embed_features = hidden_dim_c |
|
if hidden_dim_h is not None: |
|
args.G_kwargs.mapping_kwargs.embed_features_feat = hidden_dim_h |
|
args.D_kwargs.mapping_kwargs.embed_features_feat = hidden_dim_h |
|
args.G_kwargs.synthesis_kwargs.num_fp16_res = ( |
|
args.D_kwargs.num_fp16_res |
|
) = 4 |
|
args.G_kwargs.synthesis_kwargs.conv_clamp = ( |
|
args.D_kwargs.conv_clamp |
|
) = 256 |
|
args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd |
|
|
|
args.exp_name = exp_name |
|
if num_channel_d is not None: |
|
args.D_kwargs.channel_base = num_channel_d |
|
if channel_max_d is not None: |
|
args.D_kwargs.channel_max = channel_max_d |
|
if num_channel_g is not None: |
|
args.G_kwargs.synthesis_kwargs.channel_base = num_channel_g |
|
if channel_max_g is not None: |
|
args.G_kwargs.synthesis_kwargs.channel_max = channel_max_g |
|
|
|
if lrate is not None: |
|
spec.lrate = lrate |
|
|
|
args.G_opt_kwargs = dnnlib.EasyDict( |
|
class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8 |
|
) |
|
args.D_opt_kwargs = dnnlib.EasyDict( |
|
class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8 |
|
) |
|
args.loss_kwargs = dnnlib.EasyDict( |
|
class_name="training.loss.StyleGAN2Loss", r1_gamma=spec.gamma |
|
) |
|
|
|
args.total_kimg = spec.kimg |
|
args.batch_size = spec.mb |
|
args.batch_gpu = spec.mb // spec.ref_gpus |
|
args.ema_kimg = spec.ema |
|
args.ema_rampup = spec.ramp |
|
|
|
if cfg == "cifar": |
|
args.loss_kwargs.pl_weight = 0 |
|
args.loss_kwargs.style_mixing_prob = 0 |
|
args.D_kwargs.architecture = "orig" |
|
|
|
if gamma is not None: |
|
assert isinstance(gamma, float) |
|
if not gamma >= 0: |
|
raise UserError("--gamma must be non-negative") |
|
desc += f"-gamma{gamma:g}" |
|
args.loss_kwargs.r1_gamma = gamma |
|
|
|
if kimg is not None: |
|
assert isinstance(kimg, int) |
|
if not kimg >= 1: |
|
raise UserError("--kimg must be at least 1") |
|
desc += f"-kimg{kimg:d}" |
|
args.total_kimg = kimg |
|
|
|
if batch is not None: |
|
assert isinstance(batch, int) |
|
if not (batch >= 1 and batch % args.num_gpus == 0): |
|
raise UserError( |
|
"--batch must be at least 1 and divisible by --gpus and --nodes" |
|
) |
|
desc += f"-batch{batch}" |
|
args.batch_size = batch |
|
args.batch_gpu = batch // (args.num_gpus) |
|
args.slurm = slurm |
|
|
|
|
|
|
|
|
|
|
|
if aug is None: |
|
aug = "ada" |
|
else: |
|
assert isinstance(aug, str) |
|
desc += f"-{aug}" |
|
|
|
if aug == "ada": |
|
args.ada_target = 0.6 |
|
|
|
elif aug == "noaug": |
|
pass |
|
|
|
elif aug == "fixed": |
|
if p is None: |
|
raise UserError(f"--aug={aug} requires specifying --p") |
|
|
|
else: |
|
raise UserError(f"--aug={aug} not supported") |
|
|
|
if p is not None: |
|
assert isinstance(p, float) |
|
if aug != "fixed": |
|
raise UserError("--p can only be specified with --aug=fixed") |
|
if not 0 <= p <= 1: |
|
raise UserError("--p must be between 0 and 1") |
|
desc += f"-p{p:g}" |
|
args.augment_p = p |
|
|
|
if target is not None: |
|
assert isinstance(target, float) |
|
if aug != "ada": |
|
raise UserError("--target can only be specified with --aug=ada") |
|
if not 0 <= target <= 1: |
|
raise UserError("--target must be between 0 and 1") |
|
desc += f"-target{target:g}" |
|
args.ada_target = target |
|
|
|
assert augpipe is None or isinstance(augpipe, str) |
|
if augpipe is None: |
|
augpipe = "bgc" |
|
else: |
|
if aug == "noaug": |
|
raise UserError("--augpipe cannot be specified with --aug=noaug") |
|
desc += f"-{augpipe}" |
|
|
|
augpipe_specs = { |
|
"blit": dict(xflip=1, rotate90=1, xint=1), |
|
"geom": dict(scale=1, rotate=1, aniso=1, xfrac=1), |
|
"color": dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), |
|
"filter": dict(imgfilter=1), |
|
"noise": dict(noise=1), |
|
"cutout": dict(cutout=1), |
|
"bg": dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), |
|
"bgc": dict( |
|
xflip=1, |
|
rotate90=1, |
|
xint=1, |
|
scale=1, |
|
rotate=1, |
|
aniso=1, |
|
xfrac=1, |
|
brightness=1, |
|
contrast=1, |
|
lumaflip=1, |
|
hue=1, |
|
saturation=1, |
|
), |
|
"bgcf": dict( |
|
xflip=1, |
|
rotate90=1, |
|
xint=1, |
|
scale=1, |
|
rotate=1, |
|
aniso=1, |
|
xfrac=1, |
|
brightness=1, |
|
contrast=1, |
|
lumaflip=1, |
|
hue=1, |
|
saturation=1, |
|
imgfilter=1, |
|
), |
|
"bgcfn": dict( |
|
xflip=1, |
|
rotate90=1, |
|
xint=1, |
|
scale=1, |
|
rotate=1, |
|
aniso=1, |
|
xfrac=1, |
|
brightness=1, |
|
contrast=1, |
|
lumaflip=1, |
|
hue=1, |
|
saturation=1, |
|
imgfilter=1, |
|
noise=1, |
|
), |
|
"bgcfnc": dict( |
|
xflip=1, |
|
rotate90=1, |
|
xint=1, |
|
scale=1, |
|
rotate=1, |
|
aniso=1, |
|
xfrac=1, |
|
brightness=1, |
|
contrast=1, |
|
lumaflip=1, |
|
hue=1, |
|
saturation=1, |
|
imgfilter=1, |
|
noise=1, |
|
cutout=1, |
|
), |
|
} |
|
|
|
assert augpipe in augpipe_specs |
|
if aug != "noaug": |
|
args.augment_kwargs = dnnlib.EasyDict( |
|
class_name="training.augment.AugmentPipe", **augpipe_specs[augpipe] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
resume_specs = { |
|
"ffhq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl", |
|
"ffhq512": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl", |
|
"ffhq1024": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl", |
|
"celebahq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl", |
|
"lsundog256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl", |
|
} |
|
|
|
assert resume is None or isinstance(resume, str) |
|
if resume is None: |
|
resume = "noresume" |
|
elif resume == "noresume": |
|
desc += "-noresume" |
|
elif resume in resume_specs: |
|
desc += f"-resume{resume}" |
|
args.resume_pkl = resume_specs[resume] |
|
else: |
|
desc += "-resumecustom" |
|
args.resume_pkl = resume |
|
|
|
if resume != "noresume": |
|
args.ada_kimg = 100 |
|
args.ema_rampup = None |
|
|
|
if freezed is not None: |
|
assert isinstance(freezed, int) |
|
if not freezed >= 0: |
|
raise UserError("--freezed must be non-negative") |
|
desc += f"-freezed{freezed:d}" |
|
args.D_kwargs.block_kwargs.freeze_layers = freezed |
|
|
|
|
|
|
|
|
|
|
|
if fp32 is None: |
|
fp32 = False |
|
assert isinstance(fp32, bool) |
|
if fp32: |
|
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 |
|
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None |
|
|
|
if nhwc is None: |
|
nhwc = False |
|
assert isinstance(nhwc, bool) |
|
if nhwc: |
|
args.G_kwargs.synthesis_kwargs.fp16_channels_last = ( |
|
args.D_kwargs.block_kwargs.fp16_channels_last |
|
) = True |
|
|
|
if nobench is None: |
|
nobench = False |
|
assert isinstance(nobench, bool) |
|
if nobench: |
|
args.cudnn_benchmark = False |
|
|
|
if allow_tf32 is None: |
|
allow_tf32 = False |
|
assert isinstance(allow_tf32, bool) |
|
if allow_tf32: |
|
args.allow_tf32 = True |
|
|
|
if workers is not None: |
|
assert isinstance(workers, int) |
|
if not workers >= 1: |
|
raise UserError("--workers must be at least 1") |
|
args.data_loader_kwargs.num_workers = workers |
|
|
|
return desc, args |
|
|
|
|
|
|
|
|
|
|
|
def subprocess_fn(rank, args, world_size=1, dist_url="", temp_dir="", slurm=False): |
|
dnnlib.util.Logger( |
|
file_name=os.path.join(args.run_dir, "log.txt"), |
|
file_mode="a", |
|
should_flush=True, |
|
) |
|
|
|
|
|
if not slurm and args.num_gpus > 1: |
|
init_file = os.path.abspath(os.path.join(temp_dir, ".torch_distributed_init")) |
|
if os.name == "nt": |
|
init_method = "file:///" + init_file.replace("\\", "/") |
|
torch.distributed.init_process_group( |
|
backend="gloo", |
|
init_method=init_method, |
|
rank=rank, |
|
world_size=args.num_gpus, |
|
) |
|
else: |
|
init_method = f"file://{init_file}" |
|
torch.distributed.init_process_group( |
|
backend="nccl", |
|
init_method=init_method, |
|
rank=rank, |
|
world_size=args.num_gpus, |
|
) |
|
|
|
sync_device = torch.device("cuda", rank) if args.num_gpus > 1 else None |
|
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) |
|
local_rank = rank |
|
|
|
elif slurm: |
|
rank = int(os.environ.get("SLURM_PROCID")) |
|
local_rank = int(os.environ.get("SLURM_LOCALID")) |
|
torch.distributed.init_process_group( |
|
backend="nccl", init_method=dist_url, rank=rank, world_size=world_size |
|
) |
|
else: |
|
rank = local_rank = 0 |
|
|
|
if rank != 0: |
|
custom_ops.verbosity = "none" |
|
|
|
|
|
training_loop.training_loop( |
|
rank=rank, local_rank=local_rank, temp_dir=temp_dir, **args |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class CommaSeparatedList(click.ParamType): |
|
name = "list" |
|
|
|
def convert(self, value, param, ctx): |
|
_ = param, ctx |
|
if value is None or value.lower() == "none" or value == "": |
|
return [] |
|
return value.split(",") |
|
|
|
|
|
|
|
|
|
|
|
def main(args, outdir, master_node="", port=40000, dry_run=False, **config_kwargs): |
|
"""Train a GAN using the techniques described in the paper |
|
"Training Generative Adversarial Networks with Limited Data". |
|
|
|
Examples: |
|
|
|
\b |
|
# Train with custom dataset using 1 GPU. |
|
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 |
|
|
|
\b |
|
# Train class-conditional CIFAR-10 using 2 GPUs. |
|
python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\ |
|
--gpus=2 --cfg=cifar --cond=1 |
|
|
|
\b |
|
# Transfer learn MetFaces from FFHQ using 4 GPUs. |
|
python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\ |
|
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 |
|
|
|
\b |
|
# Reproduce original StyleGAN2 config F. |
|
python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\ |
|
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug |
|
|
|
\b |
|
Base configs (--cfg): |
|
auto Automatically select reasonable defaults based on resolution |
|
and GPU count. Good starting point for new datasets. |
|
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. |
|
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. |
|
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. |
|
paper1024 Reproduce results for MetFaces at 1024x1024. |
|
cifar Reproduce results for CIFAR-10 at 32x32. |
|
|
|
\b |
|
Transfer learning source networks (--resume): |
|
ffhq256 FFHQ trained at 256x256 resolution. |
|
ffhq512 FFHQ trained at 512x512 resolution. |
|
ffhq1024 FFHQ trained at 1024x1024 resolution. |
|
celebahq256 CelebA-HQ trained at 256x256 resolution. |
|
lsundog256 LSUN Dog trained at 256x256 resolution. |
|
<PATH or URL> Custom network pickle. |
|
""" |
|
dnnlib.util.Logger(should_flush=True) |
|
|
|
|
|
config_kwargs = vars(args) |
|
run_desc, args = setup_training_loop_kwargs(**config_kwargs) |
|
args.metrics = ["fid50k_full"] |
|
|
|
if args.exp_name is None: |
|
|
|
prev_run_dirs = [] |
|
if os.path.isdir(outdir): |
|
prev_run_dirs = [ |
|
x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x)) |
|
] |
|
prev_run_ids = [re.match(r"^\d+", x) for x in prev_run_dirs] |
|
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] |
|
cur_run_id = max(prev_run_ids, default=-1) + 1 |
|
args.run_dir = os.path.join(outdir, f"{cur_run_id:05d}-{run_desc}") |
|
assert not os.path.exists(args.run_dir) |
|
else: |
|
args.run_dir = os.path.join(outdir, args.exp_name) |
|
|
|
|
|
print() |
|
print("Training options:") |
|
|
|
print() |
|
print(f"Output directory: {args.run_dir}") |
|
print(f"Training data: {args.training_set_kwargs.root}") |
|
print(f"Training duration: {args.total_kimg} kimg") |
|
print(f"Number of GPUs: {args.num_gpus}") |
|
print(f"Number of images: {args.training_set_kwargs.max_size}") |
|
print(f"Image resolution: {args.training_set_kwargs.resolution}") |
|
print(f"Conditional model: {args.training_set_kwargs.load_labels}") |
|
print(f"Dataset x-flips: {args.training_set_kwargs.xflip}") |
|
print() |
|
|
|
|
|
if dry_run: |
|
print("Dry run; exiting.") |
|
return |
|
|
|
|
|
print("Creating output directory...") |
|
if not os.path.exists(args.run_dir): |
|
os.makedirs(args.run_dir, exist_ok=True) |
|
with open(os.path.join(args.run_dir, "training_options.json"), "wt") as f: |
|
json.dump(args, f, indent=2) |
|
|
|
|
|
if args.slurm: |
|
n_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES")) |
|
n_gpus_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE").split("(")[0]) |
|
world_size = n_gpus_per_node * n_nodes |
|
dist_url = "tcp://" |
|
dist_url += master_node |
|
dist_url += ":" + str(port) |
|
print("Dist url ", dist_url) |
|
temp_dir = "/scratch/slurm_tmpdir/" + str(os.environ.get("SLURM_JOB_ID")) |
|
subprocess_fn( |
|
rank=-1, |
|
args=args, |
|
world_size=world_size, |
|
dist_url=dist_url, |
|
temp_dir=temp_dir, |
|
slurm=args.slurm, |
|
) |
|
else: |
|
|
|
print("Launching processes...") |
|
torch.multiprocessing.set_start_method("spawn") |
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
if args.num_gpus == 1: |
|
subprocess_fn(rank=0, args=args, temp_dir=temp_dir) |
|
else: |
|
torch.multiprocessing.spawn( |
|
fn=subprocess_fn, |
|
args=(args, args.num_gpus, "", temp_dir), |
|
nprocs=args.num_gpus, |
|
) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser_ = parser.get_parser() |
|
args = parser_.parse_args() |
|
main(args) |
|
|
|
|
|
|