import argparse import shlex import os import pickle import swapae.util as util import swapae.models as models import swapae.models.networks as networks import swapae.data as data import swapae.evaluation as evaluation import swapae.optimizers as optimizers from swapae.util import IterationCounter from swapae.util import Visualizer class BaseOptions(): def initialize(self, parser): # experiment specifics parser.add_argument('--name', type=str, default="ffhq512_pretrained", help='name of the experiment. It decides where to store samples and models') parser.add_argument('--easy_label', type=str, default="") parser.add_argument('--num_gpus', type=int, default=1, help='#GPUs to use. 0 means CPU mode') parser.add_argument('--checkpoints_dir', type=str, default='/home/xtli/Documents/GITHUB/swapping-autoencoder-pytorch/checkpoints/', help='models are saved here') parser.add_argument('--model', type=str, default='swapping_autoencoder', help='which model to use') parser.add_argument('--optimizer', type=str, default='swapping_autoencoder', help='which model to use') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') parser.add_argument('--resume_iter', type=str, default="latest", help="# iterations (in thousands) to resume") parser.add_argument('--num_classes', type=int, default=0) # input/output sizes parser.add_argument('--batch_size', type=int, default=1, help='input batch size') parser.add_argument('--preprocess', type=str, default='resize', help='scaling and cropping of images at load time.') parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.') parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') parser.add_argument('--preprocess_crop_padding', type=int, default=None, help='padding parameter of transforms.RandomCrop(). It is not used if --preprocess does not contain crop option.') parser.add_argument('--no_flip', action='store_true') parser.add_argument('--shuffle_dataset', type=str, default=None, choices=('true', 'false')) # for setting inputs parser.add_argument('--dataroot', type=str, default="/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/") parser.add_argument('--dataset_mode', type=str, default='imagefolder') parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') # networks parser.add_argument("--netG", default="StyleGAN2Resnet") parser.add_argument("--netD", default="StyleGAN2") parser.add_argument("--netE", default="StyleGAN2Resnet") parser.add_argument("--netPatchD", default="StyleGAN2") parser.add_argument("--use_antialias", type=util.str2bool, default=True) parser.add_argument("-f", "--config_file", type=str, default='models/swap/json/sem_cons.json', help='json files including all arguments') parser.add_argument("--local_rank", type=int) return parser def gather_options(self, command=None): parser = AugmentedArgumentParser() parser.custom_command = command # get basic options parser = self.initialize(parser) # get the basic options opt, unknown = parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) # modify network-related parser options parser = networks.modify_commandline_options(parser, self.isTrain) # modify optimizer-related parser options optimizer_name = opt.optimizer optimizer_option_setter = optimizers.get_option_setter(optimizer_name) parser = optimizer_option_setter(parser, self.isTrain) # modify dataset-related parser options dataset_mode = opt.dataset_mode dataset_option_setter = data.get_option_setter(dataset_mode) parser = dataset_option_setter(parser, self.isTrain) # modify parser options related to iteration_counting parser = Visualizer.modify_commandline_options(parser, self.isTrain) # modify parser options related to iteration_counting parser = IterationCounter.modify_commandline_options(parser, self.isTrain) # modify evaluation-related parser options evaluation_option_setter = evaluation.get_option_setter() parser = evaluation_option_setter(parser, self.isTrain) opt, unknown = parser.parse_known_args() opt = parser.parse_args() self.parser = parser return opt def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) def option_file_path(self, opt, makedir=False): expr_dir = os.path.join(opt.checkpoints_dir, opt.name) if makedir: util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, 'opt') return file_name def save_options(self, opt): file_name = self.option_file_path(opt, makedir=True) with open(file_name + '.txt', 'wt') as opt_file: for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) with open(file_name + '.pkl', 'wb') as opt_file: pickle.dump(opt, opt_file) def parse(self, save=False, command=None): opt = self.gather_options(command) opt.isTrain = self.isTrain # train or test self.print_options(opt) if opt.isTrain: self.save_options(opt) opt.dataroot = os.path.expanduser(opt.dataroot) assert opt.num_gpus <= opt.batch_size, "Batch size must not be smaller than num_gpus" return opt class TrainOptions(BaseOptions): def __init__(self): super().__init__() self.isTrain = True def initialize(self, parser): super().initialize(parser) parser.add_argument('--continue_train', type=util.str2bool, default=False, help="resume training from last checkpoint") parser.add_argument('--pretrained_name', type=str, default=None, help="Load weights from the checkpoint of another experiment") return parser class TestOptions(BaseOptions): def __init__(self): super().__init__() self.isTrain = False def initialize(self, parser): super().initialize(parser) parser.add_argument("--result_dir", type=str, default="results") return parser class AugmentedArgumentParser(argparse.ArgumentParser): def parse_args(self, args=None, namespace=None): """ Enables passing bash commands as arguments to the class. """ print("parsing args...") if args is None and hasattr(self, 'custom_command') and self.custom_command is not None: print('using custom command') print(self.custom_command) args = shlex.split(self.custom_command)[2:] return super().parse_args(args, namespace) def parse_known_args(self, args=None, namespace=None): if args is None and hasattr(self, 'custom_command') and self.custom_command is not None: args = shlex.split(self.custom_command)[2:] return super().parse_known_args(args, namespace) def add_argument(self, *args, **kwargs): """ Support for providing a new argument type called "str2bool" Example: parser.add_argument("--my_option", type=util.str2bool, default=|bool|) 1. "python train.py" sets my_option to be |bool| 2. "python train.py --my_option" sets my_option to be True 3. "python train.py --my_option False" sets my_option to be False 4. "python train.py --my_option True" sets my_options to be True https://stackoverflow.com/a/43357954 """ if 'type' in kwargs and kwargs['type'] == util.str2bool: if 'nargs' not in kwargs: kwargs['nargs'] = "?" if 'const' not in kwargs: kwargs['const'] = True super().add_argument(*args, **kwargs)