gyrojeff commited on
Commit
8d9c0ef
1 Parent(s): fd9442f

feat: add load checkpoint

Browse files
Files changed (1) hide show
  1. train.py +2 -1
train.py CHANGED
@@ -14,6 +14,7 @@ torch.set_float32_matmul_precision("high")
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
16
  parser.add_argument("-b", "--single-batch-size", type=int, default=64)
 
17
 
18
  args = parser.parse_args()
19
 
@@ -88,5 +89,5 @@ detector = FontDetector(
88
  num_iters=num_iters,
89
  )
90
 
91
- trainer.fit(detector, datamodule=data_module)
92
  trainer.test(detector, datamodule=data_module)
 
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
16
  parser.add_argument("-b", "--single-batch-size", type=int, default=64)
17
+ parser.add_argument("-c", "--checkpoint", type=str, default=None)
18
 
19
  args = parser.parse_args()
20
 
 
89
  num_iters=num_iters,
90
  )
91
 
92
+ trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
93
  trainer.test(detector, datamodule=data_module)