zhzluke96
update
d2b7e94
raw
history blame
13.5 kB
#!/usr/bin/env python3
import statistics
import time
from collections import defaultdict, deque
from typing import Generator, Iterable, TypeVar
import torch
import torch.distributed as dist
from tqdm import tqdm as tqdm_class
from typing_extensions import Self
from .output import ansi, get_ansi_len, prints
__all__ = ["SmoothedValue", "MetricLogger"]
MB = 1 << 20
T = TypeVar("T")
class SmoothedValue:
r"""Track a series of values and provide access to smoothed values over a
window or the global series average.
See Also:
https://github.com/pytorch/vision/blob/main/references/classification/utils.py
Args:
name (str): Name string.
window_size (int): The :attr:`maxlen` of :class:`~collections.deque`.
fmt (str): The format pattern of ``str(self)``.
Attributes:
name (str): Name string.
fmt (str): The string pattern.
deque (~collections.deque): The unique data series.
count (int): The amount of data.
total (float): The sum of all data.
median (float): The median of :attr:`deque`.
avg (float): The avg of :attr:`deque`.
global_avg (float): :math:`\frac{\text{total}}{\text{count}}`
max (float): The max of :attr:`deque`.
min (float): The min of :attr:`deque`.
last_value (float): The last value of :attr:`deque`.
"""
def __init__(
self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}"
):
self.name = name
self.deque: deque[float] = deque(maxlen=window_size)
self.count: int = 0
self.total: float = 0.0
self.fmt = fmt
def update(self, value: float, n: int = 1) -> Self:
r"""Update :attr:`n` pieces of data with same :attr:`value`.
.. code-block:: python
self.deque.append(value)
self.total += value * n
self.count += n
Args:
value (float): the value to update.
n (int): the number of data with same :attr:`value`.
Returns:
SmoothedValue: return ``self`` for stream usage.
"""
self.deque.append(value)
self.total += value * n
self.count += n
return self
def update_list(self, value_list: list[float]) -> Self:
r"""Update :attr:`value_list`.
.. code-block:: python
for value in value_list:
self.deque.append(value)
self.total += value
self.count += len(value_list)
Args:
value_list (list[float]): the value list to update.
Returns:
SmoothedValue: return ``self`` for stream usage.
"""
for value in value_list:
self.deque.append(value)
self.total += value
self.count += len(value_list)
return self
def reset(self) -> Self:
r"""Reset ``deque``, ``count`` and ``total`` to be empty.
Returns:
SmoothedValue: return ``self`` for stream usage.
"""
self.deque = deque(maxlen=self.deque.maxlen)
self.count = 0
self.total = 0.0
return self
def synchronize_between_processes(self):
r"""
Warning:
Does NOT synchronize the deque!
"""
if not (dist.is_available() and dist.is_initialized()):
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = float(t[1])
@property
def median(self) -> float:
try:
return statistics.median(self.deque)
except Exception:
return 0.0
@property
def avg(self) -> float:
try:
return statistics.mean(self.deque)
except Exception:
return 0.0
@property
def global_avg(self) -> float:
try:
return self.total / self.count
except Exception:
return 0.0
@property
def max(self) -> float:
try:
return max(self.deque)
except Exception:
return 0.0
@property
def min(self) -> float:
try:
return min(self.deque)
except Exception:
return 0.0
@property
def last_value(self) -> float:
try:
return self.deque[-1]
except Exception:
return 0.0
def __str__(self):
return self.fmt.format(
name=self.name,
count=self.count,
total=self.total,
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
min=self.min,
max=self.max,
last_value=self.last_value,
)
def __format__(self, format_spec: str) -> str:
return self.__str__()
class MetricLogger:
r"""
See Also:
https://github.com/pytorch/vision/blob/main/references/classification/utils.py
Args:
delimiter (str): The delimiter to join different meter strings.
Defaults to ``''``.
meter_length (int): The minimum length for each meter.
Defaults to ``20``.
tqdm (bool): Whether to use tqdm to show iteration information.
Defaults to ``env['tqdm']``.
indent (int): The space indent for the entire string.
Defaults to ``0``.
Attributes:
meters (dict[str, SmoothedValue]): The meter dict.
iter_time (SmoothedValue): Iteration time meter.
data_time (SmoothedValue): Data loading time meter.
memory (SmoothedValue): Memory usage meter.
"""
def __init__(
self,
delimiter: str = "",
meter_length: int = 20,
tqdm: bool = True,
indent: int = 0,
**kwargs,
):
self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue)
self.create_meters(**kwargs)
self.delimiter = delimiter
self.meter_length = meter_length
self.tqdm = tqdm
self.indent = indent
self.iter_time = SmoothedValue()
self.data_time = SmoothedValue()
self.memory = SmoothedValue(fmt="{max:.0f}")
def create_meters(self, **kwargs: str) -> Self:
r"""Create meters with specific ``fmt`` in :attr:`self.meters`.
``self.meters[meter_name] = SmoothedValue(fmt=fmt)``
Args:
**kwargs: ``(meter_name: fmt)``
Returns:
MetricLogger: return ``self`` for stream usage.
"""
for k, v in kwargs.items():
self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v)
return self
def update(self, n: int = 1, **kwargs: float) -> Self:
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`.
``self.meters[meter_name].update(float(value), n=n)``
Args:
n (int): the number of data with same value.
**kwargs: ``{meter_name: value}``.
Returns:
MetricLogger: return ``self`` for stream usage.
"""
for k, v in kwargs.items():
if k not in self.meters:
self.meters[k] = SmoothedValue()
self.meters[k].update(float(v), n=n)
return self
def update_list(self, **kwargs: list) -> Self:
r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`.
``self.meters[meter_name].update_list(value_list)``
Args:
**kwargs: ``{meter_name: value_list}``.
Returns:
MetricLogger: return ``self`` for stream usage.
"""
for k, v in kwargs.items():
self.meters[k].update_list(v)
return self
def reset(self) -> Self:
r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`.
Returns:
MetricLogger: return ``self`` for stream usage.
"""
for meter in self.meters.values():
meter.reset()
return self
def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str:
r"""Generate formatted string based on keyword arguments.
``key: value`` with max length to be :attr:`self.meter_length`.
Args:
cut_too_long (bool): Whether to cut too long values to first 5 characters.
Defaults to ``True``.
strip (bool): Whether to strip trailing whitespaces.
Defaults to ``True``.
**kwargs: Keyword arguments to generate string.
"""
str_list: list[str] = []
for k, v in kwargs.items():
v_str = str(v)
_str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi)
max_length = self.meter_length + get_ansi_len(_str)
if cut_too_long:
_str = _str[:max_length]
str_list.append(_str.ljust(max_length))
_str = self.delimiter.join(str_list)
if strip:
_str = _str.rstrip()
return _str
def __getattr__(self, attr: str) -> float:
if attr in self.meters:
return self.meters[attr]
if attr in vars(self): # TODO: use hasattr
return vars(self)[attr]
raise AttributeError(
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
)
def __str__(self) -> str:
return self.get_str(**self.meters)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def log_every(
self,
iterable: Iterable[T],
header: str = "",
tqdm: bool = None,
tqdm_header: str = "Iter",
indent: int = None,
verbose: int = 1,
) -> Generator[T, None, None]:
r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs.
* Middle Output:
``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}``
* Final Output
``{header} str(self) {memory} {iter_time} {data_time} {total_time}``
Args:
iterable (~collections.abc.Iterable): The raw iterator.
header (str): The header string for final output.
Defaults to ``''``.
tqdm (bool): Whether to use tqdm to show iteration information.
Defaults to ``self.tqdm``.
tqdm_header (str): The header string for middle output.
Defaults to ``'Iter'``.
indent (int): The space indent for the entire string.
if ``None``, use ``self.indent``.
Defaults to ``None``.
verbose (int): The verbose level of output information.
"""
tqdm = tqdm if tqdm is not None else self.tqdm
indent = indent if indent is not None else self.indent
iterator = iterable
if len(header) != 0:
header = header.ljust(30 + get_ansi_len(header))
if tqdm:
length = len(str(len(iterable)))
pattern: str = (
"{tqdm_header}: {blue_light}"
"[ {red}{{n_fmt:>{length}}}{blue_light} "
"/ {red}{{total_fmt}}{blue_light} ]{reset}"
).format(tqdm_header=tqdm_header, length=length, **ansi)
offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length
pattern = pattern.ljust(30 + offset + get_ansi_len(pattern))
time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False)
bar_format = f"{pattern}{{desc}}{time_str}"
iterator = tqdm_class(iterable, leave=False, bar_format=bar_format)
self.iter_time.reset()
self.data_time.reset()
self.memory.reset()
end = time.time()
start_time = time.time()
for obj in iterator:
cur_data_time = time.time() - end
self.data_time.update(cur_data_time)
yield obj
cur_iter_time = time.time() - end
self.iter_time.update(cur_iter_time)
if torch.cuda.is_available():
cur_memory = torch.cuda.max_memory_allocated() / MB
self.memory.update(cur_memory)
if tqdm:
_dict = {k: v for k, v in self.meters.items()}
if verbose > 2 and torch.cuda.is_available():
_dict.update(memory=f"{cur_memory:.0f} MB")
if verbose > 1:
_dict.update(
iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s"
)
iterator.set_description_str(self.get_str(**_dict, strip=False))
end = time.time()
self.synchronize_between_processes()
total_time = time.time() - start_time
total_time_str = tqdm_class.format_interval(total_time)
_dict = {k: v for k, v in self.meters.items()}
if verbose > 2 and torch.cuda.is_available():
_dict.update(memory=f"{str(self.memory)} MB")
if verbose > 1:
_dict.update(
iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s"
)
_dict.update(time=total_time_str)
prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent)