Margerie commited on
Commit
66427d3
1 Parent(s): 267d901

Upload 2 files

Browse files
Files changed (2) hide show
  1. best_metric_dose_model.pth +3 -0
  2. modular_hdunet.py +398 -0
best_metric_dose_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d505acd7831b764949f0f6b00c9a1fbd5df25d93a3dbc1bc69771d36c6225704
3
+ size 17616673
modular_hdunet.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import LazyConv3d, MaxPool3d, BatchNorm3d
5
+
6
+ from torch.nn.modules import Module
7
+ from torch.nn.modules import ReLU
8
+ from torch.nn.modules.dropout import Dropout
9
+ from torch.nn.modules.instancenorm import InstanceNorm3d
10
+ from custom_modules import LazyConvDropoutNormNonlinCat, ModularConvLayers, LazyConvBottleneckLayer
11
+
12
+
13
+ class modular_hdunet_encoder(Module):
14
+ """HDUnet encoder with modular parameters
15
+ """
16
+
17
+ def __init__(self, base_num_filter, num_blocks_per_stage, num_stages, pool_kernel_sizes, conv_kernel_sizes,
18
+ padding='same', conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d,
19
+ dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d,
20
+ pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU):
21
+ """Object creation
22
+
23
+ :param base_num_filter: base number of filters (output channels).
24
+ :param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage).
25
+ :param num_stages: number of stages.
26
+ :param pool_kernel_sizes: last conv layer is strided => we use this parameter to set its kernel size and stride
27
+ (can be different for each stage).
28
+ Please note that this parameter is retrieved in our modular decoder and used as the scale factor (upsampling).
29
+ :param conv_kernel_sizes: kernel size (can be different for each stage).
30
+ :param padding: padding used, default is 'same'.
31
+ :param conv_type: type of convolution used, default is a lazy convolution using:
32
+ - dropout;
33
+ - normalization;
34
+ - nonlinear activation function;
35
+ - concatenation.
36
+ Must be a torch Module (should be a custom Module).
37
+ :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
38
+ :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
39
+ :param dropout_rate: dropout rate used by dropout, default is 0.
40
+ :param expansion_rate: expansion rate used to modify the number of filters, default is 1.
41
+ :param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module.
42
+ :param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2).
43
+ :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
44
+ """
45
+ super(modular_hdunet_encoder, self).__init__()
46
+ self.base_num_filter = base_num_filter
47
+ self.num_blocks_per_stage = num_blocks_per_stage
48
+ self.num_stages = num_stages
49
+ self.pool_kernel_sizes = pool_kernel_sizes
50
+ self.conv_kernel_sizes = conv_kernel_sizes
51
+ self.padding = padding
52
+ self.conv_type = conv_type
53
+ self.norm_type = norm_type
54
+ self.dropout_type = dropout_type
55
+ self.dropout_rate = dropout_rate
56
+ self.nonlin = nonlin
57
+ self.expansion_rate = expansion_rate
58
+ self.pooling_type = pooling_type
59
+ self.pooling_kernel_size = pooling_kernel_size
60
+
61
+ self.stages = []
62
+ self.pooling_stages = []
63
+ self.end_stages = []
64
+ self.stage_output_features = []
65
+ self.stage_pool_kernel_size = []
66
+ self.stage_conv_kernel_size = []
67
+
68
+ assert len(pool_kernel_sizes) == len(conv_kernel_sizes) == num_stages
69
+
70
+ if not isinstance(num_blocks_per_stage, (list, tuple)):
71
+ num_blocks_per_stage = [num_blocks_per_stage] * num_stages
72
+ else:
73
+ assert len(num_blocks_per_stage) == num_stages
74
+
75
+ self.num_blocks_per_stage = num_blocks_per_stage
76
+
77
+ current_out_channels = 0
78
+ # This is where we manage the number of steps
79
+ for stage in range(num_stages):
80
+ current_out_channels = np.round((expansion_rate ** stage) * self.base_num_filter)
81
+ current_num_blocks_per_stage = num_blocks_per_stage[stage]
82
+ current_pool_kernel_size = pool_kernel_sizes[stage]
83
+ current_kernel_size = conv_kernel_sizes[stage]
84
+
85
+ current_stage = ModularConvLayers(output_channels=current_out_channels,
86
+ num_conv_layers=current_num_blocks_per_stage,
87
+ kernel_size=current_kernel_size, padding=padding, conv_type=conv_type,
88
+ norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate,
89
+ nonlin=self.nonlin)
90
+
91
+ self.pooling_stages.append(pooling_type(kernel_size=current_pool_kernel_size))
92
+
93
+ # BatchNorm3d added statically here (to be similar to the original model)
94
+ current_end_stage = nn.Sequential(
95
+ LazyConv3d(out_channels=current_out_channels, kernel_size=current_pool_kernel_size,
96
+ stride=current_pool_kernel_size, padding=0), nonlin(), BatchNorm3d(current_out_channels)
97
+ )
98
+
99
+ self.stages.append(current_stage)
100
+ self.end_stages.append(current_end_stage)
101
+ self.stage_output_features.append(current_out_channels)
102
+ self.stage_pool_kernel_size.append(current_pool_kernel_size)
103
+ self.stage_conv_kernel_size.append(current_kernel_size)
104
+
105
+ self.stages = nn.ModuleList(self.stages)
106
+ self.pooling_stages = nn.ModuleList(self.pooling_stages)
107
+ self.end_stages = nn.ModuleList(self.end_stages)
108
+ self.output_features = current_out_channels
109
+ #self.features_reduction = nn.Conv1d(current_out_channels, current_out_channels//2, 3, stride=2)
110
+ def forward(self, x):
111
+ """Forward inputs through the layer
112
+
113
+ :param x: the input to forward.
114
+ :return: an array containing the results of the input at each stage of the down-sampling (before concatenation)
115
+ which will be used in the decoder later on. The last value of the array is the very last value provided by the
116
+ encoder (after concatenation) and will be used in the bottleneck. Therefore, provided x is the number of stages
117
+ there are x + 1 values in the array.
118
+ """
119
+ skips = []
120
+
121
+ for i, stage in enumerate(self.stages):
122
+ x = stage(x)
123
+ buff = self.pooling_stages[i](x)
124
+ tmp = self.end_stages[i](x)
125
+ skips.append(x)
126
+ x = torch.cat([tmp, buff], dim=1)
127
+ skips.append(x)
128
+ # skips[-1]=self.features_reduction(skips[-1])
129
+ return skips
130
+
131
+
132
+ class modular_hdunet_bottleneck(Module):
133
+ """HDUnet bottleneck with modular parameters
134
+ """
135
+
136
+ def __init__(self, base_num_filter, num_stages, conv_kernel_sizes, padding='same', num_steps_bottleneck=4,
137
+ conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d,
138
+ dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU):
139
+ """Object creation
140
+
141
+ :param base_num_filter: base number of filters (output channels).
142
+ :param num_stages: number of stages of the encoder.
143
+ :param conv_kernel_sizes: kernel size (can be different for each stage).
144
+ :param padding: padding used, default is 'same'.
145
+ :param num_steps_bottleneck: number of steps in the bottleneck, default is 4.
146
+ :param conv_type: type of convolution used, default is a lazy convolution using:
147
+ - dropout;
148
+ - normalization;
149
+ - nonlinear activation function.
150
+ Must be a torch Module (should be a custom Module).
151
+ :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
152
+ :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
153
+ :param dropout_rate: dropout rate used by dropout, default is 0.
154
+ :param expansion_rate: expansion rate used to modify the number of filters, default is 1.
155
+ :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
156
+ """
157
+ super(modular_hdunet_bottleneck, self).__init__()
158
+ self.base_num_filter = base_num_filter
159
+ self.conv_kernel_sizes = conv_kernel_sizes
160
+ self.padding = padding
161
+ self.num_steps_bottleneck = num_steps_bottleneck
162
+ self.conv_type = conv_type
163
+ self.norm_type = norm_type
164
+ self.dropout_type = dropout_type
165
+ self.dropout_rate = dropout_rate
166
+ self.expansion_rate = expansion_rate
167
+ self.nonlin = nonlin
168
+
169
+ encoder_output_features = (expansion_rate ** num_stages * base_num_filter)
170
+
171
+ self.stages = []
172
+ self.step_conv_kernel_size = []
173
+
174
+ assert len(conv_kernel_sizes) == num_steps_bottleneck
175
+
176
+ # This is where we manage the number of steps
177
+ for step in range(num_steps_bottleneck):
178
+ current_kernel_size = conv_kernel_sizes[step]
179
+ self.stages.append(
180
+ conv_type(output_channels=encoder_output_features, kernel_size=current_kernel_size, padding=padding,
181
+ norm_type=norm_type, dropout_type=dropout_type,
182
+ dropout_rate=dropout_rate, nonlin=self.nonlin)
183
+ )
184
+
185
+ self.stages = nn.ModuleList(self.stages)
186
+
187
+
188
+ def forward(self, x):
189
+ """Forward inputs through the layer
190
+
191
+ :param x: the input to forward. At each step the input is concatenated with
192
+ its result in order to produce the input of the next bottleneck layer.
193
+ :return: the input forwarded through the layer.
194
+ """
195
+ for stage in self.stages:
196
+ buff = stage(x)
197
+ x = torch.cat([buff, x], dim=1)
198
+ return x
199
+
200
+
201
+ class modular_hdunet_decoder(Module):
202
+ """HDUnet decoder with modular parameters
203
+ """
204
+
205
+ def __init__(self, previous, base_num_filter, num_blocks_per_stage=None, padding='same',
206
+ conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d,
207
+ dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU):
208
+ """Object creation
209
+
210
+ :param previous: the encoder which was previously used in the model. It is useful to retrieve some information
211
+ that do not change such as the number of stages or the kernel sizes of each stages per example.
212
+ :param base_num_filter: base number of filters (output channels).
213
+ :param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage).
214
+ If set to None, it will be same than the encoder (reversed).
215
+ :param padding: padding used, default is 'same'.
216
+ :param conv_type: type of convolution used, default is a lazy convolution using:
217
+ - dropout;
218
+ - normalization;
219
+ - nonlinear activation function;
220
+ - concatenation.
221
+ Must be a torch Module (should be a custom Module).
222
+ :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
223
+ :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
224
+ :param dropout_rate: dropout rate used by dropout, default is 0.
225
+ :param expansion_rate: expansion rate used to modify the number of filters, default is 1.
226
+ :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
227
+ """
228
+ super(modular_hdunet_decoder, self).__init__()
229
+ self.base_num_filter = base_num_filter
230
+ self.num_blocks_per_stage = num_blocks_per_stage
231
+ self.padding = padding
232
+ self.conv_type = conv_type
233
+ self.norm_type = norm_type
234
+ self.dropout_type = dropout_type
235
+ self.dropout_rate = dropout_rate
236
+ self.expansion_rate = expansion_rate
237
+ self.nonlin = nonlin
238
+
239
+ # We had to provide the skips using the set function since we are using Lazy layer and torchsummary does not
240
+ # allow us to use an array as a parameter for the forward function.
241
+ self.skips = []
242
+
243
+ # We retrieve the 'architectural' information that were provided to the encoder
244
+ # in order to have a consistent decoder
245
+ previous_stages = previous.stages
246
+ previous_stage_output_features = previous.stage_output_features
247
+ previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
248
+ previous_stage_conv_kernel_size = previous.stage_conv_kernel_size
249
+
250
+ # We have the same as the first stage given that bottleneck is done separately
251
+ self.num_stages = len(previous_stages)
252
+
253
+ # If num_blocks_per_stage is set to None, it will be same than the encoder (reversed).
254
+ if num_blocks_per_stage is None:
255
+ self.num_blocks_per_stage = previous.num_blocks_per_stage[:][::-1]
256
+
257
+ if not isinstance(self.num_blocks_per_stage, (list, tuple)):
258
+ self.num_blocks_per_stage = [self.num_blocks_per_stage] * self.num_stages
259
+ else:
260
+ assert len(self.num_blocks_per_stage) == self.num_stages
261
+
262
+ # There should be the same number of stages since we are doing the bottleneck and the encoder parts separately
263
+ assert len(self.num_blocks_per_stage) == len(previous.num_blocks_per_stage)
264
+
265
+ self.stage_output_features = previous_stage_output_features
266
+ self.stage_pool_kernel_size = previous_stage_pool_kernel_size[::-1]
267
+ self.stage_conv_kernel_size = previous_stage_conv_kernel_size[::-1]
268
+
269
+ self.stages = []
270
+
271
+ number_half_layer = self.num_stages + 1
272
+ # This is where we manage the number of steps
273
+ for stage in range(self.num_stages):
274
+ current_out_channels = np.round(
275
+ (expansion_rate ** (2 * number_half_layer - (stage + number_half_layer) - 1)) * self.base_num_filter)
276
+ current_num_blocks_per_stage = self.num_blocks_per_stage[stage]
277
+ current_pool_kernel_size = self.stage_pool_kernel_size[stage]
278
+ current_kernel_size = self.stage_conv_kernel_size[stage]
279
+ self.stages.append(
280
+ ModularConvLayers(output_channels=current_out_channels, kernel_size=current_kernel_size,
281
+ padding=padding, pool_size=current_pool_kernel_size, conv_type=conv_type,
282
+ norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate,
283
+ num_conv_layers=current_num_blocks_per_stage, nonlin=self.nonlin, upsampling=True))
284
+
285
+ self.stages = nn.ModuleList(self.stages)
286
+
287
+ def forward(self, x):
288
+ """Forward inputs through the layer
289
+
290
+ :param x: the input to forward.
291
+ :return: the input forwarded through the layer.
292
+ """
293
+ for i, stage in enumerate(self.stages):
294
+ x = stage(x, self.skips[i + 1])
295
+ return x
296
+
297
+ def set_skips(self, skips):
298
+ self.skips = skips
299
+
300
+
301
+ # We did our best we could to allow a maximum of modularity while keeping a certain sense in the parameters
302
+ # we propose to modify. Nevertheless, we cannot guarantee that the model will work no matter what parameters you pass.
303
+ # So if you change some parameters and the result is not what you expected, be careful to understand how it works
304
+ # If you want to change the type of convolutional layer used, we advise you to check how the existing ones have
305
+ # been implemented.
306
+ # “With great power comes great responsibility” Uncle Ben.
307
+ class modular_hdunet(Module):
308
+ """HDUnet model with modular parameters
309
+ """
310
+
311
+ def __init__(self, base_num_filter, num_blocks_per_stage_encoder, num_stages,
312
+ pool_kernel_sizes, conv_kernel_sizes, conv_bottleneck_kernel_sizes, num_blocks_per_stage_decoder=None,
313
+ padding='same', num_steps_bottleneck=4, conv_type: Module = LazyConvDropoutNormNonlinCat,
314
+ bottleneck_conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d,
315
+ dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d,
316
+ pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU):
317
+ """Object creation
318
+
319
+ :param base_num_filter: base number of filters (output channels).
320
+ :param num_blocks_per_stage_encoder: number of convolutional block per stage for the encoder
321
+ (can be different for each stage).
322
+ :param num_stages: number of stages.
323
+ :param pool_kernel_sizes: last convolutional layer of the encoder is strided => we use this parameter
324
+ to set its kernel size and stride (can be different for each stage).
325
+ :param conv_kernel_sizes: kernel size for the encoder and decoder (can be different for each stage).
326
+ :param conv_bottleneck_kernel_sizes: kernel size for the bottleneck (can be different for each stage).
327
+ :param padding: padding used, default is 'same'.
328
+ :param num_blocks_per_stage_decoder: number of convolutional block per stage for the decoder
329
+ (can be different for each stage). Default is None (it will be the same as the encoder).
330
+ :param num_steps_bottleneck: number of steps in the bottleneck, default is 4.
331
+ :param conv_type: type of convolution used, default is a lazy convolution using:
332
+ - dropout;
333
+ - normalization;
334
+ - nonlinear activation function;
335
+ - concatenation.
336
+ Must be a torch Module (should be a custom Module).
337
+ :param bottleneck_conv_type: type of convolution used in the bottleneck, default is a lazy convolution using:
338
+ - dropout;
339
+ - normalization;
340
+ - nonlinear activation function.
341
+ Must be a torch Module (should be a custom Module).
342
+ :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module.
343
+ :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module.
344
+ :param dropout_rate: dropout rate used by dropout, default is 0.
345
+ :param expansion_rate: expansion rate used to modify the number of filters, default is 1.
346
+ :param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module.
347
+ :param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2).
348
+ :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module.
349
+ """
350
+ super(modular_hdunet, self).__init__()
351
+ self.nonlin = nonlin
352
+ self.encoder = modular_hdunet_encoder(base_num_filter=base_num_filter,
353
+ num_blocks_per_stage=num_blocks_per_stage_encoder, num_stages=num_stages,
354
+ pool_kernel_sizes=pool_kernel_sizes, conv_kernel_sizes=conv_kernel_sizes,
355
+ padding=padding, conv_type=conv_type, norm_type=norm_type,
356
+ dropout_type=dropout_type, dropout_rate=dropout_rate,
357
+ expansion_rate=expansion_rate, pooling_type=pooling_type,
358
+ pooling_kernel_size=pooling_kernel_size, nonlin=self.nonlin)
359
+
360
+ self.bottleNeck = modular_hdunet_bottleneck(base_num_filter=base_num_filter, num_stages=num_stages,
361
+ conv_kernel_sizes=conv_bottleneck_kernel_sizes, padding=padding,
362
+ num_steps_bottleneck=num_steps_bottleneck,
363
+ conv_type=bottleneck_conv_type, norm_type=norm_type,
364
+ dropout_type=dropout_type, dropout_rate=dropout_rate,
365
+ expansion_rate=expansion_rate, nonlin=self.nonlin)
366
+
367
+ self.decoder = modular_hdunet_decoder(previous=self.encoder, base_num_filter=base_num_filter,
368
+ num_blocks_per_stage=num_blocks_per_stage_decoder, padding=padding,
369
+ conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type,
370
+ dropout_rate=dropout_rate, expansion_rate=expansion_rate,
371
+ nonlin=self.nonlin)
372
+
373
+ self.last_block = nn.Sequential(
374
+ LazyConv3d(out_channels=1, kernel_size=(3, 3, 3), padding='same'),
375
+ nonlin()
376
+ )
377
+
378
+ def forward(self, x):
379
+ """Forward inputs through the layer
380
+ (using the forward functions of the encoder/bottleneck/decoder)
381
+
382
+ :param x: the input to forward.
383
+ :return: the input forwarded through the layer.
384
+ """
385
+ skips = self.encoder(x)
386
+ tmp = self.bottleNeck(skips[-1])
387
+
388
+ # After providing the last value of skips to the bottleneck,
389
+ # we replace it with the value computed in the bottleneck
390
+ skips = skips[:-1]
391
+ skips.append(tmp)
392
+
393
+ # Since the first value that'll be used in the decoder is actually the last one of the array, we reverse it.
394
+ skips = skips[::-1]
395
+ self.decoder.set_skips(skips)
396
+ x = skips[0]
397
+ x = self.decoder(x)
398
+ return self.last_block(x)