Shannon Shen commited on
Commit
29d3845
·
1 Parent(s): ad8dc97

black formatting

Browse files
Files changed (1) hide show
  1. tools/train_net.py +51 -32
tools/train_net.py CHANGED
@@ -14,7 +14,13 @@ from detectron2.data import DatasetMapper, build_detection_train_loader
14
 
15
  from detectron2.data.datasets import register_coco_instances
16
 
17
- from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
 
 
 
 
 
 
18
  from detectron2.evaluation import (
19
  COCOEvaluator,
20
  verify_results,
@@ -25,12 +31,14 @@ import pandas as pd
25
 
26
  def get_augs(cfg):
27
  """Add all the desired augmentations here. A list of availble augmentations
28
- can be found here:
29
  https://detectron2.readthedocs.io/en/latest/modules/data_transforms.html
30
  """
31
  augs = [
32
  T.ResizeShortestEdge(
33
- cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
 
 
34
  )
35
  ]
36
  if cfg.INPUT.CROP.ENABLED:
@@ -42,9 +50,8 @@ def get_augs(cfg):
42
  cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
43
  )
44
  )
45
- horizontal_flip: bool = (cfg.INPUT.RANDOM_FLIP == 'horizontal')
46
- augs.append(T.RandomFlip(horizontal=horizontal_flip,
47
- vertical=not horizontal_flip))
48
  # Rotate the image between -90 to 0 degrees clockwise around the centre
49
  augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
50
  return augs
@@ -86,8 +93,7 @@ class Trainer(DefaultTrainer):
86
  model = GeneralizedRCNNWithTTA(cfg, model)
87
  evaluators = [
88
  cls.build_evaluator(
89
- cfg, name, output_folder=os.path.join(
90
- cfg.OUTPUT_DIR, "inference_TTA")
91
  )
92
  for name in cfg.DATASETS.TEST
93
  ]
@@ -99,13 +105,12 @@ class Trainer(DefaultTrainer):
99
  def eval_and_save(cls, cfg, model):
100
  evaluators = [
101
  cls.build_evaluator(
102
- cfg, name, output_folder=os.path.join(
103
- cfg.OUTPUT_DIR, "inference")
104
  )
105
  for name in cfg.DATASETS.TEST
106
  ]
107
  res = cls.test(cfg, model, evaluators)
108
- pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
109
  return res
110
 
111
 
@@ -114,12 +119,12 @@ def setup(args):
114
  Create configs and perform basic setups.
115
  """
116
  cfg = get_cfg()
117
-
118
  if args.config_file != "":
119
  cfg.merge_from_file(args.config_file)
120
  cfg.merge_from_list(args.opts)
121
 
122
- with open(args.json_annotation_train, 'r') as fp:
123
  anno_file = json.load(fp)
124
 
125
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(anno_file["categories"])
@@ -134,13 +139,19 @@ def setup(args):
134
 
135
  def main(args):
136
  # Register Datasets
137
- register_coco_instances(f"{args.dataset_name}-train", {},
138
- args.json_annotation_train,
139
- args.image_path_train)
 
 
 
140
 
141
- register_coco_instances(f"{args.dataset_name}-val", {},
142
- args.json_annotation_val,
143
- args.image_path_val)
 
 
 
144
  cfg = setup(args)
145
 
146
  if args.eval_only:
@@ -156,7 +167,7 @@ def main(args):
156
  verify_results(cfg, res)
157
 
158
  # Save the evaluation results
159
- pd.DataFrame(res).to_csv(f'{cfg.OUTPUT_DIR}/eval.csv')
160
  return res
161
 
162
  # Ensure that the Output directory exists
@@ -174,8 +185,7 @@ def main(args):
174
  )
175
  if cfg.TEST.AUG.ENABLED:
176
  trainer.register_hooks(
177
- [hooks.EvalHook(
178
- 0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
179
  )
180
  return trainer.train()
181
 
@@ -184,16 +194,25 @@ if __name__ == "__main__":
184
  parser = default_argument_parser()
185
 
186
  # Extra Configurations for dataset names and paths
187
- parser.add_argument("--dataset_name",
188
- default="", help="The Dataset Name")
189
- parser.add_argument("--json_annotation_train", default="", metavar="FILE",
190
- help="The path to the training set JSON annotation")
191
- parser.add_argument("--image_path_train", default="",
192
- metavar="FILE", help="The path to the training set image folder")
193
- parser.add_argument("--json_annotation_val", default="", metavar="FILE",
194
- help="The path to the validation set JSON annotation")
195
- parser.add_argument("--image_path_val", default="",
196
- metavar="FILE", help="The path to the validation set image folder")
 
 
 
 
 
 
 
 
 
197
  args = parser.parse_args()
198
  print("Command Line Args:", args)
199
 
 
14
 
15
  from detectron2.data.datasets import register_coco_instances
16
 
17
+ from detectron2.engine import (
18
+ DefaultTrainer,
19
+ default_argument_parser,
20
+ default_setup,
21
+ hooks,
22
+ launch,
23
+ )
24
  from detectron2.evaluation import (
25
  COCOEvaluator,
26
  verify_results,
 
31
 
32
  def get_augs(cfg):
33
  """Add all the desired augmentations here. A list of availble augmentations
34
+ can be found here:
35
  https://detectron2.readthedocs.io/en/latest/modules/data_transforms.html
36
  """
37
  augs = [
38
  T.ResizeShortestEdge(
39
+ cfg.INPUT.MIN_SIZE_TRAIN,
40
+ cfg.INPUT.MAX_SIZE_TRAIN,
41
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
42
  )
43
  ]
44
  if cfg.INPUT.CROP.ENABLED:
 
50
  cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
51
  )
52
  )
53
+ horizontal_flip: bool = cfg.INPUT.RANDOM_FLIP == "horizontal"
54
+ augs.append(T.RandomFlip(horizontal=horizontal_flip, vertical=not horizontal_flip))
 
55
  # Rotate the image between -90 to 0 degrees clockwise around the centre
56
  augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
57
  return augs
 
93
  model = GeneralizedRCNNWithTTA(cfg, model)
94
  evaluators = [
95
  cls.build_evaluator(
96
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
 
97
  )
98
  for name in cfg.DATASETS.TEST
99
  ]
 
105
  def eval_and_save(cls, cfg, model):
106
  evaluators = [
107
  cls.build_evaluator(
108
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference")
 
109
  )
110
  for name in cfg.DATASETS.TEST
111
  ]
112
  res = cls.test(cfg, model, evaluators)
113
+ pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, "eval.csv"))
114
  return res
115
 
116
 
 
119
  Create configs and perform basic setups.
120
  """
121
  cfg = get_cfg()
122
+
123
  if args.config_file != "":
124
  cfg.merge_from_file(args.config_file)
125
  cfg.merge_from_list(args.opts)
126
 
127
+ with open(args.json_annotation_train, "r") as fp:
128
  anno_file = json.load(fp)
129
 
130
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(anno_file["categories"])
 
139
 
140
  def main(args):
141
  # Register Datasets
142
+ register_coco_instances(
143
+ f"{args.dataset_name}-train",
144
+ {},
145
+ args.json_annotation_train,
146
+ args.image_path_train,
147
+ )
148
 
149
+ register_coco_instances(
150
+ f"{args.dataset_name}-val",
151
+ {},
152
+ args.json_annotation_val,
153
+ args.image_path_val
154
+ )
155
  cfg = setup(args)
156
 
157
  if args.eval_only:
 
167
  verify_results(cfg, res)
168
 
169
  # Save the evaluation results
170
+ pd.DataFrame(res).to_csv(f"{cfg.OUTPUT_DIR}/eval.csv")
171
  return res
172
 
173
  # Ensure that the Output directory exists
 
185
  )
186
  if cfg.TEST.AUG.ENABLED:
187
  trainer.register_hooks(
188
+ [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
 
189
  )
190
  return trainer.train()
191
 
 
194
  parser = default_argument_parser()
195
 
196
  # Extra Configurations for dataset names and paths
197
+ parser.add_argument(
198
+ "--dataset_name",
199
+ help="The Dataset Name")
200
+ parser.add_argument(
201
+ "--json_annotation_train",
202
+ help="The path to the training set JSON annotation",
203
+ )
204
+ parser.add_argument(
205
+ "--image_path_train",
206
+ help="The path to the training set image folder",
207
+ )
208
+ parser.add_argument(
209
+ "--json_annotation_val",
210
+ help="The path to the validation set JSON annotation",
211
+ )
212
+ parser.add_argument(
213
+ "--image_path_val",
214
+ help="The path to the validation set image folder",
215
+ )
216
  args = parser.parse_args()
217
  print("Command Line Args:", args)
218