mckabue commited on
Commit
1c0c18d
·
verified ·
1 Parent(s): 1f55314

RE_UPLOAD-REBUILD-RESTART

Browse files
model/layout-model-training/tools/train_net.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The script is based on https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import json
8
+ from collections import OrderedDict
9
+ import detectron2.utils.comm as comm
10
+ import detectron2.data.transforms as T
11
+ from detectron2.checkpoint import DetectionCheckpointer
12
+ from detectron2.config import get_cfg
13
+ from detectron2.data import DatasetMapper, build_detection_train_loader
14
+
15
+ from detectron2.data.datasets import register_coco_instances
16
+
17
+ from detectron2.engine import (
18
+ DefaultTrainer,
19
+ default_argument_parser,
20
+ default_setup,
21
+ hooks,
22
+ launch,
23
+ )
24
+ from detectron2.evaluation import (
25
+ COCOEvaluator,
26
+ verify_results,
27
+ )
28
+ from detectron2.modeling import GeneralizedRCNNWithTTA
29
+ import pandas as pd
30
+
31
+
32
+ def get_augs(cfg):
33
+ """Add all the desired augmentations here. A list of availble augmentations
34
+ can be found here:
35
+ https://detectron2.readthedocs.io/en/latest/modules/data_transforms.html
36
+ """
37
+ augs = [
38
+ T.ResizeShortestEdge(
39
+ cfg.INPUT.MIN_SIZE_TRAIN,
40
+ cfg.INPUT.MAX_SIZE_TRAIN,
41
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
42
+ )
43
+ ]
44
+ if cfg.INPUT.CROP.ENABLED:
45
+ augs.append(
46
+ T.RandomCrop_CategoryAreaConstraint(
47
+ cfg.INPUT.CROP.TYPE,
48
+ cfg.INPUT.CROP.SIZE,
49
+ cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
50
+ cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
51
+ )
52
+ )
53
+ horizontal_flip: bool = cfg.INPUT.RANDOM_FLIP == "horizontal"
54
+ augs.append(T.RandomFlip(horizontal=horizontal_flip, vertical=not horizontal_flip))
55
+ # Rotate the image between -90 to 0 degrees clockwise around the centre
56
+ augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
57
+ return augs
58
+
59
+
60
+ class Trainer(DefaultTrainer):
61
+ """
62
+ We use the "DefaultTrainer" which contains pre-defined default logic for
63
+ standard training workflow. They may not work for you, especially if you
64
+ are working on a new research project. In that case you can use the cleaner
65
+ "SimpleTrainer", or write your own training loop. You can use
66
+ "tools/plain_train_net.py" as an example.
67
+
68
+ Adapted from:
69
+ https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
70
+ """
71
+
72
+ @classmethod
73
+ def build_train_loader(cls, cfg):
74
+ mapper = DatasetMapper(cfg, is_train=True, augmentations=get_augs(cfg))
75
+ return build_detection_train_loader(cfg, mapper=mapper)
76
+
77
+ @classmethod
78
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
79
+ """
80
+ Returns:
81
+ DatasetEvaluator or None
82
+
83
+ It is not implemented by default.
84
+ """
85
+ return COCOEvaluator(dataset_name, cfg, True, output_folder)
86
+
87
+ @classmethod
88
+ def test_with_TTA(cls, cfg, model):
89
+ logger = logging.getLogger("detectron2.trainer")
90
+ # In the end of training, run an evaluation with TTA
91
+ # Only support some R-CNN models.
92
+ logger.info("Running inference with test-time augmentation ...")
93
+ model = GeneralizedRCNNWithTTA(cfg, model)
94
+ evaluators = [
95
+ cls.build_evaluator(
96
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
97
+ )
98
+ for name in cfg.DATASETS.TEST
99
+ ]
100
+ res = cls.test(cfg, model, evaluators)
101
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
102
+ return res
103
+
104
+ @classmethod
105
+ def eval_and_save(cls, cfg, model):
106
+ evaluators = [
107
+ cls.build_evaluator(
108
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference")
109
+ )
110
+ for name in cfg.DATASETS.TEST
111
+ ]
112
+ res = cls.test(cfg, model, evaluators)
113
+ pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, "eval.csv"))
114
+ return res
115
+
116
+
117
+ def setup(args):
118
+ """
119
+ Create configs and perform basic setups.
120
+ """
121
+ cfg = get_cfg()
122
+
123
+ if args.config_file != "":
124
+ cfg.merge_from_file(args.config_file)
125
+ cfg.merge_from_list(args.opts)
126
+
127
+ with open(args.json_annotation_train, "r") as fp:
128
+ anno_file = json.load(fp)
129
+
130
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(anno_file["categories"])
131
+ del anno_file
132
+
133
+ cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
134
+ cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
135
+ cfg.freeze()
136
+ default_setup(cfg, args)
137
+ return cfg
138
+
139
+
140
+ def main(args):
141
+ # Register Datasets
142
+ register_coco_instances(
143
+ f"{args.dataset_name}-train",
144
+ {},
145
+ args.json_annotation_train,
146
+ args.image_path_train,
147
+ )
148
+
149
+ register_coco_instances(
150
+ f"{args.dataset_name}-val",
151
+ {},
152
+ args.json_annotation_val,
153
+ args.image_path_val
154
+ )
155
+ cfg = setup(args)
156
+
157
+ if args.eval_only:
158
+ model = Trainer.build_model(cfg)
159
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
160
+ cfg.MODEL.WEIGHTS, resume=args.resume
161
+ )
162
+ res = Trainer.test(cfg, model)
163
+
164
+ if cfg.TEST.AUG.ENABLED:
165
+ res.update(Trainer.test_with_TTA(cfg, model))
166
+ if comm.is_main_process():
167
+ verify_results(cfg, res)
168
+
169
+ # Save the evaluation results
170
+ pd.DataFrame(res).to_csv(f"{cfg.OUTPUT_DIR}/eval.csv")
171
+ return res
172
+
173
+ # Ensure that the Output directory exists
174
+ os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
175
+
176
+ """
177
+ If you'd like to do anything fancier than the standard training logic,
178
+ consider writing your own training loop (see plain_train_net.py) or
179
+ subclassing the trainer.
180
+ """
181
+ trainer = Trainer(cfg)
182
+ trainer.resume_or_load(resume=args.resume)
183
+ trainer.register_hooks(
184
+ [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
185
+ )
186
+ if cfg.TEST.AUG.ENABLED:
187
+ trainer.register_hooks(
188
+ [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
189
+ )
190
+ return trainer.train()
191
+
192
+
193
+ if __name__ == "__main__":
194
+ parser = default_argument_parser()
195
+
196
+ # Extra Configurations for dataset names and paths
197
+ parser.add_argument(
198
+ "--dataset_name",
199
+ help="The Dataset Name")
200
+ parser.add_argument(
201
+ "--json_annotation_train",
202
+ help="The path to the training set JSON annotation",
203
+ )
204
+ parser.add_argument(
205
+ "--image_path_train",
206
+ help="The path to the training set image folder",
207
+ )
208
+ parser.add_argument(
209
+ "--json_annotation_val",
210
+ help="The path to the validation set JSON annotation",
211
+ )
212
+ parser.add_argument(
213
+ "--image_path_val",
214
+ help="The path to the validation set image folder",
215
+ )
216
+ args = parser.parse_args()
217
+ print("Command Line Args:", args)
218
+
219
+ # Dataset Registration is moved to the main function to support multi-gpu training
220
+ # See ref https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517
221
+
222
+ launch(
223
+ main,
224
+ args.num_gpus,
225
+ num_machines=args.num_machines,
226
+ machine_rank=args.machine_rank,
227
+ dist_url=args.dist_url,
228
+ args=(args,),
229
+ )