File size: 2,807 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import math
from typing import Dict, List, Optional

import torch
import torch.nn as nn


class BundledLoss(nn.Module):
    def __init__(
        self,
        single_modality_loss,
        multi_view_consistency_loss,
        volume_mask_loss,
        multi_view_consistency_weight: float,
        mvc_time_dependent: bool,
        mvc_steepness: float,
        modality: List,
        consistency_weight: float,
        consistency_source: str,
    ):
        super().__init__()

        self.single_modality_loss = single_modality_loss
        self.multi_view_consistency_loss = multi_view_consistency_loss
        self.volume_mask_loss = volume_mask_loss

        self.mvc_weight = multi_view_consistency_weight
        self.mvc_time_dependent = mvc_time_dependent
        self.mvc_steepness = mvc_steepness
        self.modality = modality
        self.consistency_weight = consistency_weight
        self.consistency_source = consistency_source

    def forward(
        self,
        output: Dict,
        label,
        mask,
        epoch: int = 1,
        max_epoch: int = 70,
        spixel=None,
        raw_image=None,
    ):

        total_loss = 0.0
        loss_dict = {}
        for modality in self.modality:
            single_loss = self.single_modality_loss(output[modality], label, mask)

            for k, v in single_loss.items():
                loss_dict[f"{k}/{modality}"] = v
            total_loss = total_loss + single_loss["total_loss"]

        if self.mvc_time_dependent:
            mvc_weight = self.mvc_weight * math.exp(
                -self.mvc_steepness * (1 - epoch / max_epoch) ** 2
            )
        else:
            mvc_weight = self.mvc_weight

        multi_view_consistency_loss = self.multi_view_consistency_loss(
            output, label, spixel, raw_image, mask
        )
        for k, v in multi_view_consistency_loss.items():
            if k not in ["total_loss", "tgt_map"]:
                loss_dict.update({k: v})

        if self.consistency_weight != 0.0 and self.consistency_source == "ensemble":
            for modality in self.modality:
                consisitency_loss = self.volume_mask_loss(
                    output[modality]["out_vol"], multi_view_consistency_loss["tgt_map"]
                )
                consisitency_loss = consisitency_loss["loss"]
                loss_dict[f"consistency_loss/{modality}"] = consisitency_loss
                total_loss = (
                    total_loss
                    + self.consistency_weight
                    * consisitency_loss
                    * math.exp(-self.mvc_steepness * (1 - epoch / max_epoch) ** 2)
                )

        total_loss = total_loss + mvc_weight * multi_view_consistency_loss["total_loss"]

        return {"total_loss": total_loss, **loss_dict}