|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
class SMeasure: |
|
def __init__(self, alpha: float = 0.5): |
|
self.alpha: float = alpha |
|
self.cuda: bool = True |
|
|
|
def _centroid(self, gt): |
|
rows, cols = gt.size()[-2:] |
|
gt = gt.view(rows, cols) |
|
if gt.sum() == 0: |
|
if self.cuda: |
|
X = torch.eye(1).cuda() * round(cols / 2) |
|
Y = torch.eye(1).cuda() * round(rows / 2) |
|
else: |
|
X = torch.eye(1) * round(cols / 2) |
|
Y = torch.eye(1) * round(rows / 2) |
|
else: |
|
total = gt.sum() |
|
if self.cuda: |
|
i = torch.from_numpy(np.arange(0, cols)).cuda().float() |
|
j = torch.from_numpy(np.arange(0, rows)).cuda().float() |
|
else: |
|
i = torch.from_numpy(np.arange(0, cols)).float() |
|
j = torch.from_numpy(np.arange(0, rows)).float() |
|
X = torch.round((gt.sum(dim=0) * i).sum() / total) |
|
Y = torch.round((gt.sum(dim=1) * j).sum() / total) |
|
return X.long(), Y.long() |
|
|
|
def _ssim(self, pred, gt): |
|
gt = gt.float() |
|
h, w = pred.size()[-2:] |
|
N = h * w |
|
x = pred.mean() |
|
y = gt.mean() |
|
sigma_x2 = ((pred - x) * (pred - x)).sum() / (N - 1 + 1e-20) |
|
sigma_y2 = ((gt - y) * (gt - y)).sum() / (N - 1 + 1e-20) |
|
sigma_xy = ((pred - x) * (gt - y)).sum() / (N - 1 + 1e-20) |
|
|
|
aplha = 4 * x * y * sigma_xy |
|
beta = (x * x + y * y) * (sigma_x2 + sigma_y2) |
|
|
|
if aplha != 0: |
|
Q = aplha / (beta + 1e-20) |
|
elif aplha == 0 and beta == 0: |
|
Q = 1.0 |
|
else: |
|
Q = 0 |
|
return Q |
|
|
|
def _object(self, pred, gt): |
|
temp = pred[gt == 1] |
|
x = temp.mean() |
|
sigma_x = temp.std() |
|
score = 2.0 * x / (x * x + 1.0 + sigma_x + 1e-20) |
|
|
|
return score |
|
|
|
def _s_object(self, pred, gt): |
|
fg = torch.where(gt == 0, torch.zeros_like(pred), pred) |
|
bg = torch.where(gt == 1, torch.zeros_like(pred), 1 - pred) |
|
o_fg = self._object(fg, gt) |
|
o_bg = self._object(bg, 1 - gt) |
|
u = gt.mean() |
|
Q = u * o_fg + (1 - u) * o_bg |
|
return Q |
|
|
|
def _divide_gt(self, gt, X, Y): |
|
h, w = gt.size()[-2:] |
|
area = h * w |
|
gt = gt.view(h, w) |
|
LT = gt[:Y, :X] |
|
RT = gt[:Y, X:w] |
|
LB = gt[Y:h, :X] |
|
RB = gt[Y:h, X:w] |
|
X = X.float() |
|
Y = Y.float() |
|
w1 = X * Y / area |
|
w2 = (w - X) * Y / area |
|
w3 = X * (h - Y) / area |
|
w4 = 1 - w1 - w2 - w3 |
|
return LT, RT, LB, RB, w1, w2, w3, w4 |
|
|
|
def _divide_prediction(self, pred, X, Y): |
|
h, w = pred.size()[-2:] |
|
pred = pred.view(h, w) |
|
LT = pred[:Y, :X] |
|
RT = pred[:Y, X:w] |
|
LB = pred[Y:h, :X] |
|
RB = pred[Y:h, X:w] |
|
return LT, RT, LB, RB |
|
|
|
def _s_region(self, pred, gt): |
|
X, Y = self._centroid(gt) |
|
gt1, gt2, gt3, gt4, w1, w2, w3, w4 = self._divide_gt(gt, X, Y) |
|
p1, p2, p3, p4 = self._divide_prediction(pred, X, Y) |
|
Q1 = self._ssim(p1, gt1) |
|
Q2 = self._ssim(p2, gt2) |
|
Q3 = self._ssim(p3, gt3) |
|
Q4 = self._ssim(p4, gt4) |
|
Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4 |
|
|
|
return Q |
|
|
|
def __call__(self, pred_mask: torch.Tensor, gt_mask: torch.Tensor): |
|
assert pred_mask.shape == gt_mask.shape |
|
y = gt_mask.mean() |
|
if y == 0: |
|
x = pred_mask.mean() |
|
Q = 1.0 - x |
|
elif y == 1: |
|
x = pred_mask.mean() |
|
Q = x |
|
else: |
|
gt_mask[gt_mask >= 0.5] = 1 |
|
gt_mask[gt_mask < 0.5] = 0 |
|
|
|
Q = self.alpha * self._s_object(pred_mask, gt_mask) + ( |
|
1 - self.alpha |
|
) * self._s_region(pred_mask, gt_mask) |
|
if Q.item() < 0: |
|
Q = torch.FloatTensor([0.0]) |
|
return Q.item() |
|
|