Spaces:
Sleeping
Sleeping
print("Importing standard...") | |
from abc import ABC, abstractmethod | |
print("Importing external...") | |
import torch | |
from torch.nn.functional import binary_cross_entropy | |
# from matplotlib import pyplot as plt | |
print("Importing internal...") | |
from utils import preprocess_masks_features, get_row_col, symlog, calculate_iou | |
######### BINARY LOSSES ############### | |
def my_lovasz_hinge(logits, gt, downsample=False): | |
if downsample: | |
offset = int(torch.randint(downsample - 1, (1,))) | |
logits, gt = logits[:, offset::downsample], gt[:, offset::downsample] | |
# B, HW | |
gt = 1.0 * gt # go float | |
areas = gt.sum(dim=1, keepdims=True) # B, 1 | |
# per_image = True, ignore = None | |
signs = 2 * gt - 1 | |
errors = 1 - logits * signs | |
errors_sorted, perm = torch.sort(errors, dim=1, descending=True) | |
gt_sorted = torch.gather(gt, 1, perm) # B, HW | |
# lovasz grad | |
intersection = areas - gt_sorted.cumsum(dim=1) # B, HW | |
union = areas + (1 - gt_sorted).cumsum(dim=1) # B, HW | |
jaccard = 1 - intersection / union # B, HW | |
jaccard[:, 1:] = jaccard[:, 1:] - jaccard[:, :-1] | |
loss = (torch.relu(errors_sorted) * jaccard).sum(dim=1) # B, | |
return torch.nanmean(loss) | |
def focal_loss(scores, targets, alpha=0.25, gamma=2): | |
p = scores | |
ce_loss = binary_cross_entropy(p, targets, reduction="none") | |
p_t = p * targets + (1 - p) * (1 - targets) | |
loss = ce_loss * ((1 - p_t) ** gamma) | |
if alpha >= 0: | |
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
loss = alpha_t * loss | |
return loss | |
# also binary_cross_entropy and lovasz | |
########## SUBFUNCTIONS ######################3 | |
def get_distances(features, refs, sigma, norm_p, square_distances, H, W): | |
# features: B, 1, F, HW | |
# refs: B, M, F, 1 | |
# sigma: B, M, 1, 1 | |
B, M = refs.shape[0], refs.shape[1] | |
distances = torch.norm( | |
features - refs, dim=2, p=norm_p, keepdim=True | |
) # B, M, 1, H*W | |
distances = distances**2 if square_distances else distances | |
distances = (distances / (2 * sigma**2)).reshape(B, M, H * W) | |
return distances | |
def activate(features, masks, activation, use_sigma, offset_pos, ret_prediction): | |
# sigmoid is very similar to exp | |
# prepare features | |
assert activation in ["sigmoid", "symlog"] | |
if masks is None: # when inferencing | |
B, M = 1, 1 | |
F, N = sorted(features.shape) | |
H, W = [int(N ** (0.5))] * 2 | |
features = features.reshape(1, 1, -1, H * W) | |
else: | |
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) | |
# features: B, 1, F, H*W | |
# masks: B, M, 1, H*W | |
if use_sigma: | |
sigma = torch.nn.functional.softplus(features)[:, :, -1:] # B, 1, 1, H*W | |
features = features[:, :, :-1] | |
F = features.shape[2] | |
else: | |
sigma = 1 | |
features = symlog(features) if activation == "symlog" else torch.sigmoid(features) | |
if offset_pos: | |
assert F >= 2 | |
row, col = get_row_col(H, W, features.device) | |
row = row.reshape(1, 1, 1, H, 1).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) | |
col = col.reshape(1, 1, 1, 1, W).expand(B, 1, 1, H, W).reshape(B, 1, 1, H * W) | |
positional_features = torch.cat([row, col], dim=2) # B, 1, 2, H*W | |
features[:, :, :2] = features[:, :, :2] + positional_features | |
prediction = features.reshape(B, 1, -1, H, W) if ret_prediction else None | |
if masks is None: | |
features = features.reshape(-1, H * W) | |
sigma = sigma.reshape(-1, H * W) if use_sigma else 1 | |
return features, sigma, H, W | |
return features, masks, sigma, prediction, B, M, F, H, W | |
class AbstractLoss(ABC): | |
def loss(features, masks, ret_prediction=False, **kwargs): | |
pass | |
def get_mask_from_query(features, sindex, **kwargs): | |
pass | |
class IISLoss(AbstractLoss): | |
def loss(features, masks, ret_prediction=False, K=3, logger=None): | |
features, masks, sigma, prediction, B, M, F, H, W = activate( | |
features, masks, "symlog", False, False, ret_prediction | |
) | |
rindices = torch.randperm(H * W, device=masks.device) | |
# the following should work if all masks have more than K pixels | |
sindices = torch.stack( | |
[ | |
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) | |
for b in range(B) | |
] | |
) # B, M, K | |
feats_at_sindices = torch.gather( | |
features.permute(0, 3, 1, 2).expand(B, H * W, K, F), | |
dim=1, | |
index=sindices.reshape(B, M, K, 1).expand(B, M, K, F), | |
) # B, M, K, F | |
feats_at_sindices = feats_at_sindices.reshape(B, M, K, F, 1) # B, M, K, F, 1 | |
dists = get_distances( | |
features, feats_at_sindices.reshape(B, M * K, F, 1), sigma, 2, True, H, W | |
) | |
score = torch.exp(-dists) # B, M*K, H*W [0, 1] | |
targets = ( | |
masks.expand(B, M, K, H * W).reshape(B, M * K, H * W).float() | |
) # B, M, K, H*W | |
floss = focal_loss(score, targets).mean() | |
lloss = my_lovasz_hinge( | |
score.view(B * M * K, H * W) * 2 - 1, | |
targets.view(B * M * K, H * W), | |
) | |
loss = floss + lloss | |
return loss, prediction | |
def get_mask_from_query(features, sindex): | |
features, _, H, W = activate(features, None, "symlog", False, False, False) | |
F = features.shape[0] | |
query_feat = features[:, sindex] | |
dists = get_distances( | |
features.reshape(1, 1, F, H * W), | |
query_feat.reshape(1, 1, F, 1), | |
1, | |
2, | |
True, | |
H, | |
W, | |
) | |
score = torch.exp(-dists) # 1, H*W | |
pred = score > 0.5 | |
return pred | |
def iis_iou(features, masks, get_mask_from_query, K=20): | |
masks, features, M, B, H, W, F = preprocess_masks_features(masks, features) | |
# features: B, 1, F, H*W | |
# masks: B, M, 1, H*W | |
rindices = torch.randperm(H * W).to(masks.device) | |
sindices = torch.stack( | |
[ | |
torch.stack([rindices[masks[b, m, 0, rindices]][:K] for m in range(M)]) | |
for b in range(B) | |
] | |
) # B, M, K | |
cum_iou, n_samples = 0, 0 | |
for b in range(B): | |
for m in range(M): | |
for k in range(K): | |
sindex = sindices[b, m, k] | |
pred = get_mask_from_query(features[b, 0], sindex) | |
iou = calculate_iou(pred, masks[b, m, 0, :]) | |
cum_iou += iou | |
n_samples += 1 | |
return cum_iou / n_samples | |
losses_names = [ | |
"iis", | |
] | |
# | |
def get_loss_class(loss_name): | |
if loss_name == "iis": | |
return IISLoss | |
else: | |
raise NotImplementedError | |
def get_get_mask_from_query(loss_name): | |
loss_class = get_loss_class(loss_name) | |
return loss_class.get_mask_from_query | |
def get_loss(loss_name): | |
loss_class = get_loss_class(loss_name) | |
return loss_class.loss | |