osbm commited on
Commit
7c0804b
1 Parent(s): 82b5429

Create nnUNetTrainerV2_focalLoss.py

Browse files
Files changed (1) hide show
  1. nnUNetTrainerV2_focalLoss.py +140 -0
nnUNetTrainerV2_focalLoss.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
20
+
21
+
22
+ class FocalLoss(nn.Module):
23
+ """
24
+ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
25
+ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
26
+ 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
27
+ Focal_Loss= -1*alpha*(1-pt)*log(pt)
28
+ :param num_class:
29
+ :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
30
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
31
+ focus on hard misclassified example
32
+ :param smooth: (float,double) smooth value when cross entropy
33
+ :param balance_index: (int) balance class index, should be specific when alpha is float
34
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
35
+ """
36
+
37
+ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
38
+ super(FocalLoss, self).__init__()
39
+ self.apply_nonlin = apply_nonlin
40
+ self.alpha = alpha
41
+ self.gamma = gamma
42
+ self.balance_index = balance_index
43
+ self.smooth = smooth
44
+ self.size_average = size_average
45
+
46
+ if self.smooth is not None:
47
+ if self.smooth < 0 or self.smooth > 1.0:
48
+ raise ValueError('smooth value should be in [0,1]')
49
+
50
+ def forward(self, logit, target):
51
+ if self.apply_nonlin is not None:
52
+ logit = self.apply_nonlin(logit)
53
+ num_class = logit.shape[1]
54
+
55
+ if logit.dim() > 2:
56
+ # flatten spatial dimensions N,C,d1,d2 -> N,C,m (m=d1*d2*...)
57
+ logit = logit.view(logit.size(0), logit.size(1), -1)
58
+ logit = logit.permute(0, 2, 1).contiguous()
59
+ logit = logit.view(-1, logit.size(-1))
60
+ target = torch.squeeze(target, 1)
61
+ target = target.view(-1, 1)
62
+ # print(logit.shape, target.shape)
63
+
64
+ alpha = self.alpha
65
+
66
+ if alpha is None:
67
+ alpha = torch.ones(num_class, 1)
68
+ elif isinstance(alpha, (list, np.ndarray)):
69
+ assert len(alpha) == num_class
70
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
71
+ alpha = alpha / alpha.sum()
72
+ elif isinstance(alpha, float):
73
+ alpha = torch.ones(num_class, 1)
74
+ alpha = alpha * (1 - self.alpha)
75
+ alpha[self.balance_index] = self.alpha
76
+ else:
77
+ raise TypeError(f'Unsupported alpha type: {type(alpha)}')
78
+
79
+ if alpha.device != logit.device:
80
+ alpha = alpha.to(logit.device)
81
+
82
+ idx = target.cpu().long()
83
+
84
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
85
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
86
+ if one_hot_key.device != logit.device:
87
+ one_hot_key = one_hot_key.to(logit.device)
88
+
89
+ if self.smooth:
90
+ one_hot_key = torch.clamp(
91
+ one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
92
+ pt = (one_hot_key * logit).sum(1) + self.smooth
93
+ logpt = pt.log()
94
+
95
+ gamma = self.gamma
96
+
97
+ alpha = alpha[idx]
98
+ alpha = torch.squeeze(alpha)
99
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
100
+
101
+ if self.size_average:
102
+ loss = loss.mean()
103
+ else:
104
+ loss = loss.sum()
105
+
106
+ return loss
107
+
108
+
109
+ class nnUNetTrainerV2_focalLossAlpha75(nnUNetTrainerV2):
110
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
111
+ unpack_data=True, deterministic=True, fp16=False):
112
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
113
+ deterministic, fp16)
114
+ print("Setting up FocalLoss(alpha=[0.75, 0.25], apply_nonlin=nn.Softmax())")
115
+ self.loss = FocalLoss(alpha=[0.75, 0.25], apply_nonlin=nn.Softmax())
116
+
117
+
118
+ class nnUNetTrainerV2_focalLossAlpha75_checkpoints(nnUNetTrainerV2_focalLossAlpha75):
119
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
120
+ unpack_data=True, deterministic=True, fp16=False):
121
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
122
+ deterministic, fp16)
123
+ print("Saving checkpoint every 50th epoch")
124
+ self.save_latest_only = False
125
+
126
+
127
+ class nnUNetTrainerV2_focalLossAlpha75_checkpoints2(nnUNetTrainerV2_focalLossAlpha75_checkpoints):
128
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
129
+ unpack_data=True, deterministic=True, fp16=False):
130
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
131
+ deterministic, fp16)
132
+ pass # this is just to get a new Trainer directory
133
+
134
+
135
+ class nnUNetTrainerV2_focalLossAlpha75_checkpoints3(nnUNetTrainerV2_focalLossAlpha75_checkpoints):
136
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
137
+ unpack_data=True, deterministic=True, fp16=False):
138
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
139
+ deterministic, fp16)
140
+ pass # this is just to get a new Trainer directory