sunshineatnoon
Add application file
1b2a9b1
import argparse
import shlex
import os
import pickle
import swapae.util as util
import swapae.models as models
import swapae.models.networks as networks
import swapae.data as data
import swapae.evaluation as evaluation
import swapae.optimizers as optimizers
from swapae.util import IterationCounter
from swapae.util import Visualizer
class BaseOptions():
def initialize(self, parser):
# experiment specifics
parser.add_argument('--name', type=str, default="ffhq512_pretrained", help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--easy_label', type=str, default="")
parser.add_argument('--num_gpus', type=int, default=1, help='#GPUs to use. 0 means CPU mode')
parser.add_argument('--checkpoints_dir', type=str, default='/home/xtli/Documents/GITHUB/swapping-autoencoder-pytorch/checkpoints/', help='models are saved here')
parser.add_argument('--model', type=str, default='swapping_autoencoder', help='which model to use')
parser.add_argument('--optimizer', type=str, default='swapping_autoencoder', help='which model to use')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--resume_iter', type=str, default="latest",
help="# iterations (in thousands) to resume")
parser.add_argument('--num_classes', type=int, default=0)
# input/output sizes
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--preprocess', type=str, default='resize', help='scaling and cropping of images at load time.')
parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.')
parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
parser.add_argument('--preprocess_crop_padding', type=int, default=None, help='padding parameter of transforms.RandomCrop(). It is not used if --preprocess does not contain crop option.')
parser.add_argument('--no_flip', action='store_true')
parser.add_argument('--shuffle_dataset', type=str, default=None, choices=('true', 'false'))
# for setting inputs
parser.add_argument('--dataroot', type=str, default="/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/")
parser.add_argument('--dataset_mode', type=str, default='imagefolder')
parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data')
# networks
parser.add_argument("--netG", default="StyleGAN2Resnet")
parser.add_argument("--netD", default="StyleGAN2")
parser.add_argument("--netE", default="StyleGAN2Resnet")
parser.add_argument("--netPatchD", default="StyleGAN2")
parser.add_argument("--use_antialias", type=util.str2bool, default=True)
parser.add_argument("-f", "--config_file", type=str, default='models/swap/json/sem_cons.json', help='json files including all arguments')
parser.add_argument("--local_rank", type=int)
return parser
def gather_options(self, command=None):
parser = AugmentedArgumentParser()
parser.custom_command = command
# get basic options
parser = self.initialize(parser)
# get the basic options
opt, unknown = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
# modify network-related parser options
parser = networks.modify_commandline_options(parser, self.isTrain)
# modify optimizer-related parser options
optimizer_name = opt.optimizer
optimizer_option_setter = optimizers.get_option_setter(optimizer_name)
parser = optimizer_option_setter(parser, self.isTrain)
# modify dataset-related parser options
dataset_mode = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_mode)
parser = dataset_option_setter(parser, self.isTrain)
# modify parser options related to iteration_counting
parser = Visualizer.modify_commandline_options(parser, self.isTrain)
# modify parser options related to iteration_counting
parser = IterationCounter.modify_commandline_options(parser, self.isTrain)
# modify evaluation-related parser options
evaluation_option_setter = evaluation.get_option_setter()
parser = evaluation_option_setter(parser, self.isTrain)
opt, unknown = parser.parse_known_args()
opt = parser.parse_args()
self.parser = parser
return opt
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def option_file_path(self, opt, makedir=False):
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
if makedir:
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt')
return file_name
def save_options(self, opt):
file_name = self.option_file_path(opt, makedir=True)
with open(file_name + '.txt', 'wt') as opt_file:
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
with open(file_name + '.pkl', 'wb') as opt_file:
pickle.dump(opt, opt_file)
def parse(self, save=False, command=None):
opt = self.gather_options(command)
opt.isTrain = self.isTrain # train or test
self.print_options(opt)
if opt.isTrain:
self.save_options(opt)
opt.dataroot = os.path.expanduser(opt.dataroot)
assert opt.num_gpus <= opt.batch_size, "Batch size must not be smaller than num_gpus"
return opt
class TrainOptions(BaseOptions):
def __init__(self):
super().__init__()
self.isTrain = True
def initialize(self, parser):
super().initialize(parser)
parser.add_argument('--continue_train', type=util.str2bool, default=False, help="resume training from last checkpoint")
parser.add_argument('--pretrained_name', type=str, default=None,
help="Load weights from the checkpoint of another experiment")
return parser
class TestOptions(BaseOptions):
def __init__(self):
super().__init__()
self.isTrain = False
def initialize(self, parser):
super().initialize(parser)
parser.add_argument("--result_dir", type=str, default="results")
return parser
class AugmentedArgumentParser(argparse.ArgumentParser):
def parse_args(self, args=None, namespace=None):
""" Enables passing bash commands as arguments to the class.
"""
print("parsing args...")
if args is None and hasattr(self, 'custom_command') and self.custom_command is not None:
print('using custom command')
print(self.custom_command)
args = shlex.split(self.custom_command)[2:]
return super().parse_args(args, namespace)
def parse_known_args(self, args=None, namespace=None):
if args is None and hasattr(self, 'custom_command') and self.custom_command is not None:
args = shlex.split(self.custom_command)[2:]
return super().parse_known_args(args, namespace)
def add_argument(self, *args, **kwargs):
""" Support for providing a new argument type called "str2bool"
Example:
parser.add_argument("--my_option", type=util.str2bool, default=|bool|)
1. "python train.py" sets my_option to be |bool|
2. "python train.py --my_option" sets my_option to be True
3. "python train.py --my_option False" sets my_option to be False
4. "python train.py --my_option True" sets my_options to be True
https://stackoverflow.com/a/43357954
"""
if 'type' in kwargs and kwargs['type'] == util.str2bool:
if 'nargs' not in kwargs:
kwargs['nargs'] = "?"
if 'const' not in kwargs:
kwargs['const'] = True
super().add_argument(*args, **kwargs)