File size: 1,123 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn


def get_map_label_loss(opt):
    return MapLabelLoss(opt.label_loss_on_whole_map)


class MapLabelLoss(nn.Module):
    def __init__(self, label_loss_on_whole_map=False):
        super().__init__()

        self.bce_loss = nn.BCELoss(reduction="none")
        self.label_loss_on_whole_map = label_loss_on_whole_map

    def forward(self, pred, out_map, label):
        batch_size = label.shape[0]
        if (
            self.label_loss_on_whole_map
        ):  # apply the loss on the whole map for pristine images
            total_loss = 0
            for i in range(batch_size):
                if label[i] == 0:  # pristine
                    total_loss = (
                        total_loss
                        + self.bce_loss(out_map[i, ...].mean(), label[i]).mean()
                    )
                else:  # modified
                    total_loss = total_loss + self.bce_loss(pred[i], label[i]).mean()
            loss = total_loss / batch_size
        else:
            loss = self.bce_loss(pred, label)
            loss = loss.mean()
        return {"loss": loss}