ilyassmoummad
commited on
Create default.py
Browse files- default.py +202 -0
default.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os.path as op
|
6 |
+
import yaml
|
7 |
+
from yacs.config import CfgNode as CN
|
8 |
+
|
9 |
+
from config import comm
|
10 |
+
|
11 |
+
|
12 |
+
_C = CN()
|
13 |
+
|
14 |
+
_C.BASE = ['']
|
15 |
+
_C.NAME = ''
|
16 |
+
_C.DATA_DIR = ''
|
17 |
+
_C.DIST_BACKEND = 'nccl'
|
18 |
+
_C.GPUS = (0,)
|
19 |
+
# _C.LOG_DIR = ''
|
20 |
+
_C.MULTIPROCESSING_DISTRIBUTED = True
|
21 |
+
_C.OUTPUT_DIR = ''
|
22 |
+
_C.PIN_MEMORY = True
|
23 |
+
_C.PRINT_FREQ = 20
|
24 |
+
_C.RANK = 0
|
25 |
+
_C.VERBOSE = True
|
26 |
+
_C.WORKERS = 4
|
27 |
+
_C.MODEL_SUMMARY = False
|
28 |
+
|
29 |
+
_C.AMP = CN()
|
30 |
+
_C.AMP.ENABLED = False
|
31 |
+
_C.AMP.MEMORY_FORMAT = 'nchw'
|
32 |
+
|
33 |
+
# Cudnn related params
|
34 |
+
_C.CUDNN = CN()
|
35 |
+
_C.CUDNN.BENCHMARK = True
|
36 |
+
_C.CUDNN.DETERMINISTIC = False
|
37 |
+
_C.CUDNN.ENABLED = True
|
38 |
+
|
39 |
+
# common params for NETWORK
|
40 |
+
_C.MODEL = CN()
|
41 |
+
_C.MODEL.NAME = 'cls_hrnet'
|
42 |
+
_C.MODEL.INIT_WEIGHTS = True
|
43 |
+
_C.MODEL.PRETRAINED = ''
|
44 |
+
_C.MODEL.PRETRAINED_LAYERS = ['*']
|
45 |
+
_C.MODEL.NUM_CLASSES = 1000
|
46 |
+
_C.MODEL.SPEC = CN(new_allowed=True)
|
47 |
+
|
48 |
+
_C.LOSS = CN(new_allowed=True)
|
49 |
+
_C.LOSS.LABEL_SMOOTHING = 0.0
|
50 |
+
_C.LOSS.LOSS = 'softmax'
|
51 |
+
|
52 |
+
# DATASET related params
|
53 |
+
_C.DATASET = CN()
|
54 |
+
_C.DATASET.ROOT = ''
|
55 |
+
_C.DATASET.DATASET = 'imagenet'
|
56 |
+
_C.DATASET.TRAIN_SET = 'train'
|
57 |
+
_C.DATASET.TEST_SET = 'val'
|
58 |
+
_C.DATASET.DATA_FORMAT = 'jpg'
|
59 |
+
_C.DATASET.LABELMAP = ''
|
60 |
+
_C.DATASET.TRAIN_TSV_LIST = []
|
61 |
+
_C.DATASET.TEST_TSV_LIST = []
|
62 |
+
_C.DATASET.SAMPLER = 'default'
|
63 |
+
|
64 |
+
_C.DATASET.TARGET_SIZE = -1
|
65 |
+
|
66 |
+
# training data augmentation
|
67 |
+
_C.INPUT = CN()
|
68 |
+
_C.INPUT.MEAN = [0.485, 0.456, 0.406]
|
69 |
+
_C.INPUT.STD = [0.229, 0.224, 0.225]
|
70 |
+
|
71 |
+
# data augmentation
|
72 |
+
_C.AUG = CN()
|
73 |
+
_C.AUG.SCALE = (0.08, 1.0)
|
74 |
+
_C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
|
75 |
+
_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
|
76 |
+
_C.AUG.GRAY_SCALE = 0.0
|
77 |
+
_C.AUG.GAUSSIAN_BLUR = 0.0
|
78 |
+
_C.AUG.DROPBLOCK_LAYERS = [3, 4]
|
79 |
+
_C.AUG.DROPBLOCK_KEEP_PROB = 1.0
|
80 |
+
_C.AUG.DROPBLOCK_BLOCK_SIZE = 7
|
81 |
+
_C.AUG.MIXUP_PROB = 0.0
|
82 |
+
_C.AUG.MIXUP = 0.0
|
83 |
+
_C.AUG.MIXCUT = 0.0
|
84 |
+
_C.AUG.MIXCUT_MINMAX = []
|
85 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
86 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
87 |
+
_C.AUG.MIXCUT_AND_MIXUP = False
|
88 |
+
_C.AUG.INTERPOLATION = 2
|
89 |
+
_C.AUG.TIMM_AUG = CN(new_allowed=True)
|
90 |
+
_C.AUG.TIMM_AUG.USE_LOADER = False
|
91 |
+
_C.AUG.TIMM_AUG.USE_TRANSFORM = False
|
92 |
+
|
93 |
+
# train
|
94 |
+
_C.TRAIN = CN()
|
95 |
+
|
96 |
+
_C.TRAIN.AUTO_RESUME = True
|
97 |
+
_C.TRAIN.CHECKPOINT = ''
|
98 |
+
_C.TRAIN.LR_SCHEDULER = CN(new_allowed=True)
|
99 |
+
_C.TRAIN.SCALE_LR = True
|
100 |
+
_C.TRAIN.LR = 0.001
|
101 |
+
|
102 |
+
_C.TRAIN.OPTIMIZER = 'sgd'
|
103 |
+
_C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True)
|
104 |
+
_C.TRAIN.MOMENTUM = 0.9
|
105 |
+
_C.TRAIN.WD = 0.0001
|
106 |
+
_C.TRAIN.WITHOUT_WD_LIST = []
|
107 |
+
_C.TRAIN.NESTEROV = True
|
108 |
+
# for adam
|
109 |
+
_C.TRAIN.GAMMA1 = 0.99
|
110 |
+
_C.TRAIN.GAMMA2 = 0.0
|
111 |
+
|
112 |
+
_C.TRAIN.BEGIN_EPOCH = 0
|
113 |
+
_C.TRAIN.END_EPOCH = 100
|
114 |
+
|
115 |
+
_C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
116 |
+
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
|
117 |
+
_C.TRAIN.SHUFFLE = True
|
118 |
+
|
119 |
+
_C.TRAIN.EVAL_BEGIN_EPOCH = 0
|
120 |
+
|
121 |
+
_C.TRAIN.DETECT_ANOMALY = False
|
122 |
+
|
123 |
+
_C.TRAIN.CLIP_GRAD_NORM = 0.0
|
124 |
+
_C.TRAIN.SAVE_ALL_MODELS = False
|
125 |
+
|
126 |
+
# testing
|
127 |
+
_C.TEST = CN()
|
128 |
+
|
129 |
+
# size of images for each device
|
130 |
+
_C.TEST.BATCH_SIZE_PER_GPU = 32
|
131 |
+
_C.TEST.CENTER_CROP = True
|
132 |
+
_C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
133 |
+
_C.TEST.INTERPOLATION = 2
|
134 |
+
_C.TEST.MODEL_FILE = ''
|
135 |
+
_C.TEST.REAL_LABELS = False
|
136 |
+
_C.TEST.VALID_LABELS = ''
|
137 |
+
|
138 |
+
_C.FINETUNE = CN()
|
139 |
+
_C.FINETUNE.FINETUNE = False
|
140 |
+
_C.FINETUNE.USE_TRAIN_AUG = False
|
141 |
+
_C.FINETUNE.BASE_LR = 0.003
|
142 |
+
_C.FINETUNE.BATCH_SIZE = 512
|
143 |
+
_C.FINETUNE.EVAL_EVERY = 3000
|
144 |
+
_C.FINETUNE.TRAIN_MODE = True
|
145 |
+
# _C.FINETUNE.MODEL_FILE = ''
|
146 |
+
_C.FINETUNE.FROZEN_LAYERS = []
|
147 |
+
_C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True)
|
148 |
+
_C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step'
|
149 |
+
|
150 |
+
# debug
|
151 |
+
_C.DEBUG = CN()
|
152 |
+
_C.DEBUG.DEBUG = False
|
153 |
+
|
154 |
+
|
155 |
+
def _update_config_from_file(config, cfg_file):
|
156 |
+
config.defrost()
|
157 |
+
with open(cfg_file, 'r') as f:
|
158 |
+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
159 |
+
|
160 |
+
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
161 |
+
if cfg:
|
162 |
+
_update_config_from_file(
|
163 |
+
config, op.join(op.dirname(cfg_file), cfg)
|
164 |
+
)
|
165 |
+
print('=> merge config from {}'.format(cfg_file))
|
166 |
+
config.merge_from_file(cfg_file)
|
167 |
+
config.freeze()
|
168 |
+
|
169 |
+
|
170 |
+
def update_config(config, args):
|
171 |
+
_update_config_from_file(config, args.cfg)
|
172 |
+
|
173 |
+
config.defrost()
|
174 |
+
config.merge_from_list(args.opts)
|
175 |
+
if config.TRAIN.SCALE_LR:
|
176 |
+
config.TRAIN.LR *= comm.world_size
|
177 |
+
file_name, _ = op.splitext(op.basename(args.cfg))
|
178 |
+
config.NAME = file_name + config.NAME
|
179 |
+
config.RANK = comm.rank
|
180 |
+
|
181 |
+
if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD:
|
182 |
+
config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH
|
183 |
+
|
184 |
+
if 'timm' == config.TRAIN.OPTIMIZER:
|
185 |
+
config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR
|
186 |
+
|
187 |
+
aug = config.AUG
|
188 |
+
if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX:
|
189 |
+
aug.MIXUP_PROB = 1.0
|
190 |
+
config.freeze()
|
191 |
+
|
192 |
+
|
193 |
+
def save_config(cfg, path):
|
194 |
+
if comm.is_main_process():
|
195 |
+
with open(path, 'w') as f:
|
196 |
+
f.write(cfg.dump())
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == '__main__':
|
200 |
+
import sys
|
201 |
+
with open(sys.argv[1], 'w') as f:
|
202 |
+
print(_C, file=f)
|