Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
This file contains specific functions for computing losses on the SEG | |
file | |
""" | |
import torch | |
class SEGLossComputation(object): | |
""" | |
This class computes the SEG loss. | |
""" | |
def __init__(self, cfg): | |
self.eps = 1e-6 | |
self.cfg = cfg | |
def __call__(self, preds, targets): | |
""" | |
Arguments: | |
preds (Tensor) | |
targets (list[Tensor]) | |
masks (list[Tensor]) | |
Returns: | |
seg_loss (Tensor) | |
""" | |
image_size = (preds.shape[2], preds.shape[3]) | |
segm_targets, masks = self.prepare_targets(targets, image_size) | |
device = preds.device | |
segm_targets = segm_targets.float().to(device) | |
masks = masks.float().to(device) | |
seg_loss = self.dice_loss(preds, segm_targets, masks) | |
return seg_loss | |
def dice_loss(self, pred, gt, m): | |
intersection = torch.sum(pred * gt * m) | |
union = torch.sum(pred * m) + torch.sum(gt * m) + self.eps | |
loss = 1 - 2.0 * intersection / union | |
return loss | |
def project_masks_on_image(self, mask_polygons, labels, shrink_ratio, image_size): | |
seg_map, training_mask = mask_polygons.convert_seg_map( | |
labels, shrink_ratio, image_size, self.cfg.MODEL.SEG.IGNORE_DIFFICULT | |
) | |
return torch.from_numpy(seg_map), torch.from_numpy(training_mask) | |
def prepare_targets(self, targets, image_size): | |
segms = [] | |
training_masks = [] | |
for target_per_image in targets: | |
segmentation_masks = target_per_image.get_field("masks") | |
labels = target_per_image.get_field("labels") | |
seg_maps_per_image, training_masks_per_image = self.project_masks_on_image( | |
segmentation_masks, labels, self.cfg.MODEL.SEG.SHRINK_RATIO, image_size | |
) | |
segms.append(seg_maps_per_image) | |
training_masks.append(training_masks_per_image) | |
return torch.stack(segms), torch.stack(training_masks) | |
def make_seg_loss_evaluator(cfg): | |
loss_evaluator = SEGLossComputation(cfg) | |
return loss_evaluator | |