File size: 3,897 Bytes
8e5d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn.functional as F
from segmentation_models_pytorch.utils import base
from segmentation_models_pytorch.base.modules import Activation

class FocalLossFunction(base.Loss):
    def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs):
        super().__init__(**kwargs)
        self.activation = Activation(activation)
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        if inputs.shape[1] == 1:  # Binary case
            inputs = torch.cat((inputs, 1 - inputs), dim=1)
            targets = torch.cat((targets, 1 - targets), dim=1)

        targets = torch.argmax(targets, dim=1)
        cross_entropy = F.cross_entropy(inputs, targets, reduction='none')
        probability = torch.exp(-cross_entropy)
        alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where(
            targets == 1, 1-self.alpha, self.alpha)

        focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy

        if self.reduction == 'mean':
            return focal_weight.mean()
        elif self.reduction == 'sum':
            return focal_weight.sum()
        return focal_weight

class TverskyLossFunction(base.Loss):
    def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None, 
                 reduction='mean', **kwargs):
        super().__init__(**kwargs)
        self.activation = Activation(activation)
        self.alpha = alpha
        self.beta = beta
        self.ignore_channels = ignore_channels
        self.reduction = reduction

    def forward(self, inputs, targets):
        if self.ignore_channels is not None:
            mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
            mask[self.ignore_channels] = False
            inputs = inputs[:, mask, ...]

        num_classes = inputs.shape[1]
        inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1 
                         else F.softmax(inputs, dim=1))

        if num_classes == 1:
            inputs_softmax = inputs_softmax.squeeze(1)
            targets = targets.squeeze(1)

        tversky_loss = 0
        for class_idx in range(num_classes):
            if num_classes == 1:
                flat_inputs = inputs_softmax.reshape(-1)
                flat_targets = targets.reshape(-1)
            else:
                flat_inputs = inputs_softmax[:, class_idx].reshape(-1)
                flat_targets = targets[:, class_idx].reshape(-1)

            intersection = (flat_inputs * flat_targets).sum()
            fps = ((1 - flat_targets) * flat_inputs).sum()
            fns = (flat_targets * (1 - flat_inputs)).sum()

            tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10
            tversky_loss += 1 - intersection / tversky_index

        if self.reduction == 'mean':
            return tversky_loss / (1 if num_classes == 1 else num_classes)
        elif self.reduction == 'sum':
            return tversky_loss
        return tversky_loss / inputs.shape[0]

class EnhancedCrossEntropy(base.Loss):
    def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs):
        super().__init__(**kwargs)
        self.activation = Activation(activation)
        self.ignore_channels = ignore_channels
        self.reduction = reduction

    def forward(self, inputs, targets):
        inputs = self.activation(inputs)

        if self.ignore_channels is not None:
            mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
            mask[self.ignore_channels] = False
            inputs = inputs[:, mask, ...]

        if targets.dim() == 4:  # Convert one-hot to class indices
            targets = torch.argmax(targets, dim=1)

        return F.cross_entropy(inputs, targets, reduction=self.reduction)