| |
| |
| import concurrent.futures |
| import logging |
| import numpy as np |
| import time |
| import weakref |
| from typing import List, Mapping, Optional |
| import torch |
| from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
|
| import detectron2.utils.comm as comm |
| from detectron2.utils.events import EventStorage, get_event_storage |
| from detectron2.utils.logger import _log_api_usage |
|
|
| __all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"] |
|
|
|
|
| class HookBase: |
| """ |
| Base class for hooks that can be registered with :class:`TrainerBase`. |
| |
| Each hook can implement 4 methods. The way they are called is demonstrated |
| in the following snippet: |
| :: |
| hook.before_train() |
| for iter in range(start_iter, max_iter): |
| hook.before_step() |
| trainer.run_step() |
| hook.after_step() |
| iter += 1 |
| hook.after_train() |
| |
| Notes: |
| 1. In the hook method, users can access ``self.trainer`` to access more |
| properties about the context (e.g., model, current iteration, or config |
| if using :class:`DefaultTrainer`). |
| |
| 2. A hook that does something in :meth:`before_step` can often be |
| implemented equivalently in :meth:`after_step`. |
| If the hook takes non-trivial time, it is strongly recommended to |
| implement the hook in :meth:`after_step` instead of :meth:`before_step`. |
| The convention is that :meth:`before_step` should only take negligible time. |
| |
| Following this convention will allow hooks that do care about the difference |
| between :meth:`before_step` and :meth:`after_step` (e.g., timer) to |
| function properly. |
| |
| """ |
|
|
| trainer: "TrainerBase" = None |
| """ |
| A weak reference to the trainer object. Set by the trainer when the hook is registered. |
| """ |
|
|
| def before_train(self): |
| """ |
| Called before the first iteration. |
| """ |
| pass |
|
|
| def after_train(self): |
| """ |
| Called after the last iteration. |
| """ |
| pass |
|
|
| def before_step(self): |
| """ |
| Called before each iteration. |
| """ |
| pass |
|
|
| def after_backward(self): |
| """ |
| Called after the backward pass of each iteration. |
| """ |
| pass |
|
|
| def after_step(self): |
| """ |
| Called after each iteration. |
| """ |
| pass |
|
|
| def state_dict(self): |
| """ |
| Hooks are stateless by default, but can be made checkpointable by |
| implementing `state_dict` and `load_state_dict`. |
| """ |
| return {} |
|
|
|
|
| class TrainerBase: |
| """ |
| Base class for iterative trainer with hooks. |
| |
| The only assumption we made here is: the training runs in a loop. |
| A subclass can implement what the loop is. |
| We made no assumptions about the existence of dataloader, optimizer, model, etc. |
| |
| Attributes: |
| iter(int): the current iteration. |
| |
| start_iter(int): The iteration to start with. |
| By convention the minimum possible value is 0. |
| |
| max_iter(int): The iteration to end training. |
| |
| storage(EventStorage): An EventStorage that's opened during the course of training. |
| """ |
|
|
| def __init__(self) -> None: |
| self._hooks: List[HookBase] = [] |
| self.iter: int = 0 |
| self.start_iter: int = 0 |
| self.max_iter: int |
| self.storage: EventStorage |
| _log_api_usage("trainer." + self.__class__.__name__) |
|
|
| def register_hooks(self, hooks: List[Optional[HookBase]]) -> None: |
| """ |
| Register hooks to the trainer. The hooks are executed in the order |
| they are registered. |
| |
| Args: |
| hooks (list[Optional[HookBase]]): list of hooks |
| """ |
| hooks = [h for h in hooks if h is not None] |
| for h in hooks: |
| assert isinstance(h, HookBase) |
| |
| |
| |
| |
| h.trainer = weakref.proxy(self) |
| self._hooks.extend(hooks) |
|
|
| def train(self, start_iter: int, max_iter: int): |
| """ |
| Args: |
| start_iter, max_iter (int): See docs above |
| """ |
| logger = logging.getLogger(__name__) |
| logger.info("Starting training from iteration {}".format(start_iter)) |
|
|
| self.iter = self.start_iter = start_iter |
| self.max_iter = max_iter |
|
|
| with EventStorage(start_iter) as self.storage: |
| try: |
| self.before_train() |
| for self.iter in range(start_iter, max_iter): |
| self.before_step() |
| self.run_step() |
| self.after_step() |
| |
| |
| |
| self.iter += 1 |
| except Exception: |
| logger.exception("Exception during training:") |
| raise |
| finally: |
| self.after_train() |
|
|
| def before_train(self): |
| for h in self._hooks: |
| h.before_train() |
|
|
| def after_train(self): |
| self.storage.iter = self.iter |
| for h in self._hooks: |
| h.after_train() |
|
|
| def before_step(self): |
| |
| |
| self.storage.iter = self.iter |
|
|
| for h in self._hooks: |
| h.before_step() |
|
|
| def after_backward(self): |
| for h in self._hooks: |
| h.after_backward() |
|
|
| def after_step(self): |
| for h in self._hooks: |
| h.after_step() |
|
|
| def run_step(self): |
| raise NotImplementedError |
|
|
| def state_dict(self): |
| ret = {"iteration": self.iter} |
| hooks_state = {} |
| for h in self._hooks: |
| sd = h.state_dict() |
| if sd: |
| name = type(h).__qualname__ |
| if name in hooks_state: |
| |
| continue |
| hooks_state[name] = sd |
| if hooks_state: |
| ret["hooks"] = hooks_state |
| return ret |
|
|
| def load_state_dict(self, state_dict): |
| logger = logging.getLogger(__name__) |
| self.iter = state_dict["iteration"] |
| for key, value in state_dict.get("hooks", {}).items(): |
| for h in self._hooks: |
| try: |
| name = type(h).__qualname__ |
| except AttributeError: |
| continue |
| if name == key: |
| h.load_state_dict(value) |
| break |
| else: |
| logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.") |
|
|
|
|
| class SimpleTrainer(TrainerBase): |
| """ |
| A simple trainer for the most common type of task: |
| single-cost single-optimizer single-data-source iterative optimization, |
| optionally using data-parallelism. |
| It assumes that every step, you: |
| |
| 1. Compute the loss with a data from the data_loader. |
| 2. Compute the gradients with the above loss. |
| 3. Update the model with the optimizer. |
| |
| All other tasks during training (checkpointing, logging, evaluation, LR schedule) |
| are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`. |
| |
| If you want to do anything fancier than this, |
| either subclass TrainerBase and implement your own `run_step`, |
| or write your own training loop. |
| """ |
|
|
| def __init__( |
| self, |
| model, |
| data_loader, |
| optimizer, |
| gather_metric_period=1, |
| zero_grad_before_forward=False, |
| async_write_metrics=False, |
| ): |
| """ |
| Args: |
| model: a torch Module. Takes a data from data_loader and returns a |
| dict of losses. |
| data_loader: an iterable. Contains data to be used to call model. |
| optimizer: a torch optimizer. |
| gather_metric_period: an int. Every gather_metric_period iterations |
| the metrics are gathered from all the ranks to rank 0 and logged. |
| zero_grad_before_forward: whether to zero the gradients before the forward. |
| async_write_metrics: bool. If True, then write metrics asynchronously to improve |
| training speed |
| """ |
| super().__init__() |
|
|
| """ |
| We set the model to training mode in the trainer. |
| However it's valid to train a model that's in eval mode. |
| If you want your model (or a submodule of it) to behave |
| like evaluation during training, you can overwrite its train() method. |
| """ |
| model.train() |
|
|
| self.model = model |
| self.data_loader = data_loader |
| |
| self._data_loader_iter_obj = None |
| self.optimizer = optimizer |
| self.gather_metric_period = gather_metric_period |
| self.zero_grad_before_forward = zero_grad_before_forward |
| self.async_write_metrics = async_write_metrics |
| |
| |
| self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) |
|
|
| def run_step(self): |
| """ |
| Implement the standard training logic described above. |
| """ |
| assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" |
| start = time.perf_counter() |
| """ |
| If you want to do something with the data, you can wrap the dataloader. |
| """ |
| data = next(self._data_loader_iter) |
| data_time = time.perf_counter() - start |
|
|
| if self.zero_grad_before_forward: |
| """ |
| If you need to accumulate gradients or do something similar, you can |
| wrap the optimizer with your custom `zero_grad()` method. |
| """ |
| self.optimizer.zero_grad() |
|
|
| """ |
| If you want to do something with the losses, you can wrap the model. |
| """ |
| loss_dict = self.model(data) |
| if isinstance(loss_dict, torch.Tensor): |
| losses = loss_dict |
| loss_dict = {"total_loss": loss_dict} |
| else: |
| losses = sum(loss_dict.values()) |
| if not self.zero_grad_before_forward: |
| """ |
| If you need to accumulate gradients or do something similar, you can |
| wrap the optimizer with your custom `zero_grad()` method. |
| """ |
| self.optimizer.zero_grad() |
| losses.backward() |
|
|
| self.after_backward() |
|
|
| if self.async_write_metrics: |
| |
| self.concurrent_executor.submit( |
| self._write_metrics, loss_dict, data_time, iter=self.iter |
| ) |
| else: |
| self._write_metrics(loss_dict, data_time) |
|
|
| """ |
| If you need gradient clipping/scaling or other processing, you can |
| wrap the optimizer with your custom `step()` method. But it is |
| suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 |
| """ |
| self.optimizer.step() |
|
|
| @property |
| def _data_loader_iter(self): |
| |
| if self._data_loader_iter_obj is None: |
| self._data_loader_iter_obj = iter(self.data_loader) |
| return self._data_loader_iter_obj |
|
|
| def reset_data_loader(self, data_loader_builder): |
| """ |
| Delete and replace the current data loader with a new one, which will be created |
| by calling `data_loader_builder` (without argument). |
| """ |
| del self.data_loader |
| data_loader = data_loader_builder() |
| self.data_loader = data_loader |
| self._data_loader_iter_obj = None |
|
|
| def _write_metrics( |
| self, |
| loss_dict: Mapping[str, torch.Tensor], |
| data_time: float, |
| prefix: str = "", |
| iter: Optional[int] = None, |
| ) -> None: |
| logger = logging.getLogger(__name__) |
|
|
| iter = self.iter if iter is None else iter |
| if (iter + 1) % self.gather_metric_period == 0: |
| try: |
| SimpleTrainer.write_metrics(loss_dict, data_time, iter, prefix) |
| except Exception: |
| logger.exception("Exception in writing metrics: ") |
| raise |
|
|
| @staticmethod |
| def write_metrics( |
| loss_dict: Mapping[str, torch.Tensor], |
| data_time: float, |
| cur_iter: int, |
| prefix: str = "", |
| ) -> None: |
| """ |
| Args: |
| loss_dict (dict): dict of scalar losses |
| data_time (float): time taken by the dataloader iteration |
| prefix (str): prefix for logging keys |
| """ |
| metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()} |
| metrics_dict["data_time"] = data_time |
|
|
| storage = get_event_storage() |
| |
| storage.put_scalar("rank_data_time", data_time, cur_iter=cur_iter) |
|
|
| |
| |
| |
| all_metrics_dict = comm.gather(metrics_dict) |
|
|
| if comm.is_main_process(): |
| |
| |
| data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) |
| storage.put_scalar("data_time", data_time, cur_iter=cur_iter) |
|
|
| |
| metrics_dict = { |
| k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() |
| } |
| total_losses_reduced = sum(metrics_dict.values()) |
| if not np.isfinite(total_losses_reduced): |
| raise FloatingPointError( |
| f"Loss became infinite or NaN at iteration={cur_iter}!\n" |
| f"loss_dict = {metrics_dict}" |
| ) |
|
|
| storage.put_scalar( |
| "{}total_loss".format(prefix), total_losses_reduced, cur_iter=cur_iter |
| ) |
| if len(metrics_dict) > 1: |
| storage.put_scalars(cur_iter=cur_iter, **metrics_dict) |
|
|
| def state_dict(self): |
| ret = super().state_dict() |
| ret["optimizer"] = self.optimizer.state_dict() |
| return ret |
|
|
| def load_state_dict(self, state_dict): |
| super().load_state_dict(state_dict) |
| self.optimizer.load_state_dict(state_dict["optimizer"]) |
|
|
| def after_train(self): |
| super().after_train() |
| self.concurrent_executor.shutdown(wait=True) |
|
|
|
|
| class AMPTrainer(SimpleTrainer): |
| """ |
| Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision |
| in the training loop. |
| """ |
|
|
| def __init__( |
| self, |
| model, |
| data_loader, |
| optimizer, |
| gather_metric_period=1, |
| zero_grad_before_forward=False, |
| grad_scaler=None, |
| precision: torch.dtype = torch.float16, |
| log_grad_scaler: bool = False, |
| async_write_metrics=False, |
| ): |
| """ |
| Args: |
| model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward, |
| async_write_metrics: same as in :class:`SimpleTrainer`. |
| grad_scaler: torch GradScaler to automatically scale gradients. |
| precision: torch.dtype as the target precision to cast to in computations |
| """ |
| unsupported = "AMPTrainer does not support single-process multi-device training!" |
| if isinstance(model, DistributedDataParallel): |
| assert not (model.device_ids and len(model.device_ids) > 1), unsupported |
| assert not isinstance(model, DataParallel), unsupported |
|
|
| super().__init__( |
| model, data_loader, optimizer, gather_metric_period, zero_grad_before_forward |
| ) |
|
|
| if grad_scaler is None: |
| from torch.cuda.amp import GradScaler |
|
|
| grad_scaler = GradScaler() |
| self.grad_scaler = grad_scaler |
| self.precision = precision |
| self.log_grad_scaler = log_grad_scaler |
|
|
| def run_step(self): |
| """ |
| Implement the AMP training logic. |
| """ |
| assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
| assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
| from torch.cuda.amp import autocast |
|
|
| start = time.perf_counter() |
| data = next(self._data_loader_iter) |
| data_time = time.perf_counter() - start |
|
|
| if self.zero_grad_before_forward: |
| self.optimizer.zero_grad() |
| with autocast(dtype=self.precision): |
| loss_dict = self.model(data) |
| if isinstance(loss_dict, torch.Tensor): |
| losses = loss_dict |
| loss_dict = {"total_loss": loss_dict} |
| else: |
| losses = sum(loss_dict.values()) |
|
|
| if not self.zero_grad_before_forward: |
| self.optimizer.zero_grad() |
|
|
| self.grad_scaler.scale(losses).backward() |
|
|
| if self.log_grad_scaler: |
| storage = get_event_storage() |
| storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale()) |
|
|
| self.after_backward() |
|
|
| if self.async_write_metrics: |
| |
| self.concurrent_executor.submit( |
| self._write_metrics, loss_dict, data_time, iter=self.iter |
| ) |
| else: |
| self._write_metrics(loss_dict, data_time) |
|
|
| self.grad_scaler.step(self.optimizer) |
| self.grad_scaler.update() |
|
|
| def state_dict(self): |
| ret = super().state_dict() |
| ret["grad_scaler"] = self.grad_scaler.state_dict() |
| return ret |
|
|
| def load_state_dict(self, state_dict): |
| super().load_state_dict(state_dict) |
| self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) |
|
|