Spaces:
Runtime error
Runtime error
File size: 3,817 Bytes
128757a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from collections import defaultdict
from collections import deque
import torch
import time
from datetime import datetime
from .comm import is_main_process
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20):
self.deque = deque(maxlen=window_size)
# self.series = []
self.total = 0.0
self.count = 0
def update(self, value):
self.deque.append(value)
# self.series.append(value)
self.count += 1
if value != value:
value = 0
self.total += value
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque))
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
)
return self.delimiter.join(loss_str)
# haotian added tensorboard support
class TensorboardLogger(MetricLogger):
def __init__(self,
log_dir,
start_iter=0,
delimiter='\t'
):
super(TensorboardLogger, self).__init__(delimiter)
self.iteration = start_iter
self.writer = self._get_tensorboard_writer(log_dir)
@staticmethod
def _get_tensorboard_writer(log_dir):
try:
from tensorboardX import SummaryWriter
except ImportError:
raise ImportError(
'To use tensorboard please install tensorboardX '
'[ pip install tensorflow tensorboardX ].'
)
if is_main_process():
# timestamp = datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H:%M')
tb_logger = SummaryWriter('{}'.format(log_dir))
return tb_logger
else:
return None
def update(self, **kwargs):
super(TensorboardLogger, self).update(**kwargs)
if self.writer:
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.writer.add_scalar(k, v, self.iteration)
self.iteration += 1
|