apolinario's picture
upload clipseg
48fa639
from torch.functional import Tensor
from general_utils import log
from collections import defaultdict
import numpy as np
import torch
from torch.nn import functional as nnf
class BaseMetric(object):
def __init__(self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True,
eval_validation=True):
self._names = tuple(metric_names)
self._eval_intermediate = eval_intermediate
self._eval_validation = eval_validation
self._pred_range = pred_range
self._pred_index = pred_index
self._gt_index = gt_index
self.predictions = []
self.ground_truths = []
def eval_intermediate(self):
return self._eval_intermediate
def eval_validation(self):
return self._eval_validation
def names(self):
return self._names
def add(self, predictions, ground_truth):
raise NotImplementedError
def value(self):
raise NotImplementedError
def scores(self):
# similar to value but returns dict
value = self.value()
if type(value) == dict:
return value
else:
assert type(value) in {list, tuple}
return list(zip(self.names(), self.value()))
def _get_pred_gt(self, predictions, ground_truth):
pred = predictions[self._pred_index]
gt = ground_truth[self._gt_index]
if self._pred_range is not None:
pred = pred[:, self._pred_range[0]: self._pred_range[1]]
return pred, gt
class FixedIntervalMetrics(BaseMetric):
def __init__(self, sigmoid=False, ignore_mask=False, resize_to=None,
resize_pred=None, n_values=51, custom_threshold=None):
super().__init__(('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'))
self.intersections = []
self.unions = []
# self.threshold = threshold
self.sigmoid = sigmoid
self.resize_to = resize_to
self.resize_pred = resize_pred # resize prediction to match ground truth
self.class_count = defaultdict(lambda: 0)
self.per_class = defaultdict(lambda : [0,0])
self.ignore_mask = ignore_mask
self.custom_threshold = custom_threshold
self.scores_ap = []
self.scores_iou = []
self.gts, self.preds = [], []
self.classes = []
# [1:-1] ignores 0 and 1
self.threshold_values = np.linspace(0, 1, n_values)[1:-1]
self.metrics = dict(tp=[], fp=[], fn=[], tn=[])
def add(self, pred, gt):
pred_batch = pred[0].cpu()
if self.sigmoid:
pred_batch = torch.sigmoid(pred_batch)
gt_batch = gt[0].cpu()
mask_batch = gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch))
cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch)
if self.resize_to is not None:
gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode='nearest')
pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode='bilinear', align_corners=False)
if isinstance(cls_batch, torch.Tensor):
cls_batch = cls_batch.cpu().numpy().tolist()
assert len(gt_batch) == len(pred_batch) == len(cls_batch), f'{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}'
for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch):
if self.resize_pred:
predictions = nnf.interpolate(predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode='bilinear', align_corners=True)
p = predictions.flatten()
g = ground_truth.flatten()
assert len(p) == len(g)
if mask is not None:
m = mask.flatten().bool()
p = p[m]
g = g[m]
p_sorted = p.sort()
p = p_sorted.values
g = g[p_sorted.indices]
tps, fps, fns, tns = [], [], [], []
for thresh in self.threshold_values:
valid = torch.where(p > thresh)[0]
if len(valid) > 0:
n = int(valid[0])
else:
n = len(g)
fn = int(g[:n].sum())
tp = int(g[n:].sum())
fns += [fn]
tns += [n - fn]
tps += [tp]
fps += [len(g) - n - tp]
self.metrics['tp'] += [tps]
self.metrics['fp'] += [fps]
self.metrics['fn'] += [fns]
self.metrics['tn'] += [tns]
self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls]
def value(self):
import time
t_start = time.time()
if set(self.classes) == set([None]):
all_classes = None
log.warning('classes were not provided, cannot compute mIoU')
else:
all_classes = set(int(c) for c in self.classes)
# log.info(f'compute metrics for {len(all_classes)} classes')
summed = {k: [sum([self.metrics[k][i][j]
for i in range(len(self.metrics[k]))])
for j in range(len(self.threshold_values))]
for k in self.metrics.keys()}
if all_classes is not None:
assert len(self.classes) == len(self.metrics['tp']) == len(self.metrics['fn'])
# group by class
metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes}
for i in range(len(self.metrics['tp'])):
for k in self.metrics.keys():
metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]]
# sum over all instances within the classes
summed_by_cls = {k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys()}
# Compute average precision
assert (np.array(summed['fp']) + np.array(summed['tp']) ).sum(), 'no predictions is made'
# only consider values where a prediction is made
precisions = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j]) for j in range(len(self.threshold_values))
if summed['tp'][j] + summed['fp'][j] > 0]
recalls = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))
if summed['tp'][j] + summed['fp'][j] > 0]
# remove duplicate recall-precision-pairs (and sort by recall value)
recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0]))
from scipy.integrate import simps
ap = simps(precisions, recalls)
# Compute best IoU
fgiou_scores = [summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j]) for j in range(len(self.threshold_values))]
biniou_scores = [
0.5*(summed['tp'][j] / (1 + summed['tp'][j] + summed['fp'][j] + summed['fn'][j])) +
0.5*(summed['tn'][j] / (1 + summed['tn'][j] + summed['fn'][j] + summed['fp'][j]))
for j in range(len(self.threshold_values))
]
index_0p5 = self.threshold_values.tolist().index(0.5)
index_0p1 = self.threshold_values.tolist().index(0.1)
index_0p2 = self.threshold_values.tolist().index(0.2)
index_0p3 = self.threshold_values.tolist().index(0.3)
if self.custom_threshold is not None:
index_ct = self.threshold_values.tolist().index(self.custom_threshold)
if all_classes is not None:
# mean IoU
mean_ious = [np.mean([summed_by_cls['tp'][c][j] / (1 + summed_by_cls['tp'][c][j] + summed_by_cls['fp'][c][j] + summed_by_cls['fn'][c][j])
for c in all_classes])
for j in range(len(self.threshold_values))]
mean_iou_dict = {
'miou_best': max(mean_ious) if all_classes is not None else None,
'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None,
'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None,
'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None,
'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None,
'miou_best_t': self.threshold_values[np.argmax(mean_ious)],
'mean_iou_ct': mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None,
'mean_iou_scores': mean_ious,
}
print(f'metric computation on {(len(all_classes) if all_classes is not None else "no")} classes took {time.time() - t_start:.1f}s')
return {
'ap': ap,
# fgiou
'fgiou_best': max(fgiou_scores),
'fgiou_0.5': fgiou_scores[index_0p5],
'fgiou_0.1': fgiou_scores[index_0p1],
'fgiou_0.2': fgiou_scores[index_0p2],
'fgiou_0.3': fgiou_scores[index_0p3],
'fgiou_best_t': self.threshold_values[np.argmax(fgiou_scores)],
# mean iou
# biniou
'biniou_best': max(biniou_scores),
'biniou_0.5': biniou_scores[index_0p5],
'biniou_0.1': biniou_scores[index_0p1],
'biniou_0.2': biniou_scores[index_0p2],
'biniou_0.3': biniou_scores[index_0p3],
'biniou_best_t': self.threshold_values[np.argmax(biniou_scores)],
# custom threshold
'fgiou_ct': fgiou_scores[index_ct] if self.custom_threshold is not None else None,
'biniou_ct': biniou_scores[index_ct] if self.custom_threshold is not None else None,
'ct': self.custom_threshold,
# statistics
'fgiou_scores': fgiou_scores,
'biniou_scores': biniou_scores,
'precision_recall_curve': sorted(list(set(zip(recalls, precisions)))),
'summed_statistics': summed,
'summed_by_cls_statistics': summed_by_cls,
**mean_iou_dict
}
# ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh'
# return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls}