Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import statistics | |
import time | |
from collections import defaultdict, deque | |
from typing import Generator, Iterable, TypeVar | |
import torch | |
import torch.distributed as dist | |
from tqdm import tqdm as tqdm_class | |
from typing_extensions import Self | |
from .output import ansi, get_ansi_len, prints | |
__all__ = ["SmoothedValue", "MetricLogger"] | |
MB = 1 << 20 | |
T = TypeVar("T") | |
class SmoothedValue: | |
r"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
See Also: | |
https://github.com/pytorch/vision/blob/main/references/classification/utils.py | |
Args: | |
name (str): Name string. | |
window_size (int): The :attr:`maxlen` of :class:`~collections.deque`. | |
fmt (str): The format pattern of ``str(self)``. | |
Attributes: | |
name (str): Name string. | |
fmt (str): The string pattern. | |
deque (~collections.deque): The unique data series. | |
count (int): The amount of data. | |
total (float): The sum of all data. | |
median (float): The median of :attr:`deque`. | |
avg (float): The avg of :attr:`deque`. | |
global_avg (float): :math:`\frac{\text{total}}{\text{count}}` | |
max (float): The max of :attr:`deque`. | |
min (float): The min of :attr:`deque`. | |
last_value (float): The last value of :attr:`deque`. | |
""" | |
def __init__( | |
self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}" | |
): | |
self.name = name | |
self.deque: deque[float] = deque(maxlen=window_size) | |
self.count: int = 0 | |
self.total: float = 0.0 | |
self.fmt = fmt | |
def update(self, value: float, n: int = 1) -> Self: | |
r"""Update :attr:`n` pieces of data with same :attr:`value`. | |
.. code-block:: python | |
self.deque.append(value) | |
self.total += value * n | |
self.count += n | |
Args: | |
value (float): the value to update. | |
n (int): the number of data with same :attr:`value`. | |
Returns: | |
SmoothedValue: return ``self`` for stream usage. | |
""" | |
self.deque.append(value) | |
self.total += value * n | |
self.count += n | |
return self | |
def update_list(self, value_list: list[float]) -> Self: | |
r"""Update :attr:`value_list`. | |
.. code-block:: python | |
for value in value_list: | |
self.deque.append(value) | |
self.total += value | |
self.count += len(value_list) | |
Args: | |
value_list (list[float]): the value list to update. | |
Returns: | |
SmoothedValue: return ``self`` for stream usage. | |
""" | |
for value in value_list: | |
self.deque.append(value) | |
self.total += value | |
self.count += len(value_list) | |
return self | |
def reset(self) -> Self: | |
r"""Reset ``deque``, ``count`` and ``total`` to be empty. | |
Returns: | |
SmoothedValue: return ``self`` for stream usage. | |
""" | |
self.deque = deque(maxlen=self.deque.maxlen) | |
self.count = 0 | |
self.total = 0.0 | |
return self | |
def synchronize_between_processes(self): | |
r""" | |
Warning: | |
Does NOT synchronize the deque! | |
""" | |
if not (dist.is_available() and dist.is_initialized()): | |
return | |
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") | |
dist.barrier() | |
dist.all_reduce(t) | |
t = t.tolist() | |
self.count = int(t[0]) | |
self.total = float(t[1]) | |
def median(self) -> float: | |
try: | |
return statistics.median(self.deque) | |
except Exception: | |
return 0.0 | |
def avg(self) -> float: | |
try: | |
return statistics.mean(self.deque) | |
except Exception: | |
return 0.0 | |
def global_avg(self) -> float: | |
try: | |
return self.total / self.count | |
except Exception: | |
return 0.0 | |
def max(self) -> float: | |
try: | |
return max(self.deque) | |
except Exception: | |
return 0.0 | |
def min(self) -> float: | |
try: | |
return min(self.deque) | |
except Exception: | |
return 0.0 | |
def last_value(self) -> float: | |
try: | |
return self.deque[-1] | |
except Exception: | |
return 0.0 | |
def __str__(self): | |
return self.fmt.format( | |
name=self.name, | |
count=self.count, | |
total=self.total, | |
median=self.median, | |
avg=self.avg, | |
global_avg=self.global_avg, | |
min=self.min, | |
max=self.max, | |
last_value=self.last_value, | |
) | |
def __format__(self, format_spec: str) -> str: | |
return self.__str__() | |
class MetricLogger: | |
r""" | |
See Also: | |
https://github.com/pytorch/vision/blob/main/references/classification/utils.py | |
Args: | |
delimiter (str): The delimiter to join different meter strings. | |
Defaults to ``''``. | |
meter_length (int): The minimum length for each meter. | |
Defaults to ``20``. | |
tqdm (bool): Whether to use tqdm to show iteration information. | |
Defaults to ``env['tqdm']``. | |
indent (int): The space indent for the entire string. | |
Defaults to ``0``. | |
Attributes: | |
meters (dict[str, SmoothedValue]): The meter dict. | |
iter_time (SmoothedValue): Iteration time meter. | |
data_time (SmoothedValue): Data loading time meter. | |
memory (SmoothedValue): Memory usage meter. | |
""" | |
def __init__( | |
self, | |
delimiter: str = "", | |
meter_length: int = 20, | |
tqdm: bool = True, | |
indent: int = 0, | |
**kwargs, | |
): | |
self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue) | |
self.create_meters(**kwargs) | |
self.delimiter = delimiter | |
self.meter_length = meter_length | |
self.tqdm = tqdm | |
self.indent = indent | |
self.iter_time = SmoothedValue() | |
self.data_time = SmoothedValue() | |
self.memory = SmoothedValue(fmt="{max:.0f}") | |
def create_meters(self, **kwargs: str) -> Self: | |
r"""Create meters with specific ``fmt`` in :attr:`self.meters`. | |
``self.meters[meter_name] = SmoothedValue(fmt=fmt)`` | |
Args: | |
**kwargs: ``(meter_name: fmt)`` | |
Returns: | |
MetricLogger: return ``self`` for stream usage. | |
""" | |
for k, v in kwargs.items(): | |
self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v) | |
return self | |
def update(self, n: int = 1, **kwargs: float) -> Self: | |
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`. | |
``self.meters[meter_name].update(float(value), n=n)`` | |
Args: | |
n (int): the number of data with same value. | |
**kwargs: ``{meter_name: value}``. | |
Returns: | |
MetricLogger: return ``self`` for stream usage. | |
""" | |
for k, v in kwargs.items(): | |
if k not in self.meters: | |
self.meters[k] = SmoothedValue() | |
self.meters[k].update(float(v), n=n) | |
return self | |
def update_list(self, **kwargs: list) -> Self: | |
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`. | |
``self.meters[meter_name].update_list(value_list)`` | |
Args: | |
**kwargs: ``{meter_name: value_list}``. | |
Returns: | |
MetricLogger: return ``self`` for stream usage. | |
""" | |
for k, v in kwargs.items(): | |
self.meters[k].update_list(v) | |
return self | |
def reset(self) -> Self: | |
r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`. | |
Returns: | |
MetricLogger: return ``self`` for stream usage. | |
""" | |
for meter in self.meters.values(): | |
meter.reset() | |
return self | |
def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str: | |
r"""Generate formatted string based on keyword arguments. | |
``key: value`` with max length to be :attr:`self.meter_length`. | |
Args: | |
cut_too_long (bool): Whether to cut too long values to first 5 characters. | |
Defaults to ``True``. | |
strip (bool): Whether to strip trailing whitespaces. | |
Defaults to ``True``. | |
**kwargs: Keyword arguments to generate string. | |
""" | |
str_list: list[str] = [] | |
for k, v in kwargs.items(): | |
v_str = str(v) | |
_str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi) | |
max_length = self.meter_length + get_ansi_len(_str) | |
if cut_too_long: | |
_str = _str[:max_length] | |
str_list.append(_str.ljust(max_length)) | |
_str = self.delimiter.join(str_list) | |
if strip: | |
_str = _str.rstrip() | |
return _str | |
def __getattr__(self, attr: str) -> float: | |
if attr in self.meters: | |
return self.meters[attr] | |
if attr in vars(self): # TODO: use hasattr | |
return vars(self)[attr] | |
raise AttributeError( | |
"'{}' object has no attribute '{}'".format(type(self).__name__, attr) | |
) | |
def __str__(self) -> str: | |
return self.get_str(**self.meters) | |
def synchronize_between_processes(self): | |
for meter in self.meters.values(): | |
meter.synchronize_between_processes() | |
def log_every( | |
self, | |
iterable: Iterable[T], | |
header: str = "", | |
tqdm: bool = None, | |
tqdm_header: str = "Iter", | |
indent: int = None, | |
verbose: int = 1, | |
) -> Generator[T, None, None]: | |
r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs. | |
* Middle Output: | |
``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}`` | |
* Final Output | |
``{header} str(self) {memory} {iter_time} {data_time} {total_time}`` | |
Args: | |
iterable (~collections.abc.Iterable): The raw iterator. | |
header (str): The header string for final output. | |
Defaults to ``''``. | |
tqdm (bool): Whether to use tqdm to show iteration information. | |
Defaults to ``self.tqdm``. | |
tqdm_header (str): The header string for middle output. | |
Defaults to ``'Iter'``. | |
indent (int): The space indent for the entire string. | |
if ``None``, use ``self.indent``. | |
Defaults to ``None``. | |
verbose (int): The verbose level of output information. | |
""" | |
tqdm = tqdm if tqdm is not None else self.tqdm | |
indent = indent if indent is not None else self.indent | |
iterator = iterable | |
if len(header) != 0: | |
header = header.ljust(30 + get_ansi_len(header)) | |
if tqdm: | |
length = len(str(len(iterable))) | |
pattern: str = ( | |
"{tqdm_header}: {blue_light}" | |
"[ {red}{{n_fmt:>{length}}}{blue_light} " | |
"/ {red}{{total_fmt}}{blue_light} ]{reset}" | |
).format(tqdm_header=tqdm_header, length=length, **ansi) | |
offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length | |
pattern = pattern.ljust(30 + offset + get_ansi_len(pattern)) | |
time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False) | |
bar_format = f"{pattern}{{desc}}{time_str}" | |
iterator = tqdm_class(iterable, leave=False, bar_format=bar_format) | |
self.iter_time.reset() | |
self.data_time.reset() | |
self.memory.reset() | |
end = time.time() | |
start_time = time.time() | |
for obj in iterator: | |
cur_data_time = time.time() - end | |
self.data_time.update(cur_data_time) | |
yield obj | |
cur_iter_time = time.time() - end | |
self.iter_time.update(cur_iter_time) | |
if torch.cuda.is_available(): | |
cur_memory = torch.cuda.max_memory_allocated() / MB | |
self.memory.update(cur_memory) | |
if tqdm: | |
_dict = {k: v for k, v in self.meters.items()} | |
if verbose > 2 and torch.cuda.is_available(): | |
_dict.update(memory=f"{cur_memory:.0f} MB") | |
if verbose > 1: | |
_dict.update( | |
iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s" | |
) | |
iterator.set_description_str(self.get_str(**_dict, strip=False)) | |
end = time.time() | |
self.synchronize_between_processes() | |
total_time = time.time() - start_time | |
total_time_str = tqdm_class.format_interval(total_time) | |
_dict = {k: v for k, v in self.meters.items()} | |
if verbose > 2 and torch.cuda.is_available(): | |
_dict.update(memory=f"{str(self.memory)} MB") | |
if verbose > 1: | |
_dict.update( | |
iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s" | |
) | |
_dict.update(time=total_time_str) | |
prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent) | |