File size: 4,914 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from typing import Dict, List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from skimage import segmentation


def get_multi_view_consistency_loss(opt):
    loss = MultiViewConsistencyLoss(
        opt.mvc_soft,
        opt.mvc_zeros_on_au,
        opt.mvc_single_weight,
        opt.modality,
        opt.mvc_spixel,
        opt.mvc_num_spixel,
    )
    return loss


class MultiViewConsistencyLoss(nn.Module):
    def __init__(
        self,
        soft: bool,
        zeros_on_au: bool,
        single_weight: Dict,
        modality: List,
        spixel: bool = False,
        num_spixel: int = 100,
        eps: float = 1e-4,
    ):
        super().__init__()
        self.soft = soft
        self.zeros_on_au = zeros_on_au
        self.single_weight = single_weight
        self.modality = modality
        self.spixel = spixel
        self.num_spixel = num_spixel
        self.eps = eps

        self.mse_loss = nn.MSELoss(reduction="mean")

    def forward(self, output: Dict, label, spixel=None, image=None, mask=None):

        tgt_map = torch.zeros_like(
            output[self.modality[0]]["out_map"], requires_grad=False
        )
        with torch.no_grad():
            for modality in self.modality:
                weight = self.single_weight[modality.lower()]
                tgt_map = tgt_map + weight * output[modality]["out_map"]

        if self.spixel:
            # raw_tgt_map = tgt_map.clone()
            tgt_map = get_spixel_tgt_map(tgt_map, spixel)

        if not self.soft:
            for b in range(tgt_map.shape[0]):
                if tgt_map[b, ...].max() <= 0.5 and label[b] == 1.0:
                    tgt_map[b, ...][
                        torch.where(tgt_map[b, ...] == torch.max(tgt_map[b, ...]))
                    ] = 1.0
            tgt_map[torch.where(tgt_map > 0.5)] = 1
            tgt_map[torch.where(tgt_map <= 0.5)] = 0
            tgt_map[torch.where(label == 0.0)[0], ...] = 0.0

        if self.zeros_on_au:
            tgt_map[torch.where(label == 0.0)[0], ...] = 0.0

        total_loss = 0.0
        loss_dict = {}
        for modality in self.modality:
            loss = self.mse_loss(output[modality]["out_map"], tgt_map)
            loss_dict[f"multi_view_consistency_loss_{modality}"] = loss
            total_loss = total_loss + loss

        return {**loss_dict, "tgt_map": tgt_map, "total_loss": total_loss}

    def _save(
        self,
        spixel: torch.Tensor,
        image: torch.Tensor,
        mask: torch.Tensor,
        tgt_map: torch.Tensor,
        raw_tgt_map: torch.Tensor,
        out_path: str = "tmp/spixel_tgt_map.png",
    ):
        spixel = spixel.permute(0, 2, 3, 1).detach().cpu().numpy()
        image = image.permute(0, 2, 3, 1).detach().cpu().numpy()
        mask = mask.permute(0, 2, 3, 1).detach().cpu().numpy() * 255.0
        tgt_map = tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0
        raw_tgt_map = (
            raw_tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0
        )
        bn = spixel.shape[0]
        i = 1
        for b in range(bn):
            plt.subplot(bn, 5, i)
            i += 1
            plt.imshow(image[b])
            plt.axis("off")
            plt.title("image")
            plt.subplot(bn, 5, i)
            i += 1
            plt.imshow(mask[b])
            plt.axis("off")
            plt.title("mask")
            plt.subplot(bn, 5, i)
            i += 1
            plt.imshow(spixel[b])
            plt.axis("off")
            plt.title("superpixel")
            plt.subplot(bn, 5, i)
            i += 1
            plt.imshow(raw_tgt_map[b])
            plt.axis("off")
            plt.title("raw target map")
            plt.subplot(bn, 5, i)
            i += 1
            plt.imshow(tgt_map[b])
            plt.axis("off")
            plt.title("target map")
        plt.tight_layout()
        plt.savefig(out_path, dpi=300)
        plt.close()


def get_spixel_tgt_map(weighted_sum, spixel):
    b, _, h, w = weighted_sum.shape
    spixel_tgt_map = torch.zeros_like(weighted_sum, requires_grad=False)

    for bidx in range(b):
        spixel_indices = spixel[bidx, ...].unique()
        # num_spixel = spixel_idx.shape[0]
        for spixel_idx in spixel_indices.tolist():
            area = (spixel[bidx, ...] == spixel_idx).sum()
            weighted_sum_in_area = weighted_sum[bidx, ...][
                torch.where(spixel[bidx, ...] == spixel_idx)
            ].sum()
            avg_area = weighted_sum_in_area / area
            # this is soft map, and the threshold process will be conducted in the forward function
            spixel_tgt_map[bidx][
                torch.where(spixel[bidx, ...] == spixel_idx)
            ] = avg_area

    return spixel_tgt_map


if __name__ == "__main__":
    mvc_loss = MultiViewConsistencyLoss(True, True, [1, 1, 2])
    print("a")