import os import json import argparse class BaseOptions(): def initialize(self, parser): parser.add_argument('--nChannel', metavar='N', default=100, type=int, help='number of channels') parser.add_argument('--maxIter', metavar='T', default=1000, type=int, help='number of maximum iterations') parser.add_argument('--lr', metavar='LR', default=0.1, type=float, help='learning rate') parser.add_argument('--nConv', metavar='M', default=2, type=int, help='number of convolutional layers') parser.add_argument("--work_dir", type=str, default="./", help='project directory') parser.add_argument("--out_dir", type=str, default=None, help='logging output') parser.add_argument("--use_wandb", type=int, default=0, help='use wandb or not') parser.add_argument("--data_path", type=str, default="images", help="data path") parser.add_argument("--img_path", type=str, default=None, help="image path") parser.add_argument('--crop_size', type=int, default= 224, help='crop_size') parser.add_argument("--batch_size", type=int, default=1, help='batch size') parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') parser.add_argument("--use_slic", default = 1, type=int, help="choose to use slic or gt label") parser.add_argument("-f", "--config_file", type=str, default='models/week0417/json/single_scale_grouping_ft.json', help='json files including all arguments') parser.add_argument("--log_freq", type=int, default=10, help='frequency to print log') parser.add_argument("--display_freq", type=int, default=100, help='frequency to save visualization') parser.add_argument("--pretrained_ae", type=str, default = "/home/xli/WORKDIR/07-16/transformer/cpk.pth") parser.add_argument("--pretrained_path", type=str, default=None, help='pretrained reconstruction model') parser.add_argument('--momentum', type=float, default=0.5, help='momentum for sgd, alpha parameter for adam') parser.add_argument('--beta', type=float, default=0.999, help='beta parameter for adam') parser.add_argument("--l1_loss_wt", default=1.0, type=float) parser.add_argument("--perceptual_loss_wt", default=1.0, type=float) parser.add_argument('--project_name', type=str, default='test_time', help='project name') parser.add_argument("--save_freq", type=int, default=2000, help='frequency to save model') parser.add_argument("--local_rank", type=int) parser.add_argument('--lr_decay_freq', type=int, default=3000, help='frequency to decay learning rate') parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') parser.add_argument('--sp_num', type=int, default=None, help='superpixel number') parser.add_argument('--add_self_loops', type=int, default=1, help='set to 1 to add self loops in GCNs') parser.add_argument('--test_time', type=int, default=0, help='set to 1 to add self loops in GCNs') parser.add_argument('--add_texture_epoch', type=int, default=1000, help='when to add texture synthesis') parser.add_argument('--add_clustering_epoch', type=int, default=1000, help='when to add grouping') parser.add_argument('--temperature', type=int, default=1, help='temperature in SoftMax') parser.add_argument('--gumbel', type=int, default=0, help='if use gumbel SoftMax') parser.add_argument('--patch_size', type=int, default=40, help='patch size in texture synthesis') parser.add_argument("--netE_num_downsampling_sp", default=4, type=int) parser.add_argument('--num_classes', type=int, default=0) parser.add_argument( "--netG_num_base_resnet_layers", default=2, type=int, help="The number of resnet layers before the upsampling layers." ) parser.add_argument("--netG_scale_capacity", default=1.0, type=float) parser.add_argument("--netG_resnet_ch", type=int, default=256) parser.add_argument("--spatial_code_ch", default=8, type=int) parser.add_argument("--texture_code_ch", default=256, type=int) parser.add_argument("--netE_scale_capacity", default=1.0, type=float) parser.add_argument("--netE_num_downsampling_gl", default=2, type=int) parser.add_argument("--netE_nc_steepness", default=2.0, type=float) parser.add_argument("--spatial_code_dim", type=int, default=256, help="codebook entry dimension") return parser def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) def save_options(self, opt): os.makedirs(opt.out_dir, exist_ok=True) file_name = os.path.join(opt.out_dir, 'exp_args.txt') with open(file_name, 'wt') as opt_file: for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) opt_file.close() def gather_options(self): parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation') self.parser = self.initialize(parser) opt = self.parser.parse_args() opt = self.update_with_json(opt) opt.out_dir = os.path.join(opt.work_dir, opt.exp_name) opt.use_slic = (opt.use_slic == 1) opt.use_wandb = (opt.use_wandb == 1) # logging self.print_options(opt) self.save_options(opt) return opt def update_with_json(self, args): arg_dict = vars(args) # arguments house keeping with open(args.config_file, 'r') as f: arg_str = f.read() file_args = json.loads(arg_str) arg_dict.update(file_args) args = argparse.Namespace(**arg_dict) return args