File size: 2,122 Bytes
c310e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
"""
This file contains specific functions for computing losses on the SEG
file
"""

import torch


class SEGLossComputation(object):
    """
    This class computes the SEG loss.
    """

    def __init__(self, cfg):
        self.eps = 1e-6
        self.cfg = cfg

    def __call__(self, preds, targets):
        """
        Arguments:
            preds (Tensor)
            targets (list[Tensor])
            masks (list[Tensor])
        Returns:
            seg_loss (Tensor)
        """
        image_size = (preds.shape[2], preds.shape[3])
        segm_targets, masks = self.prepare_targets(targets, image_size)
        device = preds.device
        segm_targets = segm_targets.float().to(device)
        masks = masks.float().to(device)
        seg_loss = self.dice_loss(preds, segm_targets, masks)
        return seg_loss

    def dice_loss(self, pred, gt, m):
        intersection = torch.sum(pred * gt * m)
        union = torch.sum(pred * m) + torch.sum(gt * m) + self.eps
        loss = 1 - 2.0 * intersection / union
        return loss

    def project_masks_on_image(self, mask_polygons, labels, shrink_ratio, image_size):
        seg_map, training_mask = mask_polygons.convert_seg_map(
            labels, shrink_ratio, image_size, self.cfg.MODEL.SEG.IGNORE_DIFFICULT
        )
        return torch.from_numpy(seg_map), torch.from_numpy(training_mask)

    def prepare_targets(self, targets, image_size):
        segms = []
        training_masks = []
        for target_per_image in targets:
            segmentation_masks = target_per_image.get_field("masks")
            labels = target_per_image.get_field("labels")
            seg_maps_per_image, training_masks_per_image = self.project_masks_on_image(
                segmentation_masks, labels, self.cfg.MODEL.SEG.SHRINK_RATIO, image_size
            )
            segms.append(seg_maps_per_image)
            training_masks.append(training_masks_per_image)
        return torch.stack(segms), torch.stack(training_masks)


def make_seg_loss_evaluator(cfg):
    loss_evaluator = SEGLossComputation(cfg)
    return loss_evaluator