Hila's picture
init commit
7754b29
raw
history blame
7.47 kB
import numpy as np
import torch
from sklearn.metrics import f1_score, average_precision_score
from sklearn.metrics import precision_recall_curve, roc_curve
SMOOTH = 1e-6
__all__ = ['get_f1_scores', 'get_ap_scores', 'batch_pix_accuracy', 'batch_intersection_union', 'get_iou', 'get_pr',
'get_roc', 'get_ap_multiclass']
def get_iou(outputs: torch.Tensor, labels: torch.Tensor):
# You can comment out this line if you are passing tensors of equal shape
# But if you are passing output from UNet or something it will most probably
# be with the BATCH x 1 x H x W shape
outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0
iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
return iou.cpu().numpy()
def get_f1_scores(predict, target, ignore_index=-1):
# Tensor process
batch_size = predict.shape[0]
predict = predict.data.cpu().numpy().reshape(-1)
target = target.data.cpu().numpy().reshape(-1)
pb = predict[target != ignore_index].reshape(batch_size, -1)
tb = target[target != ignore_index].reshape(batch_size, -1)
total = []
for p, t in zip(pb, tb):
total.append(np.nan_to_num(f1_score(t, p)))
return total
def get_roc(predict, target, ignore_index=-1):
target_expand = target.unsqueeze(1).expand_as(predict)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = target.unsqueeze(1).clamp(min=0)
target_1hot = x.scatter_(1, t, 1)
batch_size = predict.shape[0]
predict = predict.data.cpu().numpy().reshape(-1)
target = target_1hot.data.cpu().numpy().reshape(-1)
pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
total = []
for p, t in zip(pb, tb):
total.append(roc_curve(t, p))
return total
def get_pr(predict, target, ignore_index=-1):
target_expand = target.unsqueeze(1).expand_as(predict)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = target.unsqueeze(1).clamp(min=0)
target_1hot = x.scatter_(1, t, 1)
batch_size = predict.shape[0]
predict = predict.data.cpu().numpy().reshape(-1)
target = target_1hot.data.cpu().numpy().reshape(-1)
pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
total = []
for p, t in zip(pb, tb):
total.append(precision_recall_curve(t, p))
return total
def get_ap_scores(predict, target, ignore_index=-1):
total = []
for pred, tgt in zip(predict, target):
target_expand = tgt.unsqueeze(0).expand_as(pred)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = tgt.unsqueeze(0).clamp(min=0).long()
target_1hot = x.scatter_(0, t, 1)
predict_flat = pred.data.cpu().numpy().reshape(-1)
target_flat = target_1hot.data.cpu().numpy().reshape(-1)
p = predict_flat[target_expand_numpy != ignore_index]
t = target_flat[target_expand_numpy != ignore_index]
total.append(np.nan_to_num(average_precision_score(t, p)))
return total
def get_ap_multiclass(predict, target):
total = []
for pred, tgt in zip(predict, target):
predict_flat = pred.data.cpu().numpy().reshape(-1)
target_flat = tgt.data.cpu().numpy().reshape(-1)
total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat)))
return total
def batch_precision_recall(predict, target, thr=0.5):
"""Batch Precision Recall
Args:
predict: input 4D tensor
target: label 4D tensor
"""
# _, predict = torch.max(predict, 1)
predict = predict > thr
predict = predict.data.cpu().numpy() + 1
target = target.data.cpu().numpy() + 1
tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
precision = float(np.nan_to_num(tp / (tp + fp)))
recall = float(np.nan_to_num(tp / (tp + fn)))
return precision, recall
def batch_pix_accuracy(predict, target):
"""Batch Pixel Accuracy
Args:
predict: input 3D tensor
target: label 3D tensor
"""
# for thr in np.linspace(0, 1, slices):
_, predict = torch.max(predict, 0)
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target) * (target > 0))
assert pixel_correct <= pixel_labeled, \
"Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(predict, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 3D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
_, predict = torch.max(predict, 0)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Intersection area should be smaller than Union area"
return area_inter, area_union
# ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
def pixel_accuracy(im_pred, im_lab):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
# We should not penalize detections in unlabeled portions of the image.
pixel_labeled = np.sum(im_lab > 0)
pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
# pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
return pixel_correct, pixel_labeled
def intersection_and_union(im_pred, im_lab, num_class):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
im_pred = im_pred * (im_lab > 0)
# Compute area intersection:
intersection = im_pred * (im_pred == im_lab)
area_inter, _ = np.histogram(intersection, bins=num_class - 1,
range=(1, num_class - 1))
# Compute area union:
area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
range=(1, num_class - 1))
area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
range=(1, num_class - 1))
area_union = area_pred + area_lab - area_inter
return area_inter, area_union