loliipopshock commited on
Commit
9aafcc3
·
1 Parent(s): 4cf1221

Add eval_and_save hook in trainer

Browse files
Files changed (1) hide show
  1. tools/train_net.py +14 -0
tools/train_net.py CHANGED
@@ -61,6 +61,17 @@ class Trainer(DefaultTrainer):
61
  res = OrderedDict({k + "_TTA": v for k, v in res.items()})
62
  return res
63
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def setup(args):
66
  """
@@ -110,6 +121,9 @@ def main(args):
110
  """
111
  trainer = Trainer(cfg)
112
  trainer.resume_or_load(resume=args.resume)
 
 
 
113
  if cfg.TEST.AUG.ENABLED:
114
  trainer.register_hooks(
115
  [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
 
61
  res = OrderedDict({k + "_TTA": v for k, v in res.items()})
62
  return res
63
 
64
+ @classmethod
65
+ def eval_and_save(cls, cfg, model):
66
+ evaluators = [
67
+ cls.build_evaluator(
68
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference")
69
+ )
70
+ for name in cfg.DATASETS.TEST
71
+ ]
72
+ res = cls.test(cfg, model, evaluators)
73
+ pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
74
+ return res
75
 
76
  def setup(args):
77
  """
 
121
  """
122
  trainer = Trainer(cfg)
123
  trainer.resume_or_load(resume=args.resume)
124
+ trainer.register_hooks(
125
+ [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
126
+ )
127
  if cfg.TEST.AUG.ENABLED:
128
  trainer.register_hooks(
129
  [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]