# Copyright (c) OpenMMLab. All rights reserved. import warnings from typing import Callable, List, Optional, Union import torch from ..dist_utils import master_only from .hook import HOOKS, Hook @HOOKS.register_module() class ProfilerHook(Hook): """Profiler to analyze performance during training. PyTorch Profiler is a tool that allows the collection of the performance metrics during the training. More details on Profiler can be found at https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile Args: by_epoch (bool): Profile performance by epoch or by iteration. Default: True. profile_iters (int): Number of iterations for profiling. If ``by_epoch=True``, profile_iters indicates that they are the first profile_iters epochs at the beginning of the training, otherwise it indicates the first profile_iters iterations. Default: 1. activities (list[str]): List of activity groups (CPU, CUDA) to use in profiling. Default: ['cpu', 'cuda']. schedule (dict, optional): Config of generating the callable schedule. if schedule is None, profiler will not add step markers into the trace and table view. Default: None. on_trace_ready (callable, dict): Either a handler or a dict of generate handler. Default: None. record_shapes (bool): Save information about operator's input shapes. Default: False. profile_memory (bool): Track tensor memory allocation/deallocation. Default: False. with_stack (bool): Record source information (file and line number) for the ops. Default: False. with_flops (bool): Use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution). Default: False. json_trace_path (str, optional): Exports the collected trace in Chrome JSON format. Default: None. Example: >>> runner = ... # instantiate a Runner >>> # tensorboard trace >>> trace_config = dict(type='tb_trace', dir_name='work_dir') >>> profiler_config = dict(on_trace_ready=trace_config) >>> runner.register_profiler_hook(profiler_config) >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)]) """ def __init__(self, by_epoch: bool = True, profile_iters: int = 1, activities: List[str] = ['cpu', 'cuda'], schedule: Optional[dict] = None, on_trace_ready: Optional[Union[Callable, dict]] = None, record_shapes: bool = False, profile_memory: bool = False, with_stack: bool = False, with_flops: bool = False, json_trace_path: Optional[str] = None) -> None: try: from torch import profiler # torch version >= 1.8.1 except ImportError: raise ImportError('profiler is the new feature of torch1.8.1, ' f'but your version is {torch.__version__}') assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.' self.by_epoch = by_epoch if profile_iters < 1: raise ValueError('profile_iters should be greater than 0, but got ' f'{profile_iters}') self.profile_iters = profile_iters if not isinstance(activities, list): raise ValueError( f'activities should be list, but got {type(activities)}') self.activities = [] for activity in activities: activity = activity.lower() if activity == 'cpu': self.activities.append(profiler.ProfilerActivity.CPU) elif activity == 'cuda': self.activities.append(profiler.ProfilerActivity.CUDA) else: raise ValueError( f'activity should be "cpu" or "cuda", but got {activity}') if schedule is not None: self.schedule = profiler.schedule(**schedule) else: self.schedule = None self.on_trace_ready = on_trace_ready self.record_shapes = record_shapes self.profile_memory = profile_memory self.with_stack = with_stack self.with_flops = with_flops self.json_trace_path = json_trace_path @master_only def before_run(self, runner): if self.by_epoch and runner.max_epochs < self.profile_iters: raise ValueError('self.profile_iters should not be greater than ' f'{runner.max_epochs}') if not self.by_epoch and runner.max_iters < self.profile_iters: raise ValueError('self.profile_iters should not be greater than ' f'{runner.max_iters}') if callable(self.on_trace_ready): # handler _on_trace_ready = self.on_trace_ready elif isinstance(self.on_trace_ready, dict): # config of handler trace_cfg = self.on_trace_ready.copy() trace_type = trace_cfg.pop('type') # log_trace handler if trace_type == 'log_trace': def _log_handler(prof): print(prof.key_averages().table(**trace_cfg)) _on_trace_ready = _log_handler elif trace_type == 'tb_trace': # tensorboard_trace handler try: import torch_tb_profiler # noqa: F401 except ImportError: raise ImportError('please run "pip install ' 'torch-tb-profiler" to install ' 'torch_tb_profiler') _on_trace_ready = torch.profiler.tensorboard_trace_handler( **trace_cfg) else: raise ValueError('trace_type should be "log_trace" or ' f'"tb_trace", but got {trace_type}') elif self.on_trace_ready is None: _on_trace_ready = None # type: ignore else: raise ValueError('on_trace_ready should be handler, dict or None, ' f'but got {type(self.on_trace_ready)}') if runner.max_epochs > 1: warnings.warn(f'profiler will profile {runner.max_epochs} epochs ' 'instead of 1 epoch. Since profiler will slow down ' 'the training, it is recommended to train 1 epoch ' 'with ProfilerHook and adjust your setting according' ' to the profiler summary. During normal training ' '(epoch > 1), you may disable the ProfilerHook.') self.profiler = torch.profiler.profile( activities=self.activities, schedule=self.schedule, on_trace_ready=_on_trace_ready, record_shapes=self.record_shapes, profile_memory=self.profile_memory, with_stack=self.with_stack, with_flops=self.with_flops) self.profiler.__enter__() runner.logger.info('profiler is profiling...') @master_only def after_train_epoch(self, runner): if self.by_epoch and runner.epoch == self.profile_iters - 1: runner.logger.info('profiler may take a few minutes...') self.profiler.__exit__(None, None, None) if self.json_trace_path is not None: self.profiler.export_chrome_trace(self.json_trace_path) @master_only def after_train_iter(self, runner): self.profiler.step() if not self.by_epoch and runner.iter == self.profile_iters - 1: runner.logger.info('profiler may take a few minutes...') self.profiler.__exit__(None, None, None) if self.json_trace_path is not None: self.profiler.export_chrome_trace(self.json_trace_path)