Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
from collections import defaultdict | |
from collections import deque | |
import torch | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
""" | |
def __init__(self, window_size=20): | |
self.deque = deque(maxlen=window_size) | |
self.series = [] | |
self.total = 0.0 | |
self.count = 0 | |
def update(self, value): | |
self.deque.append(value) | |
self.series.append(value) | |
self.count += 1 | |
self.total += value | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque)) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / self.count | |
class MetricLogger(object): | |
def __init__(self, delimiter="\t"): | |
self.meters = defaultdict(SmoothedValue) | |
self.delimiter = delimiter | |
def update(self, **kwargs): | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.meters[k].update(v) | |
def __getattr__(self, attr): | |
if attr in self.meters: | |
return self.meters[attr] | |
return object.__getattr__(self, attr) | |
def __str__(self): | |
loss_str = [] | |
for name, meter in self.meters.items(): | |
loss_str.append( | |
"{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) | |
) | |
return self.delimiter.join(loss_str) | |