Spaces:
Runtime error
Runtime error
File size: 9,250 Bytes
1b2a9b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
import argparse
import shlex
import os
import pickle
import swapae.util as util
import swapae.models as models
import swapae.models.networks as networks
import swapae.data as data
import swapae.evaluation as evaluation
import swapae.optimizers as optimizers
from swapae.util import IterationCounter
from swapae.util import Visualizer
class BaseOptions():
def initialize(self, parser):
# experiment specifics
parser.add_argument('--name', type=str, default="ffhq512_pretrained", help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--easy_label', type=str, default="")
parser.add_argument('--num_gpus', type=int, default=1, help='#GPUs to use. 0 means CPU mode')
parser.add_argument('--checkpoints_dir', type=str, default='/home/xtli/Documents/GITHUB/swapping-autoencoder-pytorch/checkpoints/', help='models are saved here')
parser.add_argument('--model', type=str, default='swapping_autoencoder', help='which model to use')
parser.add_argument('--optimizer', type=str, default='swapping_autoencoder', help='which model to use')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--resume_iter', type=str, default="latest",
help="# iterations (in thousands) to resume")
parser.add_argument('--num_classes', type=int, default=0)
# input/output sizes
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--preprocess', type=str, default='resize', help='scaling and cropping of images at load time.')
parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.')
parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
parser.add_argument('--preprocess_crop_padding', type=int, default=None, help='padding parameter of transforms.RandomCrop(). It is not used if --preprocess does not contain crop option.')
parser.add_argument('--no_flip', action='store_true')
parser.add_argument('--shuffle_dataset', type=str, default=None, choices=('true', 'false'))
# for setting inputs
parser.add_argument('--dataroot', type=str, default="/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/")
parser.add_argument('--dataset_mode', type=str, default='imagefolder')
parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data')
# networks
parser.add_argument("--netG", default="StyleGAN2Resnet")
parser.add_argument("--netD", default="StyleGAN2")
parser.add_argument("--netE", default="StyleGAN2Resnet")
parser.add_argument("--netPatchD", default="StyleGAN2")
parser.add_argument("--use_antialias", type=util.str2bool, default=True)
parser.add_argument("-f", "--config_file", type=str, default='models/swap/json/sem_cons.json', help='json files including all arguments')
parser.add_argument("--local_rank", type=int)
return parser
def gather_options(self, command=None):
parser = AugmentedArgumentParser()
parser.custom_command = command
# get basic options
parser = self.initialize(parser)
# get the basic options
opt, unknown = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
# modify network-related parser options
parser = networks.modify_commandline_options(parser, self.isTrain)
# modify optimizer-related parser options
optimizer_name = opt.optimizer
optimizer_option_setter = optimizers.get_option_setter(optimizer_name)
parser = optimizer_option_setter(parser, self.isTrain)
# modify dataset-related parser options
dataset_mode = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_mode)
parser = dataset_option_setter(parser, self.isTrain)
# modify parser options related to iteration_counting
parser = Visualizer.modify_commandline_options(parser, self.isTrain)
# modify parser options related to iteration_counting
parser = IterationCounter.modify_commandline_options(parser, self.isTrain)
# modify evaluation-related parser options
evaluation_option_setter = evaluation.get_option_setter()
parser = evaluation_option_setter(parser, self.isTrain)
opt, unknown = parser.parse_known_args()
opt = parser.parse_args()
self.parser = parser
return opt
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def option_file_path(self, opt, makedir=False):
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
if makedir:
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt')
return file_name
def save_options(self, opt):
file_name = self.option_file_path(opt, makedir=True)
with open(file_name + '.txt', 'wt') as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
with open(file_name + '.pkl', 'wb') as opt_file:
pickle.dump(opt, opt_file)
def parse(self, save=False, command=None):
opt = self.gather_options(command)
opt.isTrain = self.isTrain # train or test
self.print_options(opt)
if opt.isTrain:
self.save_options(opt)
opt.dataroot = os.path.expanduser(opt.dataroot)
assert opt.num_gpus <= opt.batch_size, "Batch size must not be smaller than num_gpus"
return opt
class TrainOptions(BaseOptions):
def __init__(self):
super().__init__()
self.isTrain = True
def initialize(self, parser):
super().initialize(parser)
parser.add_argument('--continue_train', type=util.str2bool, default=False, help="resume training from last checkpoint")
parser.add_argument('--pretrained_name', type=str, default=None,
help="Load weights from the checkpoint of another experiment")
return parser
class TestOptions(BaseOptions):
def __init__(self):
super().__init__()
self.isTrain = False
def initialize(self, parser):
super().initialize(parser)
parser.add_argument("--result_dir", type=str, default="results")
return parser
class AugmentedArgumentParser(argparse.ArgumentParser):
def parse_args(self, args=None, namespace=None):
""" Enables passing bash commands as arguments to the class.
"""
print("parsing args...")
if args is None and hasattr(self, 'custom_command') and self.custom_command is not None:
print('using custom command')
print(self.custom_command)
args = shlex.split(self.custom_command)[2:]
return super().parse_args(args, namespace)
def parse_known_args(self, args=None, namespace=None):
if args is None and hasattr(self, 'custom_command') and self.custom_command is not None:
args = shlex.split(self.custom_command)[2:]
return super().parse_known_args(args, namespace)
def add_argument(self, *args, **kwargs):
""" Support for providing a new argument type called "str2bool"
Example:
parser.add_argument("--my_option", type=util.str2bool, default=|bool|)
1. "python train.py" sets my_option to be |bool|
2. "python train.py --my_option" sets my_option to be True
3. "python train.py --my_option False" sets my_option to be False
4. "python train.py --my_option True" sets my_options to be True
https://stackoverflow.com/a/43357954
"""
if 'type' in kwargs and kwargs['type'] == util.str2bool:
if 'nargs' not in kwargs:
kwargs['nargs'] = "?"
if 'const' not in kwargs:
kwargs['const'] = True
super().add_argument(*args, **kwargs)
|