osbm commited on
Commit
82b5429
1 Parent(s): 0ff9694

Create nnUNetTrainerV2_Loss_FL_and_CE.py

Browse files
Files changed (1) hide show
  1. nnUNetTrainerV2_Loss_FL_and_CE.py +77 -0
nnUNetTrainerV2_Loss_FL_and_CE.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch import nn
16
+
17
+ from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
18
+ from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
19
+ from nnunet.training.network_training.nnUNet_variants.loss_function.nnUNetTrainerV2_focalLoss import FocalLoss
20
+ # TODO: replace FocalLoss by fixed implemetation (and set smooth=0 in that one?)
21
+
22
+
23
+ class FL_and_CE_loss(nn.Module):
24
+ def __init__(self, fl_kwargs=None, ce_kwargs=None, alpha=0.5, aggregate="sum"):
25
+ super(FL_and_CE_loss, self).__init__()
26
+ if fl_kwargs is None:
27
+ fl_kwargs = {}
28
+ if ce_kwargs is None:
29
+ ce_kwargs = {}
30
+
31
+ self.aggregate = aggregate
32
+ self.fl = FocalLoss(apply_nonlin=nn.Softmax(), **fl_kwargs)
33
+ self.ce = RobustCrossEntropyLoss(**ce_kwargs)
34
+ self.alpha = alpha
35
+
36
+ def forward(self, net_output, target):
37
+ fl_loss = self.fl(net_output, target)
38
+ ce_loss = self.ce(net_output, target)
39
+ if self.aggregate == "sum":
40
+ result = self.alpha*fl_loss + (1-self.alpha)*ce_loss
41
+ else:
42
+ raise NotImplementedError("nah son")
43
+ return result
44
+
45
+
46
+ class nnUNetTrainerV2_Loss_FL_and_CE_checkpoints(nnUNetTrainerV2):
47
+ """
48
+ Set loss to FL + CE and set checkpoints
49
+ """
50
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
51
+ unpack_data=True, deterministic=True, fp16=False):
52
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
53
+ deterministic, fp16)
54
+ self.loss = FL_and_CE_loss(alpha=0.5)
55
+ self.save_latest_only = False
56
+
57
+
58
+ class nnUNetTrainerV2_Loss_FL_and_CE_checkpoints2(nnUNetTrainerV2_Loss_FL_and_CE_checkpoints):
59
+ """
60
+ Each run is stored in a folder with the training class name in it. This simply creates a new folder,
61
+ to allow investigating the variability between restarts.
62
+ """
63
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
64
+ unpack_data=True, deterministic=True, fp16=False):
65
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
66
+ deterministic, fp16)
67
+
68
+
69
+ class nnUNetTrainerV2_Loss_FL_and_CE_checkpoints3(nnUNetTrainerV2_Loss_FL_and_CE_checkpoints):
70
+ """
71
+ Each run is stored in a folder with the training class name in it. This simply creates a new folder,
72
+ to allow investigating the variability between restarts.
73
+ """
74
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
75
+ unpack_data=True, deterministic=True, fp16=False):
76
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
77
+ deterministic, fp16)