Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import datetime | |
import re | |
from collections import OrderedDict | |
from itertools import chain | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import torch | |
from mmengine.device import get_max_cuda_memory, is_cuda_available | |
from mmengine.registry import LOG_PROCESSORS | |
class LogProcessor: | |
"""A log processor used to format log information collected from | |
``runner.message_hub.log_scalars``. | |
``LogProcessor`` instance is built by runner and will format | |
``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can | |
directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument | |
``custom_cfg`` of constructor can control the statistics method of logs. | |
Args: | |
window_size (int): default smooth interval. Defaults to 10. | |
by_epoch (bool): Whether to format logs with epoch stype. Defaults to | |
True. | |
custom_cfg (list[dict], optional): Contains multiple log config dict, | |
in which key means the data source name of log and value means the | |
statistic method and corresponding arguments used to count the | |
data source. Defaults to None. | |
- If custom_cfg is None, all logs will be formatted via default | |
methods, such as smoothing loss by default window_size. If | |
custom_cfg is defined as a list of config dict, for example: | |
[dict(data_src='loss', method='mean', log_name='global_loss', | |
window_size='global')]. It means the log item ``loss`` will be | |
counted as global mean and additionally logged as ``global_loss`` | |
(defined by ``log_name``). If ``log_name`` is not defined in | |
config dict, the original logged key will be overwritten. | |
- The original log item cannot be overwritten twice. Here is | |
an error example: | |
[dict(data_src='loss', method='mean', window_size='global'), | |
dict(data_src='loss', method='mean', window_size='epoch')]. | |
Both log config dict in custom_cfg do not have ``log_name`` key, | |
which means the loss item will be overwritten twice. | |
- For those statistic methods with the ``window_size`` argument, | |
if ``by_epoch`` is set to False, ``windows_size`` should not be | |
`epoch` to statistics log value by epoch. | |
num_digits (int): The number of significant digit shown in the | |
logging message. Defaults to 4. | |
log_with_hierarchy (bool): Whether to log with hierarchy. If it is | |
True, the information is written to visualizer backend such as | |
:obj:`LocalVisBackend` and :obj:`TensorboardBackend` | |
with hierarchy. For example, ``loss`` will be saved as | |
``train/loss``, and accuracy will be saved as ``val/accuracy``. | |
Defaults to False. | |
`New in version 0.7.0.` | |
mean_pattern (str): This is a regular expression used to match the log | |
that need to be included in the smoothing statistics. | |
`New in version 0.7.3.` | |
Examples: | |
>>> # `log_name` is defined, `loss_large_window` will be an additional | |
>>> # record. | |
>>> log_processor = dict( | |
>>> window_size=10, | |
>>> by_epoch=True, | |
>>> custom_cfg=[dict(data_src='loss', | |
>>> log_name='loss_large_window', | |
>>> method_name='mean', | |
>>> window_size=100)]) | |
>>> # `log_name` is not defined. `loss` will be overwritten. | |
>>> log_processor = dict( | |
>>> window_size=10, | |
>>> by_epoch=True, | |
>>> custom_cfg=[dict(data_src='loss', | |
>>> method_name='mean', | |
>>> window_size=100)]) | |
>>> # Record loss with different statistics methods. | |
>>> log_processor = dict( | |
>>> window_size=10, | |
>>> by_epoch=True, | |
>>> custom_cfg=[dict(data_src='loss', | |
>>> log_name='loss_large_window', | |
>>> method_name='mean', | |
>>> window_size=100), | |
>>> dict(data_src='loss', | |
>>> method_name='mean', | |
>>> window_size=100)]) | |
>>> # Overwrite loss item twice will raise an error. | |
>>> log_processor = dict( | |
>>> window_size=10, | |
>>> by_epoch=True, | |
>>> custom_cfg=[dict(data_src='loss', | |
>>> method_name='mean', | |
>>> window_size=100), | |
>>> dict(data_src='loss', | |
>>> method_name='max', | |
>>> window_size=100)]) | |
AssertionError | |
""" | |
def __init__(self, | |
window_size=10, | |
by_epoch=True, | |
custom_cfg: Optional[List[dict]] = None, | |
num_digits: int = 4, | |
log_with_hierarchy: bool = False, | |
mean_pattern=r'.*(loss|time|data_time|grad_norm).*'): | |
self.window_size = window_size | |
self.by_epoch = by_epoch | |
self.custom_cfg = custom_cfg if custom_cfg else [] | |
self.num_digits = num_digits | |
self.log_with_hierarchy = log_with_hierarchy | |
self.mean_pattern = re.compile(mean_pattern) | |
self._check_custom_cfg() | |
def get_log_after_iter(self, runner, batch_idx: int, | |
mode: str) -> Tuple[dict, str]: | |
"""Format log string after training, validation or testing iteration. | |
Args: | |
runner (Runner): The runner of training phase. | |
batch_idx (int): The index of the current batch in the current | |
loop. | |
mode (str): Current mode of runner, train, test or val. | |
Return: | |
Tuple[dict, str]: Formatted log dict/string which will be | |
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. | |
""" | |
assert mode in ['train', 'test', 'val'] | |
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value. | |
parsed_cfg = self._parse_windows_size(runner, batch_idx, | |
self.custom_cfg) | |
# log_tag is used to write log information to terminal | |
log_tag = self._collect_scalars(parsed_cfg, runner, mode) | |
# If `self.log_with_hierarchy` is False, the tag is the same as | |
# log_tag. Otherwise, each key in tag starts with prefix `train`, | |
# `test` or `val` | |
if not self.log_with_hierarchy: | |
tag = copy.deepcopy(log_tag) | |
else: | |
tag = self._collect_scalars(parsed_cfg, runner, mode, True) | |
# Record learning rate. | |
lr_str_list = [] | |
for key, value in tag.items(): | |
if key.endswith('lr'): | |
key = self._remove_prefix(key, f'{mode}/') | |
log_tag.pop(key) | |
lr_str_list.append(f'{key}: ' | |
f'{value:.{self.num_digits}e}') | |
lr_str = ' '.join(lr_str_list) | |
# Format log header. | |
# by_epoch == True | |
# train/val: Epoch [5][5/10] ... | |
# test: Epoch [5/10] | |
# by_epoch == False | |
# train: Epoch [5/10000] ... (divided by `max_iter`) | |
# val/test: Epoch [5/2000] ... (divided by length of dataloader) | |
if self.by_epoch: | |
# Align the iteration log: | |
# Epoch(train) [ 9][010/270] | |
# ... ||| ||| | |
# Epoch(train) [ 10][100/270] | |
dataloader_len = self._get_dataloader_size(runner, mode) | |
cur_iter = self._get_iter(runner, batch_idx) | |
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len))) | |
if mode in ['train', 'val']: | |
cur_epoch = self._get_epoch(runner, mode) | |
if not (isinstance(runner._train_loop, dict) | |
or runner._train_loop is None): | |
# Right Align the epoch log: | |
# Epoch(train) [9][100/270] | |
# ... || | |
# Epoch(train) [100][100/270] | |
max_epochs = runner.max_epochs | |
# 3 means the three characters: "[", "]", and " " occupied | |
# in " [{max_epochs}]" | |
cur_epoch_str = f'[{cur_epoch}]'.rjust( | |
len(str(max_epochs)) + 3, ' ') | |
else: | |
cur_epoch_str = f'[{cur_epoch}]' | |
tag['epoch'] = cur_epoch | |
log_str = (f'Epoch({mode}){cur_epoch_str}' | |
f'[{cur_iter_str}/{dataloader_len}] ') | |
else: | |
log_str = (f'Epoch({mode}) ' | |
f'[{cur_iter_str}/{dataloader_len}] ') | |
else: | |
if mode == 'train': | |
cur_iter = self._get_iter(runner, batch_idx) | |
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters))) | |
log_str = (f'Iter({mode}) ' | |
f'[{cur_iter_str}/{runner.max_iters}] ') | |
else: | |
dataloader_len = self._get_dataloader_size(runner, mode) | |
cur_iter_str = str(batch_idx + 1).rjust( | |
len(str(dataloader_len))) | |
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ') | |
# Add global iter. | |
if isinstance(runner._train_loop, dict) or runner._train_loop is None: | |
tag['iter'] = 0 | |
else: | |
tag['iter'] = runner.iter + 1 | |
# Concatenate lr, momentum string with log header. | |
log_str += f'{lr_str} ' | |
# If IterTimerHook used in runner, eta, time, and data_time should be | |
# recorded. | |
if (all(item in log_tag for item in ['time', 'data_time']) | |
and 'eta' in runner.message_hub.runtime_info): | |
eta = runner.message_hub.get_info('eta') | |
eta_str = str(datetime.timedelta(seconds=int(eta))) | |
log_str += f'eta: {eta_str} ' | |
log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} ' | |
f'data_time: ' | |
f'{log_tag["data_time"]:.{self.num_digits}f} ') | |
# Pop recorded keys | |
log_tag.pop('time') | |
log_tag.pop('data_time') | |
# If cuda is available, the max memory occupied should be calculated. | |
if is_cuda_available(): | |
max_memory = self._get_max_memory(runner) | |
log_str += f'memory: {max_memory} ' | |
tag['memory'] = max_memory | |
# Loop left keys to fill `log_str`. | |
if mode in ('train', 'val'): | |
log_items = [] | |
for name, val in log_tag.items(): | |
if mode == 'val' and not name.startswith('val/loss'): | |
continue | |
if isinstance(val, float): | |
val = f'{val:.{self.num_digits}f}' | |
log_items.append(f'{name}: {val}') | |
log_str += ' '.join(log_items) | |
return tag, log_str | |
def get_log_after_epoch(self, | |
runner, | |
batch_idx: int, | |
mode: str, | |
with_non_scalar: bool = False) -> Tuple[dict, str]: | |
"""Format log string after validation or testing epoch. | |
Args: | |
runner (Runner): The runner of validation/testing phase. | |
batch_idx (int): The index of the current batch in the current | |
loop. | |
mode (str): Current mode of runner. | |
with_non_scalar (bool): Whether to include non-scalar infos in the | |
returned tag. Defaults to False. | |
Return: | |
Tuple[dict, str]: Formatted log dict/string which will be | |
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. | |
""" | |
assert mode in [ | |
'test', 'val' | |
], ('`_get_metric_log_str` only accept val or test mode, but got ' | |
f'{mode}') | |
dataloader_len = self._get_dataloader_size(runner, mode) | |
# By epoch: | |
# Epoch(val) [10][1000/1000] ... | |
# Epoch(test) [1000/1000] ... | |
# By iteration: | |
# Iteration(val) [1000/1000] ... | |
# Iteration(test) [1000/1000] ... | |
if self.by_epoch: | |
if mode == 'val': | |
cur_epoch = self._get_epoch(runner, mode) | |
log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/' | |
f'{dataloader_len}] ') | |
else: | |
log_str = ( | |
f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ') | |
else: | |
log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ') | |
custom_cfg_copy = copy.deepcopy(self.custom_cfg) | |
# remove prefix | |
custom_keys = [ | |
self._remove_prefix(cfg['data_src'], f'{mode}/') | |
for cfg in custom_cfg_copy | |
] | |
# Count the averaged time and data_time by epoch | |
if 'time' not in custom_keys: | |
custom_cfg_copy.append( | |
dict(data_src='time', window_size='epoch', method_name='mean')) | |
if 'data_time' not in custom_keys: | |
custom_cfg_copy.append( | |
dict( | |
data_src='data_time', | |
window_size='epoch', | |
method_name='mean')) | |
parsed_cfg = self._parse_windows_size(runner, batch_idx, | |
custom_cfg_copy) | |
# tag is used to write log information to different backends. | |
ori_tag = self._collect_scalars(parsed_cfg, runner, mode, | |
self.log_with_hierarchy) | |
non_scalar_tag = self._collect_non_scalars(runner, mode) | |
# move `time` or `data_time` to the end of the log | |
tag = OrderedDict() | |
time_tag = OrderedDict() | |
for key, value in ori_tag.items(): | |
if key in (f'{mode}/time', f'{mode}/data_time', 'time', | |
'data_time'): | |
time_tag[key] = value | |
else: | |
tag[key] = value | |
# Log other messages. | |
log_items = [] | |
log_str += ' ' | |
for name, val in chain(tag.items(), non_scalar_tag.items(), | |
time_tag.items()): | |
if isinstance(val, float): | |
val = f'{val:.{self.num_digits}f}' | |
if isinstance(val, (torch.Tensor, np.ndarray)): | |
# newline to display tensor and array. | |
val = f'\n{val}\n' | |
log_items.append(f'{name}: {val}') | |
log_str += ' '.join(log_items) | |
if with_non_scalar: | |
tag.update(non_scalar_tag) | |
tag.update(time_tag) | |
return tag, log_str | |
def _collect_scalars(self, | |
custom_cfg: List[dict], | |
runner, | |
mode: str, | |
reserve_prefix: bool = False) -> dict: | |
"""Collect log information to compose a dict according to mode. | |
Args: | |
custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int | |
``window_size``. | |
runner (Runner): The runner of the training/testing/validation | |
process. | |
mode (str): Current mode of runner. | |
reserve_prefix (bool): Whether to reserve the prefix of the key. | |
Returns: | |
dict: Statistical values of logs. | |
""" | |
custom_cfg = copy.deepcopy(custom_cfg) | |
tag = OrderedDict() | |
# history_scalars of train/val/test phase. | |
history_scalars = runner.message_hub.log_scalars | |
# corresponding mode history_scalars | |
mode_history_scalars = OrderedDict() | |
# extract log scalars and remove prefix to `mode_history_scalars` | |
# according to mode. | |
for prefix_key, log_buffer in history_scalars.items(): | |
if prefix_key.startswith(mode): | |
if not reserve_prefix: | |
key = self._remove_prefix(prefix_key, f'{mode}/') | |
else: | |
key = prefix_key | |
mode_history_scalars[key] = log_buffer | |
for key in mode_history_scalars: | |
# Update the latest learning rate and smoothed time logs. | |
if re.search(self.mean_pattern, key) is not None: | |
tag[key] = mode_history_scalars[key].mean(self.window_size) | |
else: | |
# Default statistic method is current. | |
tag[key] = mode_history_scalars[key].current() | |
# Update custom keys. | |
for log_cfg in custom_cfg: | |
data_src = log_cfg.pop('data_src') | |
log_name = log_cfg.pop('log_name', data_src) | |
if reserve_prefix: | |
data_src = f'{mode}/{data_src}' | |
log_name = f'{mode}/{log_name}' | |
# log item in custom_cfg could only exist in train or val | |
# mode. | |
if data_src in mode_history_scalars: | |
tag[log_name] = mode_history_scalars[data_src].statistics( | |
**log_cfg) | |
return tag | |
def _collect_non_scalars(self, runner, mode: str) -> dict: | |
"""Collect log information to compose a dict according to mode. | |
Args: | |
runner (Runner): The runner of the training/testing/validation | |
process. | |
mode (str): Current mode of runner. | |
Returns: | |
dict: non-scalar infos of the specified mode. | |
""" | |
# infos of train/val/test phase. | |
infos = runner.message_hub.runtime_info | |
# corresponding mode infos | |
mode_infos = OrderedDict() | |
# extract log info and remove prefix to `mode_infos` according to mode. | |
for prefix_key, value in infos.items(): | |
if prefix_key.startswith(mode): | |
if self.log_with_hierarchy: | |
key = prefix_key | |
else: | |
key = self._remove_prefix(prefix_key, f'{mode}/') | |
mode_infos[key] = value | |
return mode_infos | |
def _remove_prefix(self, string: str, prefix: str): | |
"""Remove the prefix ``train``, ``val`` and ``test`` of the key.""" | |
if string.startswith(prefix): | |
return string[len(prefix):] | |
else: | |
return string | |
def _check_custom_cfg(self) -> None: | |
"""Check the legality of ``self.custom_cfg``.""" | |
def _check_window_size(): | |
for log_cfg in self.custom_cfg: | |
if not self.by_epoch: | |
assert log_cfg['window_size'] != 'epoch', \ | |
'window_size cannot be epoch if LoggerHook.by_epoch' \ | |
' is False.' | |
def _check_repeated_log_name(): | |
# The `log_name` of the same data_src should not be repeated. | |
# If `log_name` is not specified, `data_src` will be overwritten. | |
# But only allowed to be overwritten once. | |
check_set = set() | |
for log_cfg in self.custom_cfg: | |
assert 'data_src' in log_cfg | |
data_src = log_cfg['data_src'] | |
log_name = log_cfg.get('log_name', data_src) | |
assert log_name not in check_set, ( | |
f'Found duplicate {log_name} for {data_src}. Please check' | |
'your `custom_cfg` for `log_processor`. You should ' | |
f'neither define duplicate `{log_name}` for {data_src} ' | |
f'nor do not define any {log_name} for multiple ' | |
f'{data_src}, See more information in the docstring of ' | |
'LogProcessor') | |
check_set.add(log_name) | |
_check_repeated_log_name() | |
_check_window_size() | |
def _parse_windows_size(self, | |
runner, | |
batch_idx: int, | |
custom_cfg: Optional[list] = None) -> list: | |
"""Parse window_size defined in custom_cfg to int value. | |
Args: | |
runner (Runner): The runner of the training/testing/validation | |
process. | |
batch_idx (int): The iteration index of current dataloader. | |
custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None | |
to keep backward compatibility. | |
""" | |
if custom_cfg is None: | |
custom_cfg = copy.deepcopy(self.custom_cfg) | |
else: | |
custom_cfg = copy.deepcopy(custom_cfg) | |
for log_cfg in custom_cfg: | |
window_size = log_cfg.get('window_size', None) | |
if window_size is None or isinstance(window_size, int): | |
continue | |
elif window_size == 'epoch': | |
log_cfg['window_size'] = batch_idx + 1 | |
elif window_size == 'global': | |
log_cfg['window_size'] = runner.iter + 1 | |
else: | |
raise TypeError( | |
'window_size should be int, epoch or global, but got ' | |
f'invalid {window_size}') | |
return custom_cfg | |
def _get_max_memory(self, runner) -> int: | |
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) | |
for a given device. | |
Args: | |
runner (Runner): The runner of the training/testing/validation | |
process. | |
Returns: | |
The maximum GPU memory occupied by tensors in megabytes for a given | |
device. | |
""" | |
device = getattr(runner.model, 'output_device', None) | |
return get_max_cuda_memory(device) | |
def _get_iter(self, runner, batch_idx: int) -> int: | |
"""Get current iteration index. | |
Args: | |
runner (Runner): The runner of the training/testing/validation | |
process. | |
batch_idx (int): The iteration index of current | |
dataloader. Defaults to None. | |
Returns: | |
int: The current global iter or inner iter. | |
""" | |
if self.by_epoch: | |
current_iter = batch_idx + 1 | |
else: | |
current_iter = runner.iter + 1 | |
return current_iter | |
def _get_epoch(self, runner, mode: str) -> int: | |
"""Get current epoch according to mode. | |
Args: | |
runner (Runner): The runner of the training/testing/validation | |
process. | |
mode (str): Current mode of runner. | |
Returns: | |
int: The current epoch. | |
""" | |
if mode == 'train': | |
epoch = runner.epoch + 1 | |
elif mode == 'val': | |
if (isinstance(runner._train_loop, dict) | |
or runner._train_loop is None): | |
epoch = 0 | |
else: | |
# normal val mode | |
# runner.epoch += 1 has been done before validation | |
epoch = runner.epoch | |
else: | |
raise ValueError( | |
f"runner mode should be 'train' or 'val', but got {mode}") | |
return epoch | |
def _get_cur_loop(self, runner, mode: str): | |
"""Get current loop according to mode. | |
Args: | |
runner (Runner): The runner of the training/validation/testing | |
process. | |
mode (str): Current mode of runner. | |
Returns: | |
BaseLoop: Current loop of runner. | |
""" | |
# returns type hint will occur circular import | |
if mode == 'train': | |
return runner.train_loop | |
elif mode == 'val': | |
return runner.val_loop | |
else: | |
return runner.test_loop | |
def _get_dataloader_size(self, runner, mode) -> int: | |
"""Get dataloader size of current loop. | |
Args: | |
runner (Runner): The runner of the training/validation/testing | |
mode (str): Current mode of runner. | |
Returns: | |
int: The dataloader size of current loop. | |
""" | |
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader) | |