# Copyright (c) OpenMMLab. All rights reserved. """Hyper-parameter Scheduler Visualization. This tool aims to help the user to check the hyper-parameter scheduler of the optimizer(without training), which support the "learning rate", "momentum", and "weight_decay". Example: ```shell python tools/analysis_tools/vis_scheduler.py \ configs/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py \ --dataset-size 118287 \ --ngpus 8 \ --out-dir ./output ``` Modified from: https://github.com/open-mmlab/mmclassification/blob/1.x/tools/visualizations/vis_scheduler.py # noqa """ import argparse import json import os.path as osp import re from pathlib import Path from unittest.mock import MagicMock import matplotlib.pyplot as plt import rich import torch.nn as nn from mmengine.config import Config, DictAction from mmengine.hooks import Hook from mmengine.model import BaseModel from mmengine.registry import init_default_scope from mmengine.runner import Runner from mmengine.utils.path import mkdir_or_exist from mmengine.visualization import Visualizer from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn def parse_args(): parser = argparse.ArgumentParser( description='Visualize a hyper-parameter scheduler') parser.add_argument('config', help='config file path') parser.add_argument( '-p', '--parameter', type=str, default='lr', choices=['lr', 'momentum', 'wd'], help='The parameter to visualize its change curve, choose from' '"lr", "wd" and "momentum". Defaults to "lr".') parser.add_argument( '-d', '--dataset-size', type=int, help='The size of the dataset. If specify, `DATASETS.build` will ' 'be skipped and use this size as the dataset size.') parser.add_argument( '-n', '--ngpus', type=int, default=1, help='The number of GPUs used in training.') parser.add_argument( '-o', '--out-dir', type=Path, help='Path to output file') parser.add_argument( '--log-level', default='WARNING', help='The log level of the handler and logger. Defaults to ' 'WARNING.') parser.add_argument('--title', type=str, help='title of figure') parser.add_argument( '--style', type=str, default='whitegrid', help='style of plt') parser.add_argument('--not-show', default=False, action='store_true') parser.add_argument( '--window-size', default='12*7', help='Size of the window to display images, in format of "$W*$H".') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() if args.window_size != '': assert re.match(r'\d+\*\d+', args.window_size), \ "'window-size' must be in format 'W*H'." return args class SimpleModel(BaseModel): """simple model that do nothing in train_step.""" def __init__(self): super().__init__() self.data_preprocessor = nn.Identity() self.conv = nn.Conv2d(1, 1, 1) def forward(self, inputs, data_samples, mode='tensor'): pass def train_step(self, data, optim_wrapper): pass class ParamRecordHook(Hook): def __init__(self, by_epoch): super().__init__() self.by_epoch = by_epoch self.lr_list = [] self.momentum_list = [] self.wd_list = [] self.task_id = 0 self.progress = Progress(BarColumn(), MofNCompleteColumn(), TextColumn('{task.description}')) def before_train(self, runner): if self.by_epoch: total = runner.train_loop.max_epochs self.task_id = self.progress.add_task( 'epochs', start=True, total=total) else: total = runner.train_loop.max_iters self.task_id = self.progress.add_task( 'iters', start=True, total=total) self.progress.start() def after_train_epoch(self, runner): if self.by_epoch: self.progress.update(self.task_id, advance=1) # TODO: Support multiple schedulers def after_train_iter(self, runner, batch_idx, data_batch, outputs): if not self.by_epoch: self.progress.update(self.task_id, advance=1) self.lr_list.append(runner.optim_wrapper.get_lr()['lr'][0]) self.momentum_list.append( runner.optim_wrapper.get_momentum()['momentum'][0]) self.wd_list.append( runner.optim_wrapper.param_groups[0]['weight_decay']) def after_train(self, runner): self.progress.stop() def plot_curve(lr_list, args, param_name, iters_per_epoch, by_epoch=True): """Plot learning rate vs iter graph.""" try: import seaborn as sns sns.set_style(args.style) except ImportError: pass wind_w, wind_h = args.window_size.split('*') wind_w, wind_h = int(wind_w), int(wind_h) plt.figure(figsize=(wind_w, wind_h)) ax: plt.Axes = plt.subplot() ax.plot(lr_list, linewidth=1) if by_epoch: ax.xaxis.tick_top() ax.set_xlabel('Iters') ax.xaxis.set_label_position('top') sec_ax = ax.secondary_xaxis( 'bottom', functions=(lambda x: x / iters_per_epoch, lambda y: y * iters_per_epoch)) sec_ax.set_xlabel('Epochs') else: plt.xlabel('Iters') plt.ylabel(param_name) if args.title is None: plt.title(f'{osp.basename(args.config)} {param_name} curve') else: plt.title(args.title) def simulate_train(data_loader, cfg, by_epoch): model = SimpleModel() param_record_hook = ParamRecordHook(by_epoch=by_epoch) default_hooks = dict( param_scheduler=cfg.default_hooks['param_scheduler'], runtime_info=None, timer=None, logger=None, checkpoint=None, sampler_seed=None, param_record=param_record_hook) runner = Runner( model=model, work_dir=cfg.work_dir, train_dataloader=data_loader, train_cfg=cfg.train_cfg, log_level=cfg.log_level, optim_wrapper=cfg.optim_wrapper, param_scheduler=cfg.param_scheduler, default_scope=cfg.default_scope, default_hooks=default_hooks, visualizer=MagicMock(spec=Visualizer), custom_hooks=cfg.get('custom_hooks', None)) runner.train() param_dict = dict( lr=param_record_hook.lr_list, momentum=param_record_hook.momentum_list, wd=param_record_hook.wd_list) return param_dict def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) if cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) cfg.log_level = args.log_level init_default_scope(cfg.get('default_scope', 'mmyolo')) # init logger print('Param_scheduler :') rich.print_json(json.dumps(cfg.param_scheduler)) # prepare data loader batch_size = cfg.train_dataloader.batch_size * args.ngpus if 'by_epoch' in cfg.train_cfg: by_epoch = cfg.train_cfg.get('by_epoch') elif 'type' in cfg.train_cfg: by_epoch = cfg.train_cfg.get('type') == 'EpochBasedTrainLoop' else: raise ValueError('please set `train_cfg`.') if args.dataset_size is None and by_epoch: from mmyolo.registry import DATASETS dataset_size = len(DATASETS.build(cfg.train_dataloader.dataset)) else: dataset_size = args.dataset_size or batch_size class FakeDataloader(list): dataset = MagicMock(metainfo=None) data_loader = FakeDataloader(range(dataset_size // batch_size)) dataset_info = ( f'\nDataset infos:' f'\n - Dataset size: {dataset_size}' f'\n - Batch size per GPU: {cfg.train_dataloader.batch_size}' f'\n - Number of GPUs: {args.ngpus}' f'\n - Total batch size: {batch_size}') if by_epoch: dataset_info += f'\n - Iterations per epoch: {len(data_loader)}' rich.print(dataset_info + '\n') # simulation training process param_dict = simulate_train(data_loader, cfg, by_epoch) param_list = param_dict[args.parameter] if args.parameter == 'lr': param_name = 'Learning Rate' elif args.parameter == 'momentum': param_name = 'Momentum' else: param_name = 'Weight Decay' plot_curve(param_list, args, param_name, len(data_loader), by_epoch) if args.out_dir: # make dir for output mkdir_or_exist(args.out_dir) # save the graph out_file = osp.join( args.out_dir, f'{osp.basename(args.config)}-{args.parameter}.jpg') plt.savefig(out_file) print(f'\nThe {param_name} graph is saved at {out_file}') if not args.not_show: plt.show() if __name__ == '__main__': main()