feat: add load checkpoint
Browse files
train.py
CHANGED
@@ -14,6 +14,7 @@ torch.set_float32_matmul_precision("high")
|
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
|
16 |
parser.add_argument("-b", "--single-batch-size", type=int, default=64)
|
|
|
17 |
|
18 |
args = parser.parse_args()
|
19 |
|
@@ -88,5 +89,5 @@ detector = FontDetector(
|
|
88 |
num_iters=num_iters,
|
89 |
)
|
90 |
|
91 |
-
trainer.fit(detector, datamodule=data_module)
|
92 |
trainer.test(detector, datamodule=data_module)
|
|
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
|
16 |
parser.add_argument("-b", "--single-batch-size", type=int, default=64)
|
17 |
+
parser.add_argument("-c", "--checkpoint", type=str, default=None)
|
18 |
|
19 |
args = parser.parse_args()
|
20 |
|
|
|
89 |
num_iters=num_iters,
|
90 |
)
|
91 |
|
92 |
+
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
93 |
trainer.test(detector, datamodule=data_module)
|