#!/usr/bin/env python3 import statistics import time from collections import defaultdict, deque from typing import Generator, Iterable, TypeVar import torch import torch.distributed as dist from tqdm import tqdm as tqdm_class from typing_extensions import Self from .output import ansi, get_ansi_len, prints __all__ = ["SmoothedValue", "MetricLogger"] MB = 1 << 20 T = TypeVar("T") class SmoothedValue: r"""Track a series of values and provide access to smoothed values over a window or the global series average. See Also: https://github.com/pytorch/vision/blob/main/references/classification/utils.py Args: name (str): Name string. window_size (int): The :attr:`maxlen` of :class:`~collections.deque`. fmt (str): The format pattern of ``str(self)``. Attributes: name (str): Name string. fmt (str): The string pattern. deque (~collections.deque): The unique data series. count (int): The amount of data. total (float): The sum of all data. median (float): The median of :attr:`deque`. avg (float): The avg of :attr:`deque`. global_avg (float): :math:`\frac{\text{total}}{\text{count}}` max (float): The max of :attr:`deque`. min (float): The min of :attr:`deque`. last_value (float): The last value of :attr:`deque`. """ def __init__( self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}" ): self.name = name self.deque: deque[float] = deque(maxlen=window_size) self.count: int = 0 self.total: float = 0.0 self.fmt = fmt def update(self, value: float, n: int = 1) -> Self: r"""Update :attr:`n` pieces of data with same :attr:`value`. .. code-block:: python self.deque.append(value) self.total += value * n self.count += n Args: value (float): the value to update. n (int): the number of data with same :attr:`value`. Returns: SmoothedValue: return ``self`` for stream usage. """ self.deque.append(value) self.total += value * n self.count += n return self def update_list(self, value_list: list[float]) -> Self: r"""Update :attr:`value_list`. .. code-block:: python for value in value_list: self.deque.append(value) self.total += value self.count += len(value_list) Args: value_list (list[float]): the value list to update. Returns: SmoothedValue: return ``self`` for stream usage. """ for value in value_list: self.deque.append(value) self.total += value self.count += len(value_list) return self def reset(self) -> Self: r"""Reset ``deque``, ``count`` and ``total`` to be empty. Returns: SmoothedValue: return ``self`` for stream usage. """ self.deque = deque(maxlen=self.deque.maxlen) self.count = 0 self.total = 0.0 return self def synchronize_between_processes(self): r""" Warning: Does NOT synchronize the deque! """ if not (dist.is_available() and dist.is_initialized()): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = float(t[1]) @property def median(self) -> float: try: return statistics.median(self.deque) except Exception: return 0.0 @property def avg(self) -> float: try: return statistics.mean(self.deque) except Exception: return 0.0 @property def global_avg(self) -> float: try: return self.total / self.count except Exception: return 0.0 @property def max(self) -> float: try: return max(self.deque) except Exception: return 0.0 @property def min(self) -> float: try: return min(self.deque) except Exception: return 0.0 @property def last_value(self) -> float: try: return self.deque[-1] except Exception: return 0.0 def __str__(self): return self.fmt.format( name=self.name, count=self.count, total=self.total, median=self.median, avg=self.avg, global_avg=self.global_avg, min=self.min, max=self.max, last_value=self.last_value, ) def __format__(self, format_spec: str) -> str: return self.__str__() class MetricLogger: r""" See Also: https://github.com/pytorch/vision/blob/main/references/classification/utils.py Args: delimiter (str): The delimiter to join different meter strings. Defaults to ``''``. meter_length (int): The minimum length for each meter. Defaults to ``20``. tqdm (bool): Whether to use tqdm to show iteration information. Defaults to ``env['tqdm']``. indent (int): The space indent for the entire string. Defaults to ``0``. Attributes: meters (dict[str, SmoothedValue]): The meter dict. iter_time (SmoothedValue): Iteration time meter. data_time (SmoothedValue): Data loading time meter. memory (SmoothedValue): Memory usage meter. """ def __init__( self, delimiter: str = "", meter_length: int = 20, tqdm: bool = True, indent: int = 0, **kwargs, ): self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue) self.create_meters(**kwargs) self.delimiter = delimiter self.meter_length = meter_length self.tqdm = tqdm self.indent = indent self.iter_time = SmoothedValue() self.data_time = SmoothedValue() self.memory = SmoothedValue(fmt="{max:.0f}") def create_meters(self, **kwargs: str) -> Self: r"""Create meters with specific ``fmt`` in :attr:`self.meters`. ``self.meters[meter_name] = SmoothedValue(fmt=fmt)`` Args: **kwargs: ``(meter_name: fmt)`` Returns: MetricLogger: return ``self`` for stream usage. """ for k, v in kwargs.items(): self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v) return self def update(self, n: int = 1, **kwargs: float) -> Self: r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`. ``self.meters[meter_name].update(float(value), n=n)`` Args: n (int): the number of data with same value. **kwargs: ``{meter_name: value}``. Returns: MetricLogger: return ``self`` for stream usage. """ for k, v in kwargs.items(): if k not in self.meters: self.meters[k] = SmoothedValue() self.meters[k].update(float(v), n=n) return self def update_list(self, **kwargs: list) -> Self: r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`. ``self.meters[meter_name].update_list(value_list)`` Args: **kwargs: ``{meter_name: value_list}``. Returns: MetricLogger: return ``self`` for stream usage. """ for k, v in kwargs.items(): self.meters[k].update_list(v) return self def reset(self) -> Self: r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`. Returns: MetricLogger: return ``self`` for stream usage. """ for meter in self.meters.values(): meter.reset() return self def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str: r"""Generate formatted string based on keyword arguments. ``key: value`` with max length to be :attr:`self.meter_length`. Args: cut_too_long (bool): Whether to cut too long values to first 5 characters. Defaults to ``True``. strip (bool): Whether to strip trailing whitespaces. Defaults to ``True``. **kwargs: Keyword arguments to generate string. """ str_list: list[str] = [] for k, v in kwargs.items(): v_str = str(v) _str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi) max_length = self.meter_length + get_ansi_len(_str) if cut_too_long: _str = _str[:max_length] str_list.append(_str.ljust(max_length)) _str = self.delimiter.join(str_list) if strip: _str = _str.rstrip() return _str def __getattr__(self, attr: str) -> float: if attr in self.meters: return self.meters[attr] if attr in vars(self): # TODO: use hasattr return vars(self)[attr] raise AttributeError( "'{}' object has no attribute '{}'".format(type(self).__name__, attr) ) def __str__(self) -> str: return self.get_str(**self.meters) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def log_every( self, iterable: Iterable[T], header: str = "", tqdm: bool = None, tqdm_header: str = "Iter", indent: int = None, verbose: int = 1, ) -> Generator[T, None, None]: r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs. * Middle Output: ``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}`` * Final Output ``{header} str(self) {memory} {iter_time} {data_time} {total_time}`` Args: iterable (~collections.abc.Iterable): The raw iterator. header (str): The header string for final output. Defaults to ``''``. tqdm (bool): Whether to use tqdm to show iteration information. Defaults to ``self.tqdm``. tqdm_header (str): The header string for middle output. Defaults to ``'Iter'``. indent (int): The space indent for the entire string. if ``None``, use ``self.indent``. Defaults to ``None``. verbose (int): The verbose level of output information. """ tqdm = tqdm if tqdm is not None else self.tqdm indent = indent if indent is not None else self.indent iterator = iterable if len(header) != 0: header = header.ljust(30 + get_ansi_len(header)) if tqdm: length = len(str(len(iterable))) pattern: str = ( "{tqdm_header}: {blue_light}" "[ {red}{{n_fmt:>{length}}}{blue_light} " "/ {red}{{total_fmt}}{blue_light} ]{reset}" ).format(tqdm_header=tqdm_header, length=length, **ansi) offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length pattern = pattern.ljust(30 + offset + get_ansi_len(pattern)) time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False) bar_format = f"{pattern}{{desc}}{time_str}" iterator = tqdm_class(iterable, leave=False, bar_format=bar_format) self.iter_time.reset() self.data_time.reset() self.memory.reset() end = time.time() start_time = time.time() for obj in iterator: cur_data_time = time.time() - end self.data_time.update(cur_data_time) yield obj cur_iter_time = time.time() - end self.iter_time.update(cur_iter_time) if torch.cuda.is_available(): cur_memory = torch.cuda.max_memory_allocated() / MB self.memory.update(cur_memory) if tqdm: _dict = {k: v for k, v in self.meters.items()} if verbose > 2 and torch.cuda.is_available(): _dict.update(memory=f"{cur_memory:.0f} MB") if verbose > 1: _dict.update( iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s" ) iterator.set_description_str(self.get_str(**_dict, strip=False)) end = time.time() self.synchronize_between_processes() total_time = time.time() - start_time total_time_str = tqdm_class.format_interval(total_time) _dict = {k: v for k, v in self.meters.items()} if verbose > 2 and torch.cuda.is_available(): _dict.update(memory=f"{str(self.memory)} MB") if verbose > 1: _dict.update( iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s" ) _dict.update(time=total_time_str) prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent)