|
|
|
|
|
"""Meters.""" |
|
|
|
import datetime |
|
import numpy as np |
|
import os |
|
from collections import defaultdict, deque |
|
import torch |
|
from fvcore.common.timer import Timer |
|
from sklearn.metrics import average_precision_score |
|
|
|
import timesformer.utils.logging as logging |
|
import timesformer.utils.metrics as metrics |
|
import timesformer.utils.misc as misc |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class TestMeter(object): |
|
""" |
|
Perform the multi-view ensemble for testing: each video with an unique index |
|
will be sampled with multiple clips, and the predictions of the clips will |
|
be aggregated to produce the final prediction for the video. |
|
The accuracy is calculated with the given ground truth labels. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_videos, |
|
num_clips, |
|
num_cls, |
|
overall_iters, |
|
multi_label=False, |
|
ensemble_method="sum", |
|
): |
|
""" |
|
Construct tensors to store the predictions and labels. Expect to get |
|
num_clips predictions from each video, and calculate the metrics on |
|
num_videos videos. |
|
Args: |
|
num_videos (int): number of videos to test. |
|
num_clips (int): number of clips sampled from each video for |
|
aggregating the final prediction for the video. |
|
num_cls (int): number of classes for each prediction. |
|
overall_iters (int): overall iterations for testing. |
|
multi_label (bool): if True, use map as the metric. |
|
ensemble_method (str): method to perform the ensemble, options |
|
include "sum", and "max". |
|
""" |
|
|
|
self.iter_timer = Timer() |
|
self.data_timer = Timer() |
|
self.net_timer = Timer() |
|
self.num_clips = num_clips |
|
self.overall_iters = overall_iters |
|
self.multi_label = multi_label |
|
self.ensemble_method = ensemble_method |
|
|
|
self.video_preds = torch.zeros((num_videos, num_cls)) |
|
if multi_label: |
|
self.video_preds -= 1e10 |
|
|
|
self.video_labels = ( |
|
torch.zeros((num_videos, num_cls)) |
|
if multi_label |
|
else torch.zeros((num_videos)).long() |
|
) |
|
self.clip_count = torch.zeros((num_videos)).long() |
|
self.topk_accs = [] |
|
self.stats = {} |
|
|
|
|
|
self.reset() |
|
|
|
def reset(self): |
|
""" |
|
Reset the metric. |
|
""" |
|
self.clip_count.zero_() |
|
self.video_preds.zero_() |
|
if self.multi_label: |
|
self.video_preds -= 1e10 |
|
self.video_labels.zero_() |
|
|
|
def update_stats(self, preds, labels, clip_ids): |
|
""" |
|
Collect the predictions from the current batch and perform on-the-flight |
|
summation as ensemble. |
|
Args: |
|
preds (tensor): predictions from the current batch. Dimension is |
|
N x C where N is the batch size and C is the channel size |
|
(num_cls). |
|
labels (tensor): the corresponding labels of the current batch. |
|
Dimension is N. |
|
clip_ids (tensor): clip indexes of the current batch, dimension is |
|
N. |
|
""" |
|
for ind in range(preds.shape[0]): |
|
vid_id = int(clip_ids[ind]) // self.num_clips |
|
if self.video_labels[vid_id].sum() > 0: |
|
assert torch.equal( |
|
self.video_labels[vid_id].type(torch.FloatTensor), |
|
labels[ind].type(torch.FloatTensor), |
|
) |
|
self.video_labels[vid_id] = labels[ind] |
|
if self.ensemble_method == "sum": |
|
self.video_preds[vid_id] += preds[ind] |
|
elif self.ensemble_method == "max": |
|
self.video_preds[vid_id] = torch.max( |
|
self.video_preds[vid_id], preds[ind] |
|
) |
|
else: |
|
raise NotImplementedError( |
|
"Ensemble Method {} is not supported".format( |
|
self.ensemble_method |
|
) |
|
) |
|
self.clip_count[vid_id] += 1 |
|
|
|
def log_iter_stats(self, cur_iter): |
|
""" |
|
Log the stats. |
|
Args: |
|
cur_iter (int): the current iteration of testing. |
|
""" |
|
eta_sec = self.iter_timer.seconds() * (self.overall_iters - cur_iter) |
|
eta = str(datetime.timedelta(seconds=int(eta_sec))) |
|
stats = { |
|
"split": "test_iter", |
|
"cur_iter": "{}".format(cur_iter + 1), |
|
"eta": eta, |
|
"time_diff": self.iter_timer.seconds(), |
|
} |
|
logging.log_json_stats(stats) |
|
|
|
def iter_tic(self): |
|
""" |
|
Start to record time. |
|
""" |
|
self.iter_timer.reset() |
|
self.data_timer.reset() |
|
|
|
def iter_toc(self): |
|
""" |
|
Stop to record time. |
|
""" |
|
self.iter_timer.pause() |
|
self.net_timer.pause() |
|
|
|
def data_toc(self): |
|
self.data_timer.pause() |
|
self.net_timer.reset() |
|
|
|
def finalize_metrics(self, ks=(1, 5)): |
|
""" |
|
Calculate and log the final ensembled metrics. |
|
ks (tuple): list of top-k values for topk_accuracies. For example, |
|
ks = (1, 5) correspods to top-1 and top-5 accuracy. |
|
""" |
|
if not all(self.clip_count == self.num_clips): |
|
logger.warning( |
|
"clip count {} ~= num clips {}".format( |
|
", ".join( |
|
[ |
|
"{}: {}".format(i, k) |
|
for i, k in enumerate(self.clip_count.tolist()) |
|
] |
|
), |
|
self.num_clips, |
|
) |
|
) |
|
|
|
self.stats = {"split": "test_final"} |
|
if self.multi_label: |
|
map = get_map( |
|
self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy() |
|
) |
|
self.stats["map"] = map |
|
else: |
|
num_topks_correct = metrics.topks_correct( |
|
self.video_preds, self.video_labels, ks |
|
) |
|
topks = [ |
|
(x / self.video_preds.size(0)) * 100.0 |
|
for x in num_topks_correct |
|
] |
|
|
|
assert len({len(ks), len(topks)}) == 1 |
|
for k, topk in zip(ks, topks): |
|
self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format( |
|
topk, prec=2 |
|
) |
|
logging.log_json_stats(self.stats) |
|
|
|
|
|
class ScalarMeter(object): |
|
""" |
|
A scalar meter uses a deque to track a series of scaler values with a given |
|
window size. It supports calculating the median and average values of the |
|
window, and also supports calculating the global average. |
|
""" |
|
|
|
def __init__(self, window_size): |
|
""" |
|
Args: |
|
window_size (int): size of the max length of the deque. |
|
""" |
|
self.deque = deque(maxlen=window_size) |
|
self.total = 0.0 |
|
self.count = 0 |
|
|
|
def reset(self): |
|
""" |
|
Reset the deque. |
|
""" |
|
self.deque.clear() |
|
self.total = 0.0 |
|
self.count = 0 |
|
|
|
def add_value(self, value): |
|
""" |
|
Add a new scalar value to the deque. |
|
""" |
|
self.deque.append(value) |
|
self.count += 1 |
|
self.total += value |
|
|
|
def get_win_median(self): |
|
""" |
|
Calculate the current median value of the deque. |
|
""" |
|
return np.median(self.deque) |
|
|
|
def get_win_avg(self): |
|
""" |
|
Calculate the current average value of the deque. |
|
""" |
|
return np.mean(self.deque) |
|
|
|
def get_global_avg(self): |
|
""" |
|
Calculate the global mean value. |
|
""" |
|
return self.total / self.count |
|
|
|
|
|
class TrainMeter(object): |
|
""" |
|
Measure training stats. |
|
""" |
|
|
|
def __init__(self, epoch_iters, cfg): |
|
""" |
|
Args: |
|
epoch_iters (int): the overall number of iterations of one epoch. |
|
cfg (CfgNode): configs. |
|
""" |
|
self._cfg = cfg |
|
self.epoch_iters = epoch_iters |
|
self.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH * epoch_iters |
|
self.iter_timer = Timer() |
|
self.data_timer = Timer() |
|
self.net_timer = Timer() |
|
self.loss = ScalarMeter(cfg.LOG_PERIOD) |
|
self.loss_total = 0.0 |
|
self.lr = None |
|
|
|
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) |
|
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) |
|
|
|
self.num_top1_mis = 0 |
|
self.num_top5_mis = 0 |
|
self.num_samples = 0 |
|
self.output_dir = cfg.OUTPUT_DIR |
|
self.extra_stats = {} |
|
self.extra_stats_total = {} |
|
self.log_period = cfg.LOG_PERIOD |
|
|
|
def reset(self): |
|
""" |
|
Reset the Meter. |
|
""" |
|
self.loss.reset() |
|
self.loss_total = 0.0 |
|
self.lr = None |
|
self.mb_top1_err.reset() |
|
self.mb_top5_err.reset() |
|
self.num_top1_mis = 0 |
|
self.num_top5_mis = 0 |
|
self.num_samples = 0 |
|
|
|
for key in self.extra_stats.keys(): |
|
self.extra_stats[key].reset() |
|
self.extra_stats_total[key] = 0.0 |
|
|
|
def iter_tic(self): |
|
""" |
|
Start to record time. |
|
""" |
|
self.iter_timer.reset() |
|
self.data_timer.reset() |
|
|
|
def iter_toc(self): |
|
""" |
|
Stop to record time. |
|
""" |
|
self.iter_timer.pause() |
|
self.net_timer.pause() |
|
|
|
def data_toc(self): |
|
self.data_timer.pause() |
|
self.net_timer.reset() |
|
|
|
def update_stats(self, top1_err, top5_err, loss, lr, mb_size, stats={}): |
|
""" |
|
Update the current stats. |
|
Args: |
|
top1_err (float): top1 error rate. |
|
top5_err (float): top5 error rate. |
|
loss (float): loss value. |
|
lr (float): learning rate. |
|
mb_size (int): mini batch size. |
|
""" |
|
self.loss.add_value(loss) |
|
self.lr = lr |
|
self.loss_total += loss * mb_size |
|
self.num_samples += mb_size |
|
|
|
if not self._cfg.DATA.MULTI_LABEL: |
|
|
|
self.mb_top1_err.add_value(top1_err) |
|
self.mb_top5_err.add_value(top5_err) |
|
|
|
self.num_top1_mis += top1_err * mb_size |
|
self.num_top5_mis += top5_err * mb_size |
|
|
|
for key in stats.keys(): |
|
if key not in self.extra_stats: |
|
self.extra_stats[key] = ScalarMeter(self.log_period) |
|
self.extra_stats_total[key] = 0.0 |
|
self.extra_stats[key].add_value(stats[key]) |
|
self.extra_stats_total[key] += stats[key] * mb_size |
|
|
|
def log_iter_stats(self, cur_epoch, cur_iter): |
|
""" |
|
log the stats of the current iteration. |
|
Args: |
|
cur_epoch (int): the number of current epoch. |
|
cur_iter (int): the number of current iteration. |
|
""" |
|
if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: |
|
return |
|
eta_sec = self.iter_timer.seconds() * ( |
|
self.MAX_EPOCH - (cur_epoch * self.epoch_iters + cur_iter + 1) |
|
) |
|
eta = str(datetime.timedelta(seconds=int(eta_sec))) |
|
stats = { |
|
"_type": "train_iter", |
|
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), |
|
"iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), |
|
"dt": self.iter_timer.seconds(), |
|
"dt_data": self.data_timer.seconds(), |
|
"dt_net": self.net_timer.seconds(), |
|
"eta": eta, |
|
"loss": self.loss.get_win_median(), |
|
"lr": self.lr, |
|
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), |
|
} |
|
if not self._cfg.DATA.MULTI_LABEL: |
|
stats["top1_err"] = self.mb_top1_err.get_win_median() |
|
stats["top5_err"] = self.mb_top5_err.get_win_median() |
|
for key in self.extra_stats.keys(): |
|
stats[key] = self.extra_stats_total[key] / self.num_samples |
|
logging.log_json_stats(stats) |
|
|
|
def log_epoch_stats(self, cur_epoch): |
|
""" |
|
Log the stats of the current epoch. |
|
Args: |
|
cur_epoch (int): the number of current epoch. |
|
""" |
|
eta_sec = self.iter_timer.seconds() * ( |
|
self.MAX_EPOCH - (cur_epoch + 1) * self.epoch_iters |
|
) |
|
eta = str(datetime.timedelta(seconds=int(eta_sec))) |
|
stats = { |
|
"_type": "train_epoch", |
|
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), |
|
"dt": self.iter_timer.seconds(), |
|
"dt_data": self.data_timer.seconds(), |
|
"dt_net": self.net_timer.seconds(), |
|
"eta": eta, |
|
"lr": self.lr, |
|
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), |
|
"RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), |
|
} |
|
if not self._cfg.DATA.MULTI_LABEL: |
|
top1_err = self.num_top1_mis / self.num_samples |
|
top5_err = self.num_top5_mis / self.num_samples |
|
avg_loss = self.loss_total / self.num_samples |
|
stats["top1_err"] = top1_err |
|
stats["top5_err"] = top5_err |
|
stats["loss"] = avg_loss |
|
for key in self.extra_stats.keys(): |
|
stats[key] = self.extra_stats_total[key] / self.num_samples |
|
logging.log_json_stats(stats) |
|
|
|
|
|
class ValMeter(object): |
|
""" |
|
Measures validation stats. |
|
""" |
|
|
|
def __init__(self, max_iter, cfg): |
|
""" |
|
Args: |
|
max_iter (int): the max number of iteration of the current epoch. |
|
cfg (CfgNode): configs. |
|
""" |
|
self._cfg = cfg |
|
self.max_iter = max_iter |
|
self.iter_timer = Timer() |
|
self.data_timer = Timer() |
|
self.net_timer = Timer() |
|
|
|
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) |
|
self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) |
|
|
|
self.min_top1_err = 100.0 |
|
self.min_top5_err = 100.0 |
|
|
|
self.num_top1_mis = 0 |
|
self.num_top5_mis = 0 |
|
self.num_samples = 0 |
|
self.all_preds = [] |
|
self.all_labels = [] |
|
self.output_dir = cfg.OUTPUT_DIR |
|
self.extra_stats = {} |
|
self.extra_stats_total = {} |
|
self.log_period = cfg.LOG_PERIOD |
|
|
|
def reset(self): |
|
""" |
|
Reset the Meter. |
|
""" |
|
self.iter_timer.reset() |
|
self.mb_top1_err.reset() |
|
self.mb_top5_err.reset() |
|
self.num_top1_mis = 0 |
|
self.num_top5_mis = 0 |
|
self.num_samples = 0 |
|
self.all_preds = [] |
|
self.all_labels = [] |
|
|
|
for key in self.extra_stats.keys(): |
|
self.extra_stats[key].reset() |
|
self.extra_stats_total[key] = 0.0 |
|
|
|
def iter_tic(self): |
|
""" |
|
Start to record time. |
|
""" |
|
self.iter_timer.reset() |
|
self.data_timer.reset() |
|
|
|
def iter_toc(self): |
|
""" |
|
Stop to record time. |
|
""" |
|
self.iter_timer.pause() |
|
self.net_timer.pause() |
|
|
|
def data_toc(self): |
|
self.data_timer.pause() |
|
self.net_timer.reset() |
|
|
|
def update_stats(self, top1_err, top5_err, mb_size, stats={}): |
|
""" |
|
Update the current stats. |
|
Args: |
|
top1_err (float): top1 error rate. |
|
top5_err (float): top5 error rate. |
|
mb_size (int): mini batch size. |
|
""" |
|
self.mb_top1_err.add_value(top1_err) |
|
self.mb_top5_err.add_value(top5_err) |
|
self.num_top1_mis += top1_err * mb_size |
|
self.num_top5_mis += top5_err * mb_size |
|
self.num_samples += mb_size |
|
|
|
for key in stats.keys(): |
|
if key not in self.extra_stats: |
|
self.extra_stats[key] = ScalarMeter(self.log_period) |
|
self.extra_stats_total[key] = 0.0 |
|
self.extra_stats[key].add_value(stats[key]) |
|
self.extra_stats_total[key] += stats[key] * mb_size |
|
|
|
|
|
def update_predictions(self, preds, labels): |
|
""" |
|
Update predictions and labels. |
|
Args: |
|
preds (tensor): model output predictions. |
|
labels (tensor): labels. |
|
""" |
|
|
|
self.all_preds.append(preds) |
|
self.all_labels.append(labels) |
|
|
|
def log_iter_stats(self, cur_epoch, cur_iter): |
|
""" |
|
log the stats of the current iteration. |
|
Args: |
|
cur_epoch (int): the number of current epoch. |
|
cur_iter (int): the number of current iteration. |
|
""" |
|
if (cur_iter + 1) % self._cfg.LOG_PERIOD != 0: |
|
return |
|
eta_sec = self.iter_timer.seconds() * (self.max_iter - cur_iter - 1) |
|
eta = str(datetime.timedelta(seconds=int(eta_sec))) |
|
stats = { |
|
"_type": "val_iter", |
|
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), |
|
"iter": "{}/{}".format(cur_iter + 1, self.max_iter), |
|
"time_diff": self.iter_timer.seconds(), |
|
"eta": eta, |
|
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), |
|
} |
|
if not self._cfg.DATA.MULTI_LABEL: |
|
stats["top1_err"] = self.mb_top1_err.get_win_median() |
|
stats["top5_err"] = self.mb_top5_err.get_win_median() |
|
for key in self.extra_stats.keys(): |
|
stats[key] = self.extra_stats[key].get_win_median() |
|
logging.log_json_stats(stats) |
|
|
|
def log_epoch_stats(self, cur_epoch): |
|
""" |
|
Log the stats of the current epoch. |
|
Args: |
|
cur_epoch (int): the number of current epoch. |
|
""" |
|
stats = { |
|
"_type": "val_epoch", |
|
"epoch": "{}/{}".format(cur_epoch + 1, self._cfg.SOLVER.MAX_EPOCH), |
|
"time_diff": self.iter_timer.seconds(), |
|
"gpu_mem": "{:.2f}G".format(misc.gpu_mem_usage()), |
|
"RAM": "{:.2f}/{:.2f}G".format(*misc.cpu_mem_usage()), |
|
} |
|
if self._cfg.DATA.MULTI_LABEL: |
|
stats["map"] = get_map( |
|
torch.cat(self.all_preds).cpu().numpy(), |
|
torch.cat(self.all_labels).cpu().numpy(), |
|
) |
|
else: |
|
top1_err = self.num_top1_mis / self.num_samples |
|
top5_err = self.num_top5_mis / self.num_samples |
|
self.min_top1_err = min(self.min_top1_err, top1_err) |
|
self.min_top5_err = min(self.min_top5_err, top5_err) |
|
|
|
stats["top1_err"] = top1_err |
|
stats["top5_err"] = top5_err |
|
stats["min_top1_err"] = self.min_top1_err |
|
stats["min_top5_err"] = self.min_top5_err |
|
|
|
for key in self.extra_stats.keys(): |
|
stats[key] = self.extra_stats_total[key] / self.num_samples |
|
|
|
logging.log_json_stats(stats) |
|
|
|
|
|
def get_map(preds, labels): |
|
""" |
|
Compute mAP for multi-label case. |
|
Args: |
|
preds (numpy tensor): num_examples x num_classes. |
|
labels (numpy tensor): num_examples x num_classes. |
|
Returns: |
|
mean_ap (int): final mAP score. |
|
""" |
|
|
|
logger.info("Getting mAP for {} examples".format(preds.shape[0])) |
|
|
|
preds = preds[:, ~(np.all(labels == 0, axis=0))] |
|
labels = labels[:, ~(np.all(labels == 0, axis=0))] |
|
aps = [0] |
|
try: |
|
aps = average_precision_score(labels, preds, average=None) |
|
except ValueError: |
|
print( |
|
"Average precision requires a sufficient number of samples \ |
|
in a batch which are missing in this sample." |
|
) |
|
|
|
mean_ap = np.mean(aps) |
|
return mean_ap |
|
|