File size: 9,250 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import argparse
import shlex
import os
import pickle

import swapae.util as util
import swapae.models as models
import swapae.models.networks as networks
import swapae.data as data
import swapae.evaluation as evaluation
import swapae.optimizers as optimizers
from swapae.util import IterationCounter
from swapae.util import Visualizer


class BaseOptions():
    def initialize(self, parser):
        # experiment specifics
        parser.add_argument('--name', type=str, default="ffhq512_pretrained", help='name of the experiment. It decides where to store samples and models')
        parser.add_argument('--easy_label', type=str, default="")

        parser.add_argument('--num_gpus', type=int, default=1, help='#GPUs to use. 0 means CPU mode')
        parser.add_argument('--checkpoints_dir', type=str, default='/home/xtli/Documents/GITHUB/swapping-autoencoder-pytorch/checkpoints/', help='models are saved here')
        parser.add_argument('--model', type=str, default='swapping_autoencoder', help='which model to use')
        parser.add_argument('--optimizer', type=str, default='swapping_autoencoder', help='which model to use')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        parser.add_argument('--resume_iter', type=str, default="latest",
                            help="# iterations (in thousands) to resume")
        parser.add_argument('--num_classes', type=int, default=0)

        # input/output sizes
        parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
        parser.add_argument('--preprocess', type=str, default='resize', help='scaling and cropping of images at load time.')
        parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.')
        parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
        parser.add_argument('--preprocess_crop_padding', type=int, default=None, help='padding parameter of transforms.RandomCrop(). It is not used if --preprocess does not contain crop option.')
        parser.add_argument('--no_flip', action='store_true')
        parser.add_argument('--shuffle_dataset', type=str, default=None, choices=('true', 'false'))

        # for setting inputs
        parser.add_argument('--dataroot', type=str, default="/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/")
        parser.add_argument('--dataset_mode', type=str, default='imagefolder')
        parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data')

        # networks
        parser.add_argument("--netG", default="StyleGAN2Resnet")
        parser.add_argument("--netD", default="StyleGAN2")
        parser.add_argument("--netE", default="StyleGAN2Resnet")
        parser.add_argument("--netPatchD", default="StyleGAN2")
        parser.add_argument("--use_antialias", type=util.str2bool, default=True)

        parser.add_argument("-f", "--config_file", type=str, default='models/swap/json/sem_cons.json', help='json files including all arguments')
        parser.add_argument("--local_rank", type=int)

        return parser

    def gather_options(self, command=None):
        parser = AugmentedArgumentParser()
        parser.custom_command = command

        # get basic options
        parser = self.initialize(parser)

        # get the basic options
        opt, unknown = parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(parser, self.isTrain)

        # modify network-related parser options
        parser = networks.modify_commandline_options(parser, self.isTrain)

        # modify optimizer-related parser options
        optimizer_name = opt.optimizer
        optimizer_option_setter = optimizers.get_option_setter(optimizer_name)
        parser = optimizer_option_setter(parser, self.isTrain)

        # modify dataset-related parser options
        dataset_mode = opt.dataset_mode
        dataset_option_setter = data.get_option_setter(dataset_mode)
        parser = dataset_option_setter(parser, self.isTrain)

        # modify parser options related to iteration_counting
        parser = Visualizer.modify_commandline_options(parser, self.isTrain)

        # modify parser options related to iteration_counting
        parser = IterationCounter.modify_commandline_options(parser, self.isTrain)

        # modify evaluation-related parser options
        evaluation_option_setter = evaluation.get_option_setter()
        parser = evaluation_option_setter(parser, self.isTrain)

        opt, unknown = parser.parse_known_args()

        opt = parser.parse_args()
        self.parser = parser
        return opt

    def print_options(self, opt):
        """Print and save options

        It will print both current options and default values(if different).
        It will save options into a text file / [checkpoints_dir] / opt.txt
        """
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(opt).items()):
            comment = ''
            default = self.parser.get_default(k)
            if v != default:
                comment = '\t[default: %s]' % str(default)
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)

    def option_file_path(self, opt, makedir=False):
        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
        if makedir:
            util.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, 'opt')
        return file_name

    def save_options(self, opt):
        file_name = self.option_file_path(opt, makedir=True)
        with open(file_name + '.txt', 'wt') as opt_file:
            for k, v in sorted(vars(opt).items()):
                comment = ''
                default = self.parser.get_default(k)
                if v != default:
                    comment = '\t[default: %s]' % str(default)
                opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))

        with open(file_name + '.pkl', 'wb') as opt_file:
            pickle.dump(opt, opt_file)

    def parse(self, save=False, command=None):
        opt = self.gather_options(command)
        opt.isTrain = self.isTrain   # train or test
        self.print_options(opt)
        if opt.isTrain:
            self.save_options(opt)

        opt.dataroot = os.path.expanduser(opt.dataroot)

        assert opt.num_gpus <= opt.batch_size, "Batch size must not be smaller than num_gpus"
        return opt



class TrainOptions(BaseOptions):
    def __init__(self):
        super().__init__()
        self.isTrain = True

    def initialize(self, parser):
        super().initialize(parser)
        parser.add_argument('--continue_train', type=util.str2bool, default=False, help="resume training from last checkpoint")
        parser.add_argument('--pretrained_name', type=str, default=None,
                            help="Load weights from the checkpoint of another experiment")

        return parser


class TestOptions(BaseOptions):
    def __init__(self):
        super().__init__()
        self.isTrain = False

    def initialize(self, parser):
        super().initialize(parser)
        parser.add_argument("--result_dir", type=str, default="results")
        return parser


class AugmentedArgumentParser(argparse.ArgumentParser):
    def parse_args(self, args=None, namespace=None):
        """ Enables passing bash commands as arguments to the class.
        """
        print("parsing args...")
        if args is None and hasattr(self, 'custom_command') and self.custom_command is not None:
            print('using custom command')
            print(self.custom_command)
            args = shlex.split(self.custom_command)[2:]
        return super().parse_args(args, namespace)
    
    def parse_known_args(self, args=None, namespace=None):
        if args is None and hasattr(self, 'custom_command') and self.custom_command is not None:
            args = shlex.split(self.custom_command)[2:]
        return super().parse_known_args(args, namespace)
    
    def add_argument(self, *args, **kwargs):
        """ Support for providing a new argument type called "str2bool"
        
        Example:
        parser.add_argument("--my_option", type=util.str2bool, default=|bool|)
        
        1. "python train.py" sets my_option to be |bool|
        2. "python train.py --my_option" sets my_option to be True
        3. "python train.py --my_option False" sets my_option to be False
        4. "python train.py --my_option True" sets my_options to be True
        
        https://stackoverflow.com/a/43357954
        """
        
        if 'type' in kwargs and kwargs['type'] == util.str2bool:
            if 'nargs' not in kwargs:
                kwargs['nargs'] = "?"
            if 'const' not in kwargs:
                kwargs['const'] = True
        super().add_argument(*args, **kwargs)