ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from argparse import ArgumentParser
def get_parser():
parser = ArgumentParser(description="StyleGANv2")
# General options.
parser.add_argument(
"--json_config",
type=str,
default="",
help="Json config from where to load the configuration parameters.",
)
parser.add_argument(
"--exp_name", help="Experiment name", required=False, default="default_name"
)
parser.add_argument(
"--base_root",
help="Where to save the results",
required=False,
default="training-runs",
metavar="DIR",
)
parser.add_argument(
"--slurm_logdir",
help="Where to save the logs from SLURM",
required=False,
default="training-runs",
metavar="DIR",
)
parser.add_argument(
"--partition",
help="Partition name for SLURM",
required=False,
default="learnlab",
)
parser.add_argument(
"--slurm_time",
help="Time in minutes that an experiment runs in SLURM",
default=3200,
type=int,
metavar="INT",
)
parser.add_argument(
"--gpus", help="Number of GPUs to use [default: 1]", type=int, metavar="INT"
)
parser.add_argument(
"--nodes",
help="Number of nodes to use [default: 1]",
type=int,
metavar="INT",
default=1,
)
parser.add_argument(
"--snap", help="Snapshot interval [default: 50 ticks]", type=int, metavar="INT"
)
parser.add_argument(
"--seed", help="Random seed [default: 0]", type=int, metavar="INT"
)
parser.add_argument(
"--port",
help="Port number for DDP connection [default: 40000]",
type=int,
default=40000,
metavar="INT",
)
parser.add_argument(
"--dry-run", help="Print training options and exit", type=bool, metavar="BOOL"
)
# Dataset.
parser.add_argument(
"--data_root",
help="Path where to find the data",
metavar="PATH",
required=False,
default=None,
)
parser.add_argument(
"--data",
help="Training data (directory or zip)",
metavar="PATH",
required=False,
default="datasets/cocostuff_128.zip",
)
# parser.add_argument('--cond',
# help='Train conditional model based on dataset labels [default: false]',
# type=bool, metavar='BOOL')
parser.add_argument(
"--class_cond",
help="Use class labels to condition model [default: false]",
type=bool,
metavar="BOOL",
)
parser.add_argument(
"--subset",
help="Train with only N images [default: all]",
type=int,
metavar="INT",
)
parser.add_argument(
"--mirror",
help="Enable dataset x-flips [default: false]",
type=bool,
metavar="BOOL",
)
parser.add_argument(
"--label_dim",
help="nb of classes when using class conditioning",
default=1000,
type=int,
metavar="INT",
)
# IC-GAN options for Dataset.
parser.add_argument(
"--root_feats",
help="Training data features as instance conditioning (hdf5 file)",
metavar="PATH",
required=False,
default="",
)
parser.add_argument(
"--root_nns",
help="NN Training data for each instance conditioning (hdf5 file)",
metavar="PATH",
required=False,
default="",
)
parser.add_argument(
"--instance_cond",
help="Use instance features to condition model [default: false]",
type=bool,
metavar="BOOL",
)
parser.add_argument(
"--feature_augmentation",
help="Use horizontal flips in instances to obtain instance features [default: false]",
type=bool,
metavar="BOOL",
)
# Base config.
parser.add_argument(
"--cfg",
help="Base config [default: auto]",
choices=["auto", "stylegan2", "paper256", "paper512", "paper1024", "cifar"],
)
parser.add_argument("--gamma", help="Override R1 gamma", type=float)
parser.add_argument(
"--kimg", help="Override training duration", type=int, metavar="INT"
)
parser.add_argument("--batch", help="Override batch size", type=int, metavar="INT")
parser.add_argument("--lrate", help="Override lrate", type=float)
parser.add_argument("--num_channel_g", help="Override width of generator", type=int)
parser.add_argument(
"--num_channel_d", help="Override width of discriminator", type=int
)
parser.add_argument(
"--channel_max_g", help="Override max width of generator", type=int
)
parser.add_argument(
"--channel_max_d", help="Override max width of discriminator", type=int
)
parser.add_argument(
"--hidden_dim_c",
help="Override embedding size in maping network for class conditioning",
type=int,
)
parser.add_argument(
"--hidden_dim_h",
help="Override embedding size in maping network for class conditioning",
type=int,
)
parser.add_argument(
"--es_patience",
help="Early stopping patience",
type=int,
default=100000000,
metavar="INT",
)
# Discriminator augmentation.
parser.add_argument(
"--aug",
help="Augmentation mode [default: ada]",
choices=["noaug", "ada", "fixed"],
)
parser.add_argument(
"--p", help="Augmentation probability for --aug=fixed", type=float
)
parser.add_argument("--target", help="ADA target value for --aug=ada", type=float)
parser.add_argument(
"--augpipe",
help="Augmentation pipeline [default: bgc]",
choices=[
"blit",
"geom",
"color",
"filter",
"noise",
"cutout",
"bg",
"bgc",
"bgcf",
"bgcfn",
"bgcfnc",
],
)
# Transfer learning.
parser.add_argument(
"--resume", help="Resume training [default: noresume]", metavar="PKL"
)
parser.add_argument(
"--freezed", help="Freeze-D [default: 0 layers]", type=int, metavar="INT"
)
# Performance options.
parser.add_argument(
"--fp32", help="Disable mixed-precision training", type=bool, metavar="BOOL"
)
parser.add_argument(
"--nhwc", help="Use NHWC memory format with FP16", type=bool, metavar="BOOL"
)
parser.add_argument(
"--nobench", help="Disable cuDNN benchmarking", type=bool, metavar="BOOL"
)
parser.add_argument(
"--allow-tf32",
help="Allow PyTorch to use TF32 internally",
type=bool,
metavar="BOOL",
)
parser.add_argument(
"--workers",
help="Override number of DataLoader workers",
type=int,
metavar="INT",
)
## Experiment setup
parser.add_argument(
"--slurm",
help="Using SLURM to launch the experiment in a cluster",
type=bool,
metavar="BOOL",
)
return parser
# ----------------------------------------------------------------------------