sayakpaul's picture
sayakpaul HF staff
add files
c4b2b37
import numpy as np
import torch
from sklearn.metrics import (average_precision_score, f1_score,
precision_recall_curve, roc_curve)
SMOOTH = 1e-6
__all__ = [
"get_f1_scores",
"get_ap_scores",
"batch_pix_accuracy",
"batch_intersection_union",
"get_iou",
"get_pr",
"get_roc",
"get_ap_multiclass",
]
def get_iou(outputs: torch.Tensor, labels: torch.Tensor):
# You can comment out this line if you are passing tensors of equal shape
# But if you are passing output from UNet or something it will most probably
# be with the BATCH x 1 x H x W shape
outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
intersection = (
(outputs & labels).float().sum((1, 2))
) # Will be zero if Truth=0 or Prediction=0
union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0
iou = (intersection + SMOOTH) / (
union + SMOOTH
) # We smooth our devision to avoid 0/0
return iou.cpu().numpy()
def get_f1_scores(predict, target, ignore_index=-1):
# Tensor process
batch_size = predict.shape[0]
predict = predict.data.cpu().numpy().reshape(-1)
target = target.data.cpu().numpy().reshape(-1)
pb = predict[target != ignore_index].reshape(batch_size, -1)
tb = target[target != ignore_index].reshape(batch_size, -1)
total = []
for p, t in zip(pb, tb):
total.append(np.nan_to_num(f1_score(t, p)))
return total
def get_roc(predict, target, ignore_index=-1):
target_expand = target.unsqueeze(1).expand_as(predict)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = target.unsqueeze(1).clamp(min=0)
target_1hot = x.scatter_(1, t, 1)
batch_size = predict.shape[0]
predict = predict.data.cpu().numpy().reshape(-1)
target = target_1hot.data.cpu().numpy().reshape(-1)
pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
total = []
for p, t in zip(pb, tb):
total.append(roc_curve(t, p))
return total
def get_pr(predict, target, ignore_index=-1):
target_expand = target.unsqueeze(1).expand_as(predict)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = target.unsqueeze(1).clamp(min=0)
target_1hot = x.scatter_(1, t, 1)
batch_size = predict.shape[0]
predict = predict.data.cpu().numpy().reshape(-1)
target = target_1hot.data.cpu().numpy().reshape(-1)
pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
total = []
for p, t in zip(pb, tb):
total.append(precision_recall_curve(t, p))
return total
def get_ap_scores(predict, target, ignore_index=-1):
total = []
for pred, tgt in zip(predict, target):
target_expand = tgt.unsqueeze(0).expand_as(pred)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = tgt.unsqueeze(0).clamp(min=0).long()
target_1hot = x.scatter_(0, t, 1)
predict_flat = pred.data.cpu().numpy().reshape(-1)
target_flat = target_1hot.data.cpu().numpy().reshape(-1)
p = predict_flat[target_expand_numpy != ignore_index]
t = target_flat[target_expand_numpy != ignore_index]
total.append(np.nan_to_num(average_precision_score(t, p)))
return total
def get_ap_multiclass(predict, target):
total = []
for pred, tgt in zip(predict, target):
predict_flat = pred.data.cpu().numpy().reshape(-1)
target_flat = tgt.data.cpu().numpy().reshape(-1)
total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat)))
return total
def batch_precision_recall(predict, target, thr=0.5):
"""Batch Precision Recall
Args:
predict: input 4D tensor
target: label 4D tensor
"""
# _, predict = torch.max(predict, 1)
predict = predict > thr
predict = predict.data.cpu().numpy() + 1
target = target.data.cpu().numpy() + 1
tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
precision = float(np.nan_to_num(tp / (tp + fp)))
recall = float(np.nan_to_num(tp / (tp + fn)))
return precision, recall
def batch_pix_accuracy(predict, target):
"""Batch Pixel Accuracy
Args:
predict: input 3D tensor
target: label 3D tensor
"""
# for thr in np.linspace(0, 1, slices):
_, predict = torch.max(predict, 0)
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target) * (target > 0))
assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(predict, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 3D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
_, predict = torch.max(predict, 0)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (
area_inter <= area_union
).all(), "Intersection area should be smaller than Union area"
return area_inter, area_union
# ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
def pixel_accuracy(im_pred, im_lab):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
# We should not penalize detections in unlabeled portions of the image.
pixel_labeled = np.sum(im_lab > 0)
pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
# pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
return pixel_correct, pixel_labeled
def intersection_and_union(im_pred, im_lab, num_class):
im_pred = np.asarray(im_pred)
im_lab = np.asarray(im_lab)
# Remove classes from unlabeled pixels in gt image.
im_pred = im_pred * (im_lab > 0)
# Compute area intersection:
intersection = im_pred * (im_pred == im_lab)
area_inter, _ = np.histogram(
intersection, bins=num_class - 1, range=(1, num_class - 1)
)
# Compute area union:
area_pred, _ = np.histogram(im_pred, bins=num_class - 1, range=(1, num_class - 1))
area_lab, _ = np.histogram(im_lab, bins=num_class - 1, range=(1, num_class - 1))
area_union = area_pred + area_lab - area_inter
return area_inter, area_union