Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Lightning Trainer should be considered beta at this point | |
| # We have confirmed that training and validation run correctly and produce correct results | |
| # Depending on how you launch the trainer, there are issues with processes terminating correctly | |
| # This module is still dependent on D2 logging, but could be transferred to use Lightning logging | |
| import logging | |
| import os | |
| import time | |
| import weakref | |
| from collections import OrderedDict | |
| from typing import Any, Dict, List | |
| import detectron2.utils.comm as comm | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.config import get_cfg | |
| from detectron2.data import build_detection_test_loader, build_detection_train_loader | |
| from detectron2.engine import ( | |
| DefaultTrainer, | |
| SimpleTrainer, | |
| default_argument_parser, | |
| default_setup, | |
| default_writers, | |
| hooks, | |
| ) | |
| from detectron2.evaluation import print_csv_format | |
| from detectron2.evaluation.testing import flatten_results_dict | |
| from detectron2.modeling import build_model | |
| from detectron2.solver import build_lr_scheduler, build_optimizer | |
| from detectron2.utils.events import EventStorage | |
| from detectron2.utils.logger import setup_logger | |
| import pytorch_lightning as pl # type: ignore | |
| from pytorch_lightning import LightningDataModule, LightningModule | |
| from train_net import build_evaluator | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("detectron2") | |
| class TrainingModule(LightningModule): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 | |
| setup_logger() | |
| self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) | |
| self.storage: EventStorage = None | |
| self.model = build_model(self.cfg) | |
| self.start_iter = 0 | |
| self.max_iter = cfg.SOLVER.MAX_ITER | |
| def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| checkpoint["iteration"] = self.storage.iter | |
| def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: | |
| self.start_iter = checkpointed_state["iteration"] | |
| self.storage.iter = self.start_iter | |
| def setup(self, stage: str): | |
| if self.cfg.MODEL.WEIGHTS: | |
| self.checkpointer = DetectionCheckpointer( | |
| # Assume you want to save checkpoints together with logs/statistics | |
| self.model, | |
| self.cfg.OUTPUT_DIR, | |
| ) | |
| logger.info(f"Load model weights from checkpoint: {self.cfg.MODEL.WEIGHTS}.") | |
| # Only load weights, use lightning checkpointing if you want to resume | |
| self.checkpointer.load(self.cfg.MODEL.WEIGHTS) | |
| self.iteration_timer = hooks.IterationTimer() | |
| self.iteration_timer.before_train() | |
| self.data_start = time.perf_counter() | |
| self.writers = None | |
| def training_step(self, batch, batch_idx): | |
| data_time = time.perf_counter() - self.data_start | |
| # Need to manually enter/exit since trainer may launch processes | |
| # This ideally belongs in setup, but setup seems to run before processes are spawned | |
| if self.storage is None: | |
| self.storage = EventStorage(0) | |
| self.storage.__enter__() | |
| self.iteration_timer.trainer = weakref.proxy(self) | |
| self.iteration_timer.before_step() | |
| self.writers = ( | |
| default_writers(self.cfg.OUTPUT_DIR, self.max_iter) | |
| if comm.is_main_process() | |
| else {} | |
| ) | |
| loss_dict = self.model(batch) | |
| SimpleTrainer.write_metrics(loss_dict, data_time) | |
| opt = self.optimizers() | |
| self.storage.put_scalar( | |
| "lr", opt.param_groups[self._best_param_group_id]["lr"], smoothing_hint=False | |
| ) | |
| self.iteration_timer.after_step() | |
| self.storage.step() | |
| # A little odd to put before step here, but it's the best way to get a proper timing | |
| self.iteration_timer.before_step() | |
| if self.storage.iter % 20 == 0: | |
| for writer in self.writers: | |
| writer.write() | |
| return sum(loss_dict.values()) | |
| def training_step_end(self, training_step_outpus): | |
| self.data_start = time.perf_counter() | |
| return training_step_outpus | |
| def training_epoch_end(self, training_step_outputs): | |
| self.iteration_timer.after_train() | |
| if comm.is_main_process(): | |
| self.checkpointer.save("model_final") | |
| for writer in self.writers: | |
| writer.write() | |
| writer.close() | |
| self.storage.__exit__(None, None, None) | |
| def _process_dataset_evaluation_results(self) -> OrderedDict: | |
| results = OrderedDict() | |
| for idx, dataset_name in enumerate(self.cfg.DATASETS.TEST): | |
| results[dataset_name] = self._evaluators[idx].evaluate() | |
| if comm.is_main_process(): | |
| print_csv_format(results[dataset_name]) | |
| if len(results) == 1: | |
| results = list(results.values())[0] | |
| return results | |
| def _reset_dataset_evaluators(self): | |
| self._evaluators = [] | |
| for dataset_name in self.cfg.DATASETS.TEST: | |
| evaluator = build_evaluator(self.cfg, dataset_name) | |
| evaluator.reset() | |
| self._evaluators.append(evaluator) | |
| def on_validation_epoch_start(self, _outputs): | |
| self._reset_dataset_evaluators() | |
| def validation_epoch_end(self, _outputs): | |
| results = self._process_dataset_evaluation_results(_outputs) | |
| flattened_results = flatten_results_dict(results) | |
| for k, v in flattened_results.items(): | |
| try: | |
| v = float(v) | |
| except Exception as e: | |
| raise ValueError( | |
| "[EvalHook] eval_function should return a nested dict of float. " | |
| "Got '{}: {}' instead.".format(k, v) | |
| ) from e | |
| self.storage.put_scalars(**flattened_results, smoothing_hint=False) | |
| def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: | |
| if not isinstance(batch, List): | |
| batch = [batch] | |
| outputs = self.model(batch) | |
| self._evaluators[dataloader_idx].process(batch, outputs) | |
| def configure_optimizers(self): | |
| optimizer = build_optimizer(self.cfg, self.model) | |
| self._best_param_group_id = hooks.LRScheduler.get_best_param_group_id(optimizer) | |
| scheduler = build_lr_scheduler(self.cfg, optimizer) | |
| return [optimizer], [{"scheduler": scheduler, "interval": "step"}] | |
| class DataModule(LightningDataModule): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) | |
| def train_dataloader(self): | |
| return build_detection_train_loader(self.cfg) | |
| def val_dataloader(self): | |
| dataloaders = [] | |
| for dataset_name in self.cfg.DATASETS.TEST: | |
| dataloaders.append(build_detection_test_loader(self.cfg, dataset_name)) | |
| return dataloaders | |
| def main(args): | |
| cfg = setup(args) | |
| train(cfg, args) | |
| def train(cfg, args): | |
| trainer_params = { | |
| # training loop is bounded by max steps, use a large max_epochs to make | |
| # sure max_steps is met first | |
| "max_epochs": 10**8, | |
| "max_steps": cfg.SOLVER.MAX_ITER, | |
| "val_check_interval": cfg.TEST.EVAL_PERIOD if cfg.TEST.EVAL_PERIOD > 0 else 10**8, | |
| "num_nodes": args.num_machines, | |
| "gpus": args.num_gpus, | |
| "num_sanity_val_steps": 0, | |
| } | |
| if cfg.SOLVER.AMP.ENABLED: | |
| trainer_params["precision"] = 16 | |
| last_checkpoint = os.path.join(cfg.OUTPUT_DIR, "last.ckpt") | |
| if args.resume: | |
| # resume training from checkpoint | |
| trainer_params["resume_from_checkpoint"] = last_checkpoint | |
| logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") | |
| trainer = pl.Trainer(**trainer_params) | |
| logger.info(f"start to train with {args.num_machines} nodes and {args.num_gpus} GPUs") | |
| module = TrainingModule(cfg) | |
| data_module = DataModule(cfg) | |
| if args.eval_only: | |
| logger.info("Running inference") | |
| trainer.validate(module, data_module) | |
| else: | |
| logger.info("Running training") | |
| trainer.fit(module, data_module) | |
| def setup(args): | |
| """ | |
| Create configs and perform basic setups. | |
| """ | |
| cfg = get_cfg() | |
| cfg.merge_from_file(args.config_file) | |
| cfg.merge_from_list(args.opts) | |
| cfg.freeze() | |
| default_setup(cfg, args) | |
| return cfg | |
| if __name__ == "__main__": | |
| parser = default_argument_parser() | |
| args = parser.parse_args() | |
| logger.info("Command Line Args:", args) | |
| main(args) | |