Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
from collections import OrderedDict | |
from functools import partial | |
import numpy as np | |
import torch | |
from vidar.utils.distributed import reduce_value | |
from vidar.utils.tensor import same_shape, interpolate | |
class BaseEvaluation: | |
""" | |
Base class for evaluation metrics | |
Parameters | |
---------- | |
cfg : Config | |
Configuration file | |
name : String | |
Evaluation name | |
task : String | |
Task referent to the evaluation | |
metrics : String | |
Metrics name | |
""" | |
def __init__(self, cfg, name, task, metrics): | |
self.name = name | |
self.task = task | |
self.width = 32 + 11 * len(metrics) | |
self.metrics = metrics | |
self.modes = [''] | |
self.font1 = {'color': 'magenta', 'attrs': ('bold',)} | |
self.font2 = {'color': 'cyan', 'attrs': ()} | |
self.nearest = partial(interpolate, scale_factor=None, mode='nearest', align_corners=None) | |
self.bilinear = partial(interpolate, scale_factor=None, mode='bilinear', align_corners=True) | |
self.only_first = cfg.has('only_first', False) | |
def horz_line(self): | |
"""Print horizontal line""" | |
return '|{:<}|'.format('*' * self.width) | |
def metr_line(self): | |
"""Print metrics line""" | |
return '| {:^30} |' + ' {:^8} |' * len(self.metrics) | |
def outp_line(self): | |
"""Print output line""" | |
return '{:<30}' + ' | {:^8.3f}' * len(self.metrics) | |
def wrap(string): | |
"""Wrap line around vertical bars""" | |
return '| {} |'.format(string) | |
def check_name(self, key): | |
"""Check name for prefixes""" | |
return key.startswith(self.name) or \ | |
key.startswith('fwd_' + self.name) or \ | |
key.startswith('bwd_' + self.name) | |
def reduce_fn(self, *args, **kwargs): | |
"""Reduce function""" | |
raise NotImplementedError('reduce_fn not implemented for {}'.format(self.__name__)) | |
def populate_metrics_dict(self, *args, **kwargs): | |
"""Populate metrics function""" | |
raise NotImplementedError('create_dict_key not implemented for {}'.format(self.__name__)) | |
def print(self, *args, **kwargs): | |
"""Print function""" | |
raise NotImplementedError('print not implemented for {}'.format(self.__name__)) | |
def interp(dst, src, fn): | |
"""Interpolate dst to be the size of src using the interpolation function fn""" | |
if dst is None: | |
return dst | |
assert dst.dim() == src.dim() | |
if dst.dim() == 4 and not same_shape(dst.shape, src.shape): | |
dst = fn(dst, size=src) | |
return dst | |
def interp_bilinear(self, dst, src): | |
"""Bilinear interpolation""" | |
return self.interp(dst, src, self.bilinear) | |
def interp_nearest(self, dst, src): | |
"""Nearest interpolation""" | |
return self.interp(dst, src, self.nearest) | |
def reduce(self, output, dataloaders, prefixes, verbose=True): | |
"""Reduce function""" | |
reduced_data = self.reduce_metrics(output, dataloaders) | |
metrics_dict = self.create_metrics_dict(reduced_data, prefixes) | |
if verbose: | |
self.print(reduced_data, prefixes) | |
return metrics_dict | |
def create_metrics_dict(self, reduced_data, prefixes): | |
"""Create metrics dictionary""" | |
metrics_dict = {} | |
# For all datasets | |
for n, metrics in enumerate(reduced_data): | |
if metrics: # If there are calculated metrics | |
self.populate_metrics_dict(metrics, metrics_dict, prefixes[n]) | |
# Return metrics dictionary | |
return metrics_dict | |
def reduce_metrics(self, dataset_outputs, datasets, ontology=None, strict=True): | |
"""Reduce metrics""" | |
# If there is only one dataset, wrap in a list | |
if isinstance(dataset_outputs[0], dict): | |
dataset_outputs = [dataset_outputs] | |
# List storing metrics for all datasets | |
all_metrics_dict = [] | |
# Loop over all datasets and all batches | |
for batch_outputs, dataset in zip(dataset_outputs, datasets): | |
# Initialize metrics dictionary | |
metrics_dict = OrderedDict() | |
# Get length, names and dimensions | |
length = len(dataset) | |
names = [key for key in list(batch_outputs[0].keys()) if self.check_name(key)] | |
dims = [tuple(batch_outputs[0][name].size()) for name in names] | |
# Get data device | |
device = batch_outputs[0]['idx'].device | |
# Count how many times each sample was seen | |
if strict: | |
seen = torch.zeros(length, device=device) | |
for output in batch_outputs: | |
seen[output['idx']] += 1 | |
seen = reduce_value(seen, average=False, name='idx') | |
assert not np.any(seen.cpu().numpy() == 0), \ | |
'Not all samples were seen during evaluation' | |
# Reduce relevant metrics | |
for name, dim in zip(names, dims): | |
metrics = torch.zeros([length] + list(dim), device=device) | |
# Count how many times each sample was seen | |
if not strict: | |
seen = torch.zeros(length, device=device) | |
for output in batch_outputs: | |
if name in output: | |
seen[output['idx']] += 1 | |
seen = reduce_value(seen, average=False, name='idx') | |
for output in batch_outputs: | |
if name in output: | |
metrics[output['idx']] = output[name] | |
metrics = reduce_value(metrics, average=False, name=name) | |
metrics_dict[name] = self.reduce_fn(metrics, seen) | |
# Append metrics dictionary to the list | |
all_metrics_dict.append(metrics_dict) | |
# Return list of metrics dictionary | |
return all_metrics_dict | |