|
|
|
|
|
""" |
|
Training script using the new "LazyConfig" python config files. |
|
|
|
This scripts reads a given python config file and runs the training or evaluation. |
|
It can be used to train any models or dataset as long as they can be |
|
instantiated by the recursive construction defined in the given config file. |
|
|
|
Besides lazy construction of models, dataloader, etc., this scripts expects a |
|
few common configuration parameters currently defined in "configs/common/train.py". |
|
To add more complicated training logic, you can easily add other configs |
|
in the config file and implement a new train_net.py to handle them. |
|
""" |
|
import logging |
|
|
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.config import LazyConfig, instantiate |
|
from detectron2.engine import ( |
|
AMPTrainer, |
|
SimpleTrainer, |
|
default_argument_parser, |
|
default_setup, |
|
default_writers, |
|
hooks, |
|
launch, |
|
) |
|
from detectron2.engine.defaults import create_ddp_model |
|
from detectron2.evaluation import inference_on_dataset, print_csv_format |
|
from detectron2.utils import comm |
|
|
|
logger = logging.getLogger("detectron2") |
|
|
|
|
|
def do_test(cfg, model): |
|
if "evaluator" in cfg.dataloader: |
|
ret = inference_on_dataset( |
|
model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) |
|
) |
|
print_csv_format(ret) |
|
return ret |
|
|
|
|
|
def do_train(args, cfg): |
|
""" |
|
Args: |
|
cfg: an object with the following attributes: |
|
model: instantiate to a module |
|
dataloader.{train,test}: instantiate to dataloaders |
|
dataloader.evaluator: instantiate to evaluator for test set |
|
optimizer: instantaite to an optimizer |
|
lr_multiplier: instantiate to a fvcore scheduler |
|
train: other misc config defined in `configs/common/train.py`, including: |
|
output_dir (str) |
|
init_checkpoint (str) |
|
amp.enabled (bool) |
|
max_iter (int) |
|
eval_period, log_period (int) |
|
device (str) |
|
checkpointer (dict) |
|
ddp (dict) |
|
""" |
|
model = instantiate(cfg.model) |
|
logger = logging.getLogger("detectron2") |
|
logger.info("Model:\n{}".format(model)) |
|
model.to(cfg.train.device) |
|
|
|
cfg.optimizer.params.model = model |
|
optim = instantiate(cfg.optimizer) |
|
|
|
train_loader = instantiate(cfg.dataloader.train) |
|
|
|
model = create_ddp_model(model, **cfg.train.ddp) |
|
trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim) |
|
checkpointer = DetectionCheckpointer( |
|
model, |
|
cfg.train.output_dir, |
|
trainer=trainer, |
|
) |
|
trainer.register_hooks( |
|
[ |
|
hooks.IterationTimer(), |
|
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), |
|
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) |
|
if comm.is_main_process() |
|
else None, |
|
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), |
|
hooks.PeriodicWriter( |
|
default_writers(cfg.train.output_dir, cfg.train.max_iter), |
|
period=cfg.train.log_period, |
|
) |
|
if comm.is_main_process() |
|
else None, |
|
] |
|
) |
|
|
|
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) |
|
if args.resume and checkpointer.has_checkpoint(): |
|
|
|
|
|
start_iter = trainer.iter + 1 |
|
else: |
|
start_iter = 0 |
|
trainer.train(start_iter, cfg.train.max_iter) |
|
|
|
|
|
def main(args): |
|
cfg = LazyConfig.load(args.config_file) |
|
cfg = LazyConfig.apply_overrides(cfg, args.opts) |
|
default_setup(cfg, args) |
|
|
|
if args.eval_only: |
|
model = instantiate(cfg.model) |
|
model.to(cfg.train.device) |
|
model = create_ddp_model(model) |
|
DetectionCheckpointer(model).load(cfg.train.init_checkpoint) |
|
print(do_test(cfg, model)) |
|
else: |
|
do_train(args, cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = default_argument_parser().parse_args() |
|
launch( |
|
main, |
|
args.num_gpus, |
|
num_machines=args.num_machines, |
|
machine_rank=args.machine_rank, |
|
dist_url=args.dist_url, |
|
args=(args,), |
|
) |
|
|