import argparse import torch import os from datetime import datetime import time import torch import random import numpy as np import sys class Options(object): """docstring for Options""" def __init__(self): super(Options, self).__init__() def initialize(self): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--mode', type=str, default='train', help='Mode of code. [train|test]') parser.add_argument('--model', type=str, default='ganimation', help='[ganimation|stargan], see model.__init__ from more details.') parser.add_argument('--lucky_seed', type=int, default=0, help='seed for random initialize, 0 to use current time.') parser.add_argument('--visdom_env', type=str, default="main", help='visdom env.') parser.add_argument('--visdom_port', type=int, default=8097, help='visdom port.') parser.add_argument('--visdom_display_id', type=int, default=1, help='set value larger than 0 to display with visdom.') parser.add_argument('--results', type=str, default="results", help='save test results to this path.') parser.add_argument('--interpolate_len', type=int, default=5, help='interpolate length for test.') parser.add_argument('--no_test_eval', action='store_true', help='do not use eval mode during test time.') parser.add_argument('--save_test_gif', action='store_true', help='save gif images instead of the concatenation of static images.') parser.add_argument('--data_root', required=False, help='paths to data set.') parser.add_argument('--imgs_dir', type=str, default="imgs", help='path to image') parser.add_argument('--aus_pkl', type=str, default="aus_openface.pkl", help='AUs pickle dictionary.') parser.add_argument('--train_csv', type=str, default="train_ids.csv", help='train images paths') parser.add_argument('--test_csv', type=str, default="test_ids.csv", help='test images paths') parser.add_argument('--batch_size', type=int, default=25, help='input batch size.') parser.add_argument('--serial_batches', action='store_true', help='if specified, input images in order.') parser.add_argument('--n_threads', type=int, default=6, help='number of workers to load data.') parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='maximum number of samples.') parser.add_argument('--resize_or_crop', type=str, default='none', help='Preprocessing image, [resize_and_crop|crop|none]') parser.add_argument('--load_size', type=int, default=148, help='scale image to this size.') parser.add_argument('--final_size', type=int, default=128, help='crop image to this size.') parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip image.') parser.add_argument('--no_aus_noise', action='store_true', help='if specified, add noise to target AUs.') parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids, eg. 0,1,2; -1 for cpu.') parser.add_argument('--ckpt_dir', type=str, default='./ckpts', help='directory to save check points.') parser.add_argument('--load_epoch', type=int, default=0, help='load epoch; 0: do not load') parser.add_argument('--log_file', type=str, default="logs.txt", help='log loss') parser.add_argument('--opt_file', type=str, default="opt.txt", help='options file') # train options parser.add_argument('--img_nc', type=int, default=3, help='image number of channel') parser.add_argument('--aus_nc', type=int, default=17, help='aus number of channel') parser.add_argument('--ngf', type=int, default=64, help='ngf') parser.add_argument('--ndf', type=int, default=64, help='ndf') parser.add_argument('--use_dropout', action='store_true', help='if specified, use dropout.') parser.add_argument('--gan_type', type=str, default='wgan-gp', help='GAN loss [wgan-gp|lsgan|gan]') parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [batch|instance|none]') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate') parser.add_argument('--niter_decay', type=int, default=10, help='# of iter to linearly decay learning rate to zero') # loss options parser.add_argument('--lambda_dis', type=float, default=1.0, help='discriminator weight in loss') parser.add_argument('--lambda_aus', type=float, default=160.0, help='AUs weight in loss') parser.add_argument('--lambda_rec', type=float, default=10.0, help='reconstruct loss weight') parser.add_argument('--lambda_mask', type=float, default=0, help='mse loss weight') parser.add_argument('--lambda_tv', type=float, default=0, help='total variation loss weight') parser.add_argument('--lambda_wgan_gp', type=float, default=10., help='wgan gradient penalty weight') # frequency options parser.add_argument('--train_gen_iter', type=int, default=5, help='train G every n interations.') parser.add_argument('--print_losses_freq', type=int, default=100, help='print log every print_freq step.') parser.add_argument('--plot_losses_freq', type=int, default=20000, help='plot log every plot_freq step.') parser.add_argument('--sample_img_freq', type=int, default=2000, help='draw image every sample_img_freq step.') parser.add_argument('--save_epoch_freq', type=int, default=2, help='save checkpoint every save_epoch_freq epoch.') return parser def parse(self): parser = self.initialize() parser.set_defaults(name=datetime.now().strftime("%y%m%d_%H%M%S")) opt = parser.parse_args() dataset_name = os.path.basename(opt.data_root.strip('/')) # update checkpoint dir if opt.mode == 'train' and opt.load_epoch == 0: opt.ckpt_dir = os.path.join(opt.ckpt_dir, dataset_name, opt.model, opt.name) if not os.path.exists(opt.ckpt_dir): os.makedirs(opt.ckpt_dir) # if test, disable visdom, update results path if opt.mode == "test": opt.visdom_display_id = 0 opt.results = os.path.join(opt.results, "%s_%s_%s" % (dataset_name, opt.model, opt.load_epoch)) if not os.path.exists(opt.results): os.makedirs(opt.results) # set gpu device str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: cur_id = int(str_id) if cur_id >= 0: opt.gpu_ids.append(cur_id) if len(opt.gpu_ids) > 0: torch.cuda.set_device(opt.gpu_ids[0]) # set seed if opt.lucky_seed == 0: opt.lucky_seed = int(time.time()) random.seed(a=opt.lucky_seed) np.random.seed(seed=opt.lucky_seed) torch.manual_seed(opt.lucky_seed) if len(opt.gpu_ids) > 0: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.cuda.manual_seed(opt.lucky_seed) torch.cuda.manual_seed_all(opt.lucky_seed) # write command to file script_dir = opt.ckpt_dir with open(os.path.join(os.path.join(script_dir, "run_script.sh")), 'a+') as f: f.write("[%5s][%s]python %s\n" % (opt.mode, opt.name, ' '.join(sys.argv))) # print and write options file msg = '' msg += '------------------- [%5s][%s]Options --------------------\n' % (opt.mode, opt.name) for k, v in sorted(vars(opt).items()): comment = '' default_v = parser.get_default(k) if v != default_v: comment = '\t[default: %s]' % str(default_v) msg += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) msg += '--------------------- [%5s][%s]End ----------------------\n' % (opt.mode, opt.name) print(msg) with open(os.path.join(os.path.join(script_dir, "opt.txt")), 'a+') as f: f.write(msg + '\n\n') return opt