JustinLin610
update
8437114
raw history blame
No virus
8.71 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import bisect
import time
from collections import OrderedDict
from typing import Dict, Optional
try:
import torch
def type_as(a, b):
if torch.is_tensor(a) and torch.is_tensor(b):
return a.to(b)
else:
return a
except ImportError:
torch = None
def type_as(a, b):
return a
try:
import numpy as np
except ImportError:
np = None
class Meter(object):
"""Base class for Meters."""
def __init__(self):
pass
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
pass
def reset(self):
raise NotImplementedError
@property
def smoothed_value(self) -> float:
"""Smoothed value used for logging."""
raise NotImplementedError
def safe_round(number, ndigits):
if hasattr(number, "__round__"):
return round(number, ndigits)
elif torch is not None and torch.is_tensor(number) and number.numel() == 1:
return safe_round(number.item(), ndigits)
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"):
return safe_round(number.item(), ndigits)
else:
return number
class AverageMeter(Meter):
"""Computes and stores the average and current value"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.reset()
def reset(self):
self.val = None # most recent update
self.sum = 0 # sum from all updates
self.count = 0 # total n from all updates
def update(self, val, n=1):
if val is not None:
self.val = val
if n > 0:
self.sum = type_as(self.sum, val) + (val * n)
self.count = type_as(self.count, n) + n
def state_dict(self):
return {
"val": self.val,
"sum": self.sum,
"count": self.count,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.val = state_dict["val"]
self.sum = state_dict["sum"]
self.count = state_dict["count"]
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.count if self.count > 0 else self.val
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class SumMeter(Meter):
"""Computes and stores the sum"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.reset()
def reset(self):
self.sum = 0 # sum from all updates
def update(self, val):
if val is not None:
self.sum = type_as(self.sum, val) + val
def state_dict(self):
return {
"sum": self.sum,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.sum = state_dict["sum"]
self.round = state_dict.get("round", None)
@property
def smoothed_value(self) -> float:
val = self.sum
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class TimeMeter(Meter):
"""Computes the average occurrence of some event per second"""
def __init__(
self,
init: int = 0,
n: int = 0,
round: Optional[int] = None,
):
self.round = round
self.reset(init, n)
def reset(self, init=0, n=0):
self.init = init
self.start = time.perf_counter()
self.n = n
self.i = 0
def update(self, val=1):
self.n = type_as(self.n, val) + val
self.i += 1
def state_dict(self):
return {
"init": self.elapsed_time,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
if "start" in state_dict:
# backwards compatibility for old state_dicts
self.reset(init=state_dict["init"])
else:
self.reset(init=state_dict["init"], n=state_dict["n"])
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.perf_counter() - self.start)
@property
def smoothed_value(self) -> float:
val = self.avg
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class StopwatchMeter(Meter):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self, round: Optional[int] = None):
self.round = round
self.sum = 0
self.n = 0
self.start_time = None
def start(self):
self.start_time = time.perf_counter()
def stop(self, n=1, prehook=None):
if self.start_time is not None:
if prehook is not None:
prehook()
delta = time.perf_counter() - self.start_time
self.sum = self.sum + delta
self.n = type_as(self.n, n) + n
def reset(self):
self.sum = 0 # cumulative time during which stopwatch was active
self.n = 0 # total n across all start/stop
self.start()
def state_dict(self):
return {
"sum": self.sum,
"n": self.n,
"round": self.round,
}
def load_state_dict(self, state_dict):
self.sum = state_dict["sum"]
self.n = state_dict["n"]
self.start_time = None
self.round = state_dict.get("round", None)
@property
def avg(self):
return self.sum / self.n if self.n > 0 else self.sum
@property
def elapsed_time(self):
if self.start_time is None:
return 0.0
return time.perf_counter() - self.start_time
@property
def smoothed_value(self) -> float:
val = self.avg if self.sum > 0 else self.elapsed_time
if self.round is not None and val is not None:
val = safe_round(val, self.round)
return val
class MetersDict(OrderedDict):
"""A sorted dictionary of :class:`Meters`.
Meters are sorted according to a priority that is given when the
meter is first added to the dictionary.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.priorities = []
def __setitem__(self, key, value):
assert key not in self, "MetersDict doesn't support reassignment"
priority, value = value
bisect.insort(self.priorities, (priority, len(self.priorities), key))
super().__setitem__(key, value)
for _, _, key in self.priorities: # reorder dict to match priorities
self.move_to_end(key)
def add_meter(self, key, meter, priority):
self.__setitem__(key, (priority, meter))
def state_dict(self):
return [
(pri, key, self[key].__class__.__name__, self[key].state_dict())
for pri, _, key in self.priorities
# can't serialize DerivedMeter instances
if not isinstance(self[key], MetersDict._DerivedMeter)
]
def load_state_dict(self, state_dict):
self.clear()
self.priorities.clear()
for pri, key, meter_cls, meter_state in state_dict:
meter = globals()[meter_cls]()
meter.load_state_dict(meter_state)
self.add_meter(key, meter, pri)
def get_smoothed_value(self, key: str) -> float:
"""Get a single smoothed value."""
meter = self[key]
if isinstance(meter, MetersDict._DerivedMeter):
return meter.fn(self)
else:
return meter.smoothed_value
def get_smoothed_values(self) -> Dict[str, float]:
"""Get all smoothed values."""
return OrderedDict(
[
(key, self.get_smoothed_value(key))
for key in self.keys()
if not key.startswith("_")
]
)
def reset(self):
"""Reset Meter instances."""
for meter in self.values():
if isinstance(meter, MetersDict._DerivedMeter):
continue
meter.reset()
class _DerivedMeter(Meter):
"""A Meter whose values are derived from other Meters."""
def __init__(self, fn):
self.fn = fn
def reset(self):
pass