osbm commited on
Commit
e3bc7f9
1 Parent(s): ab3fa5a

Upload 3 files

Browse files
main.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def predict(img):
4
+ return "cat"
5
+
6
+
7
+ demo = gr.Interface(
8
+ fn=predict,
9
+ inputs=gr.Image(type="pil"),
10
+ outputs=gr.Label(num_top_classes=3),
11
+ )
12
+
13
+ demo.launch(server_name="0.0.0.0", server_port=7860)
nnUNetTrainerV2_Loss_FL_and_CE.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torch import nn
16
+
17
+ from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
18
+ from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
19
+ from nnunet.training.network_training.nnUNet_variants.loss_function.nnUNetTrainerV2_focalLoss import FocalLoss
20
+ # TODO: replace FocalLoss by fixed implemetation (and set smooth=0 in that one?)
21
+
22
+
23
+ class FL_and_CE_loss(nn.Module):
24
+ def __init__(self, fl_kwargs=None, ce_kwargs=None, alpha=0.5, aggregate="sum"):
25
+ super(FL_and_CE_loss, self).__init__()
26
+ if fl_kwargs is None:
27
+ fl_kwargs = {}
28
+ if ce_kwargs is None:
29
+ ce_kwargs = {}
30
+
31
+ self.aggregate = aggregate
32
+ self.fl = FocalLoss(apply_nonlin=nn.Softmax(), **fl_kwargs)
33
+ self.ce = RobustCrossEntropyLoss(**ce_kwargs)
34
+ self.alpha = alpha
35
+
36
+ def forward(self, net_output, target):
37
+ fl_loss = self.fl(net_output, target)
38
+ ce_loss = self.ce(net_output, target)
39
+ if self.aggregate == "sum":
40
+ result = self.alpha*fl_loss + (1-self.alpha)*ce_loss
41
+ else:
42
+ raise NotImplementedError("nah son")
43
+ return result
44
+
45
+
46
+ class nnUNetTrainerV2_Loss_FL_and_CE_checkpoints(nnUNetTrainerV2):
47
+ """
48
+ Set loss to FL + CE and set checkpoints
49
+ """
50
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
51
+ unpack_data=True, deterministic=True, fp16=False):
52
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
53
+ deterministic, fp16)
54
+ self.loss = FL_and_CE_loss(alpha=0.5)
55
+ self.save_latest_only = False
56
+
57
+
58
+ class nnUNetTrainerV2_Loss_FL_and_CE_checkpoints2(nnUNetTrainerV2_Loss_FL_and_CE_checkpoints):
59
+ """
60
+ Each run is stored in a folder with the training class name in it. This simply creates a new folder,
61
+ to allow investigating the variability between restarts.
62
+ """
63
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
64
+ unpack_data=True, deterministic=True, fp16=False):
65
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
66
+ deterministic, fp16)
67
+
68
+
69
+ class nnUNetTrainerV2_Loss_FL_and_CE_checkpoints3(nnUNetTrainerV2_Loss_FL_and_CE_checkpoints):
70
+ """
71
+ Each run is stored in a folder with the training class name in it. This simply creates a new folder,
72
+ to allow investigating the variability between restarts.
73
+ """
74
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
75
+ unpack_data=True, deterministic=True, fp16=False):
76
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
77
+ deterministic, fp16)
nnUNetTrainerV2_focalLoss.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
20
+
21
+
22
+ class FocalLoss(nn.Module):
23
+ """
24
+ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
25
+ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
26
+ 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
27
+ Focal_Loss= -1*alpha*(1-pt)*log(pt)
28
+ :param num_class:
29
+ :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
30
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
31
+ focus on hard misclassified example
32
+ :param smooth: (float,double) smooth value when cross entropy
33
+ :param balance_index: (int) balance class index, should be specific when alpha is float
34
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
35
+ """
36
+
37
+ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
38
+ super(FocalLoss, self).__init__()
39
+ self.apply_nonlin = apply_nonlin
40
+ self.alpha = alpha
41
+ self.gamma = gamma
42
+ self.balance_index = balance_index
43
+ self.smooth = smooth
44
+ self.size_average = size_average
45
+
46
+ if self.smooth is not None:
47
+ if self.smooth < 0 or self.smooth > 1.0:
48
+ raise ValueError('smooth value should be in [0,1]')
49
+
50
+ def forward(self, logit, target):
51
+ if self.apply_nonlin is not None:
52
+ logit = self.apply_nonlin(logit)
53
+ num_class = logit.shape[1]
54
+
55
+ if logit.dim() > 2:
56
+ # flatten spatial dimensions N,C,d1,d2 -> N,C,m (m=d1*d2*...)
57
+ logit = logit.view(logit.size(0), logit.size(1), -1)
58
+ logit = logit.permute(0, 2, 1).contiguous()
59
+ logit = logit.view(-1, logit.size(-1))
60
+ target = torch.squeeze(target, 1)
61
+ target = target.view(-1, 1)
62
+ # print(logit.shape, target.shape)
63
+
64
+ alpha = self.alpha
65
+
66
+ if alpha is None:
67
+ alpha = torch.ones(num_class, 1)
68
+ elif isinstance(alpha, (list, np.ndarray)):
69
+ assert len(alpha) == num_class
70
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
71
+ alpha = alpha / alpha.sum()
72
+ elif isinstance(alpha, float):
73
+ alpha = torch.ones(num_class, 1)
74
+ alpha = alpha * (1 - self.alpha)
75
+ alpha[self.balance_index] = self.alpha
76
+ else:
77
+ raise TypeError(f'Unsupported alpha type: {type(alpha)}')
78
+
79
+ if alpha.device != logit.device:
80
+ alpha = alpha.to(logit.device)
81
+
82
+ idx = target.cpu().long()
83
+
84
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
85
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
86
+ if one_hot_key.device != logit.device:
87
+ one_hot_key = one_hot_key.to(logit.device)
88
+
89
+ if self.smooth:
90
+ one_hot_key = torch.clamp(
91
+ one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
92
+ pt = (one_hot_key * logit).sum(1) + self.smooth
93
+ logpt = pt.log()
94
+
95
+ gamma = self.gamma
96
+
97
+ alpha = alpha[idx]
98
+ alpha = torch.squeeze(alpha)
99
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
100
+
101
+ if self.size_average:
102
+ loss = loss.mean()
103
+ else:
104
+ loss = loss.sum()
105
+
106
+ return loss
107
+
108
+
109
+ class nnUNetTrainerV2_focalLossAlpha75(nnUNetTrainerV2):
110
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
111
+ unpack_data=True, deterministic=True, fp16=False):
112
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
113
+ deterministic, fp16)
114
+ print("Setting up FocalLoss(alpha=[0.75, 0.25], apply_nonlin=nn.Softmax())")
115
+ self.loss = FocalLoss(alpha=[0.75, 0.25], apply_nonlin=nn.Softmax())
116
+
117
+
118
+ class nnUNetTrainerV2_focalLossAlpha75_checkpoints(nnUNetTrainerV2_focalLossAlpha75):
119
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
120
+ unpack_data=True, deterministic=True, fp16=False):
121
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
122
+ deterministic, fp16)
123
+ print("Saving checkpoint every 50th epoch")
124
+ self.save_latest_only = False
125
+
126
+
127
+ class nnUNetTrainerV2_focalLossAlpha75_checkpoints2(nnUNetTrainerV2_focalLossAlpha75_checkpoints):
128
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
129
+ unpack_data=True, deterministic=True, fp16=False):
130
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
131
+ deterministic, fp16)
132
+ pass # this is just to get a new Trainer directory
133
+
134
+
135
+ class nnUNetTrainerV2_focalLossAlpha75_checkpoints3(nnUNetTrainerV2_focalLossAlpha75_checkpoints):
136
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
137
+ unpack_data=True, deterministic=True, fp16=False):
138
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
139
+ deterministic, fp16)
140
+ pass # this is just to get a new Trainer directory