File size: 5,499 Bytes
2d5f249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse


class TrainOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()

        gen = self.parser.add_argument_group('General')
        gen.add_argument(
            '--resume',
            dest='resume',
            default=False,
            action='store_true',
            help='Resume from checkpoint (Use latest checkpoint by default')

        io = self.parser.add_argument_group('io')
        io.add_argument('--log_dir',
                        default='logs',
                        help='Directory to store logs')
        io.add_argument(
            '--pretrained_checkpoint',
            default=None,
            help='Load a pretrained checkpoint at the beginning training')

        train = self.parser.add_argument_group('Training Options')
        train.add_argument('--num_epochs',
                           type=int,
                           default=200,
                           help='Total number of training epochs')
        train.add_argument('--regressor',
                           type=str,
                           choices=['hmr', 'pymaf_net'],
                           default='pymaf_net',
                           help='Name of the SMPL regressor.')
        train.add_argument('--cfg_file',
                           type=str,
                           default='./configs/pymaf_config.yaml',
                           help='config file path for PyMAF.')
        train.add_argument(
            '--img_res',
            type=int,
            default=224,
            help='Rescale bounding boxes to size [img_res, img_res] before feeding them in the network'
        )
        train.add_argument(
            '--rot_factor',
            type=float,
            default=30,
            help='Random rotation in the range [-rot_factor, rot_factor]')
        train.add_argument(
            '--noise_factor',
            type=float,
            default=0.4,
            help='Randomly multiply pixel values with factor in the range [1-noise_factor, 1+noise_factor]'
        )
        train.add_argument(
            '--scale_factor',
            type=float,
            default=0.25,
            help='Rescale bounding boxes by a factor of [1-scale_factor,1+scale_factor]'
        )
        train.add_argument(
            '--openpose_train_weight',
            default=0.,
            help='Weight for OpenPose keypoints during training')
        train.add_argument('--gt_train_weight',
                           default=1.,
                           help='Weight for GT keypoints during training')
        train.add_argument('--eval_dataset',
                           type=str,
                           default='h36m-p2-mosh',
                           help='Name of the evaluation dataset.')
        train.add_argument('--single_dataset',
                           default=False,
                           action='store_true',
                           help='Use a single dataset')
        train.add_argument('--single_dataname',
                           type=str,
                           default='h36m',
                           help='Name of the single dataset.')
        train.add_argument('--eval_pve',
                           default=False,
                           action='store_true',
                           help='evaluate PVE')
        train.add_argument('--overwrite',
                           default=False,
                           action='store_true',
                           help='overwrite the latest checkpoint')

        train.add_argument('--distributed',
                           action='store_true',
                           help='Use distributed training')
        train.add_argument('--dist_backend',
                           default='nccl',
                           type=str,
                           help='distributed backend')
        train.add_argument('--dist_url',
                           default='tcp://127.0.0.1:10356',
                           type=str,
                           help='url used to set up distributed training')
        train.add_argument('--world_size',
                           default=1,
                           type=int,
                           help='number of nodes for distributed training')
        train.add_argument("--local_rank", default=0, type=int)
        train.add_argument('--rank',
                           default=0,
                           type=int,
                           help='node rank for distributed training')
        train.add_argument(
            '--multiprocessing_distributed',
            action='store_true',
            help='Use multi-processing distributed training to launch '
            'N processes per node, which has N GPUs. This is the '
            'fastest way to use PyTorch for either single node or '
            'multi node data parallel training')

        misc = self.parser.add_argument_group('Misc Options')
        misc.add_argument('--misc',
                          help="Modify config options using the command-line",
                          default=None,
                          nargs=argparse.REMAINDER)
        return

    def parse_args(self):
        """Parse input arguments."""
        self.args = self.parser.parse_args()
        self.save_dump()
        return self.args

    def save_dump(self):
        """Store all argument values to a json file.
        The default location is logs/expname/args.json.
        """
        pass