franchesoni's picture
v0
e1b51e5
raw
history blame
No virus
7.06 kB
print("Importing standard...")
from abc import ABC, abstractmethod
print("Importing external...")
import torch
from torch.nn.functional import binary_cross_entropy
# from matplotlib import pyplot as plt
print("Importing internal...")
from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou
######### BINARY LOSSES ###############
def my_lovasz_hinge(logits, gt, downsample=False):
if downsample:
offset = int(torch.randint(downsample - 1, (1,)))
logits, gt = logits[:, offset::downsample], gt[:, offset::downsample]
# B, HW
gt = 1.0 * gt # go float
areas = gt.sum(dim=1, keepdims=True) # B, 1
# per_image = True, ignore = None
signs = 2 * gt - 1
errors = 1 - logits * signs
errors_sorted, perm = torch.sort(errors, dim=1, descending=True)
gt_sorted = torch.gather(gt, 1, perm) # B, HW
# lovasz grad
intersection = areas - gt_sorted.cumsum(dim=1) # B, HW
union = areas + (1 - gt_sorted).cumsum(dim=1) # B, HW
jaccard = 1 - intersection / union # B, HW
jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1]
loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) # B,
return torch.nanmean(loss)
def focal_loss(scores, targets, alpha=0.25, gamma=2):
p = scores
ce_loss = binary_cross_entropy(p, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss
# also binary_cross_entropy and lovasz
########## SUBFUNCTIONS ######################3
def get_distances(features, refs, sigma, norm_p, square_distances, H, W):
# features: B, 1, F, HW
# refs: B, M, F, 1
# sigma: B, M, 1, 1
B, M = refs.shape[0], refs.shape[1]
distances = torch.norm(
features - refs, dim=2, p=norm_p, keepdim=True
) # B, M, 1, H*W
distances = distances**2 if square_distances else distances
distances = (distances / (2 * sigma**2)).reshape(B, M, H * W)
return distances
def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction):
# sigmoid is very similar to exp
# prepare features
assert activation in ["sigmoid", "symlog"]
if masks is None: # when inferencing
B, M = 1, 1
F, N = sorted(features.shape)
H, W = [int(N ** (0.5))] * 2
features = features.reshape(1, 1, -1, H * W)
else:
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features)
# features: B, 1, F, H*W
# masks: B, M, 1, H*W
if use_sigma:
sigma = torch.nn.functional.softplus(features)[:, :, -1:] # B, 1, 1, H*W
features = features[:, :, :-1]
F = features.shape[2]
else:
sigma = 1
features = symlog(features) if activation == "symlog" else torch.sigmoid(features)
if offset_pos:
assert F >= 2
row, col = get_row_col(H, W, features.device)
row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W)
col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W)
positional_features = torch.cat([row, col], dim=2) # B, 1, 2, H*W
features[:, :, :2] = features[:, :, :2] + positional_features
prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None
if masks is None:
features = features.reshape(-1, H * W)
sigma = sigma.reshape(-1, H * W) if use_sigma else 1
return features, sigma, H, W
return features, masks, sigma, prediction, B, M, F, H, W
class AbstractLoss(ABC):
@staticmethod
@abstractmethod
def loss(features, masks, ret_prediction=False, **kwargs):
pass
@staticmethod
@abstractmethod
def get_mask_from_query(features, sindex, **kwargs):
pass
class IISLoss(AbstractLoss):
@staticmethod
def loss(features, masks, ret_prediction=False, K=3, logger=None):
features, masks, sigma, prediction, B, M, F, H, W = activate(
features, masks, "symlog", False, False, ret_prediction
)
rindices = torch.randperm(H * W, device=masks.device)
# the following should work if all masks have more than K pixels
sindices = torch.stack(
[
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)])
for b in range(B)
]
) # B, M, K
feats_at_sindices = torch.gather(
features.permute(0, 3, 1, 2).expand(B, H * W, K, F),
dim=1,
index=sindices.reshape(B, M, K, 1).expand(B, M, K, F),
) # B, M, K, F
feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) # B, M, K, F, 1
dists = get_distances(
features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W
)
score = torch.exp(-dists) # B, M*K, H*W [0, 1]
targets = (
masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float()
) # B, M, K, H*W
floss = focal_loss(score, targets).mean()
lloss = my_lovasz_hinge(
score.view(B * M * K, H * W) * 2 - 1,
targets.view(B * M * K, H * W),
)
loss = floss + lloss
return loss, prediction
@staticmethod
def get_mask_from_query(features, sindex):
features, _, H, W = activate(features, None, "symlog", False, False, False)
F = features.shape[0]
query_feat = features[:, sindex]
dists = get_distances(
features.reshape(1, 1, F, H * W),
query_feat.reshape(1, 1, F, 1),
1,
2,
True,
H,
W,
)
score = torch.exp(-dists) # 1, H*W
pred = score > 0.5
return pred
def iis_iou(features, masks, get_mask_from_query, K=20):
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features)
# features: B, 1, F, H*W
# masks: B, M, 1, H*W
rindices = torch.randperm(H * W).to(masks.device)
sindices = torch.stack(
[
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)])
for b in range(B)
]
) # B, M, K
cum_iou, n_samples = 0, 0
for b in range(B):
for m in range(M):
for k in range(K):
sindex = sindices[b, m, k]
pred = get_mask_from_query(features[b, 0], sindex)
iou = calculate_iou(pred, masks[b, m, 0, :])
cum_iou += iou
n_samples += 1
return cum_iou / n_samples
losses_names = [
"iis",
]
#
def get_loss_class(loss_name):
if loss_name == "iis":
return IISLoss
else:
raise NotImplementedError
def get_get_mask_from_query(loss_name):
loss_class = get_loss_class(loss_name)
return loss_class.get_mask_from_query
def get_loss(loss_name):
loss_class = get_loss_class(loss_name)
return loss_class.loss