File size: 3,652 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn


class Loss(nn.Module):
    def __init__(
        self,
        map_label_loss,
        volume_label_loss,
        map_mask_loss,
        volume_mask_loss,
        consistency_loss,
        entropy_loss,
        map_label_weight,
        volume_label_weight,
        map_mask_weight,
        volume_mask_weight,
        consistency_weight,
        map_entropy_weight,
        volume_entropy_weight,
        consistency_source,
    ):
        super().__init__()

        self.map_label_loss = map_label_loss
        self.volume_label_loss = volume_label_loss
        self.map_mask_loss = map_mask_loss
        self.volume_mask_loss = volume_mask_loss
        self.consistency_loss = consistency_loss
        self.entropy_loss = entropy_loss

        self.map_label_weight = map_label_weight
        self.volume_label_weight = volume_label_weight
        self.map_mask_weight = map_mask_weight
        self.volume_mask_weight = volume_mask_weight
        self.consistency_weight = consistency_weight
        self.map_entropy_weight = map_entropy_weight
        self.volume_entropy_weight = volume_entropy_weight
        self.consistency_source = consistency_source

    def forward(self, output, label, mask):
        total_loss = 0.0
        loss_dict = {}

        # --- label loss ---
        label = label.float()
        # compute map label loss anyway
        map_label_loss = self.map_label_loss(
            output["map_pred"], output["out_map"], label
        )["loss"]
        total_loss = total_loss + self.map_label_weight * map_label_loss
        loss_dict.update({"map_label_loss": map_label_loss})

        if self.volume_label_weight != 0.0:
            volume_label_loss = self.volume_label_loss(
                output["vol_pred"], output["out_vol"], label
            )["loss"]
            total_loss = total_loss + self.volume_label_weight * volume_label_loss
            loss_dict.update({"vol_label_loss": volume_label_loss})

        # --- mask loss ---
        # compute map mask loss anyway
        map_mask_loss = self.map_mask_loss(output["out_map"], mask)["loss"]
        total_loss = total_loss + self.map_mask_weight * map_mask_loss
        loss_dict.update({"map_mask_loss": map_mask_loss})

        if self.volume_mask_weight != 0.0:
            volume_mask_loss = self.volume_mask_loss(output["out_vol"], mask)["loss"]
            total_loss = total_loss + self.volume_mask_weight * volume_mask_loss
            loss_dict.update({"vol_mask_loss": volume_mask_loss})

        # --- self-consistency loss ---
        if self.consistency_weight != 0.0 and self.consistency_source == "self":
            consistency_loss = self.consistency_loss(
                output["out_vol"], output["out_map"], label
            )
            consistency_loss = consistency_loss["loss"]
            total_loss = total_loss + self.consistency_weight * consistency_loss
            loss_dict.update({"consistency_loss": consistency_loss})

        # --- entropy loss ---
        if self.map_entropy_weight != 0.0:
            map_entropy_loss = self.entropy_loss(output["out_map"])["loss"]
            total_loss = total_loss + self.map_entropy_weight * map_entropy_loss
            loss_dict.update({"map_entropy_loss": map_entropy_loss})

        if self.volume_entropy_weight != 0:
            volume_entropy_loss = self.entropy_loss(output["out_vol"])["loss"]
            total_loss = total_loss + self.volume_entropy_weight * volume_entropy_loss
            loss_dict.update({"vol_entropy_loss": volume_entropy_loss})

        loss_dict.update({"total_loss": total_loss})
        return loss_dict