|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from utils.geometry_utils import edge_acc |
|
|
|
|
|
class CornerCriterion(nn.Module): |
|
def __init__(self, image_size): |
|
super().__init__() |
|
self.loss_rate = 9 |
|
|
|
def forward(self, outputs_s1, targets, gauss_targets, epoch=0): |
|
|
|
preds_s1 = (outputs_s1 >= 0.5).float() |
|
pos_target_ids = torch.where(targets == 1) |
|
correct = (preds_s1[pos_target_ids] == targets[pos_target_ids]).float().sum() |
|
recall_s1 = correct / len(pos_target_ids[0]) |
|
|
|
rate = self.loss_rate |
|
|
|
loss_weight = (gauss_targets > 0.5).float() * rate + 1 |
|
loss_s1 = F.binary_cross_entropy(outputs_s1, gauss_targets, weight=loss_weight, reduction='none') |
|
loss_s1 = loss_s1.sum(-1).sum(-1).mean() |
|
|
|
return loss_s1, recall_s1 |
|
|
|
|
|
class EdgeCriterion(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.edge_loss = nn.CrossEntropyLoss(weight=torch.tensor([0.33, 1.0]).cuda(), reduction='none') |
|
|
|
def forward(self, logits_s1, logits_s2_hybrid, logits_s2_rel, s2_ids, s2_edge_mask, edge_labels, edge_lengths, |
|
edge_mask, s2_gt_values): |
|
|
|
s1_losses = self.edge_loss(logits_s1, edge_labels) |
|
s1_losses[torch.where(edge_mask == True)] = 0 |
|
s1_losses = s1_losses[torch.where(s1_losses > 0)].sum() / edge_mask.shape[0] |
|
gt_values = torch.ones_like(edge_mask).long() * 2 |
|
s1_acc = edge_acc(logits_s1, edge_labels, edge_lengths, gt_values) |
|
|
|
|
|
s2_labels = torch.gather(edge_labels, 1, s2_ids) |
|
|
|
|
|
s2_losses_hybrid = self.edge_loss(logits_s2_hybrid, s2_labels) |
|
s2_losses_hybrid[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0 |
|
|
|
s2_losses_hybrid = s2_losses_hybrid[torch.where(s2_losses_hybrid > 0)].sum() / s2_edge_mask.shape[0] |
|
s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) |
|
|
|
s2_acc_hybrid = edge_acc(logits_s2_hybrid, s2_labels, s2_edge_lengths, s2_gt_values) |
|
|
|
|
|
s2_losses_rel = self.edge_loss(logits_s2_rel, s2_labels) |
|
s2_losses_rel[torch.where((s2_edge_mask == True) | (s2_gt_values != 2))] = 0 |
|
|
|
s2_losses_rel = s2_losses_rel[torch.where(s2_losses_rel > 0)].sum() / s2_edge_mask.shape[0] |
|
s2_edge_lengths = (s2_edge_mask == 0).sum(dim=-1) |
|
|
|
s2_acc_rel = edge_acc(logits_s2_rel, s2_labels, s2_edge_lengths, s2_gt_values) |
|
|
|
return s1_losses, s1_acc, s2_losses_hybrid, s2_acc_hybrid, s2_losses_rel, s2_acc_rel |
|
|