import numpy as np import torch import torch.nn as nn from torch.nn import LazyConv3d, MaxPool3d, BatchNorm3d from torch.nn.modules import Module from torch.nn.modules import ReLU from torch.nn.modules.dropout import Dropout from torch.nn.modules.instancenorm import InstanceNorm3d from custom_modules import LazyConvDropoutNormNonlinCat, ModularConvLayers, LazyConvBottleneckLayer class modular_hdunet_encoder(Module): """HDUnet encoder with modular parameters """ def __init__(self, base_num_filter, num_blocks_per_stage, num_stages, pool_kernel_sizes, conv_kernel_sizes, padding='same', conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d, dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d, pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU): """Object creation :param base_num_filter: base number of filters (output channels). :param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage). :param num_stages: number of stages. :param pool_kernel_sizes: last conv layer is strided => we use this parameter to set its kernel size and stride (can be different for each stage). Please note that this parameter is retrieved in our modular decoder and used as the scale factor (upsampling). :param conv_kernel_sizes: kernel size (can be different for each stage). :param padding: padding used, default is 'same'. :param conv_type: type of convolution used, default is a lazy convolution using: - dropout; - normalization; - nonlinear activation function; - concatenation. Must be a torch Module (should be a custom Module). :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. :param dropout_rate: dropout rate used by dropout, default is 0. :param expansion_rate: expansion rate used to modify the number of filters, default is 1. :param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module. :param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2). :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. """ super(modular_hdunet_encoder, self).__init__() self.base_num_filter = base_num_filter self.num_blocks_per_stage = num_blocks_per_stage self.num_stages = num_stages self.pool_kernel_sizes = pool_kernel_sizes self.conv_kernel_sizes = conv_kernel_sizes self.padding = padding self.conv_type = conv_type self.norm_type = norm_type self.dropout_type = dropout_type self.dropout_rate = dropout_rate self.nonlin = nonlin self.expansion_rate = expansion_rate self.pooling_type = pooling_type self.pooling_kernel_size = pooling_kernel_size self.stages = [] self.pooling_stages = [] self.end_stages = [] self.stage_output_features = [] self.stage_pool_kernel_size = [] self.stage_conv_kernel_size = [] assert len(pool_kernel_sizes) == len(conv_kernel_sizes) == num_stages if not isinstance(num_blocks_per_stage, (list, tuple)): num_blocks_per_stage = [num_blocks_per_stage] * num_stages else: assert len(num_blocks_per_stage) == num_stages self.num_blocks_per_stage = num_blocks_per_stage current_out_channels = 0 # This is where we manage the number of steps for stage in range(num_stages): current_out_channels = np.round((expansion_rate ** stage) * self.base_num_filter) current_num_blocks_per_stage = num_blocks_per_stage[stage] current_pool_kernel_size = pool_kernel_sizes[stage] current_kernel_size = conv_kernel_sizes[stage] current_stage = ModularConvLayers(output_channels=current_out_channels, num_conv_layers=current_num_blocks_per_stage, kernel_size=current_kernel_size, padding=padding, conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, nonlin=self.nonlin) self.pooling_stages.append(pooling_type(kernel_size=current_pool_kernel_size)) # BatchNorm3d added statically here (to be similar to the original model) current_end_stage = nn.Sequential( LazyConv3d(out_channels=current_out_channels, kernel_size=current_pool_kernel_size, stride=current_pool_kernel_size, padding=0), nonlin(), BatchNorm3d(current_out_channels) ) self.stages.append(current_stage) self.end_stages.append(current_end_stage) self.stage_output_features.append(current_out_channels) self.stage_pool_kernel_size.append(current_pool_kernel_size) self.stage_conv_kernel_size.append(current_kernel_size) self.stages = nn.ModuleList(self.stages) self.pooling_stages = nn.ModuleList(self.pooling_stages) self.end_stages = nn.ModuleList(self.end_stages) self.output_features = current_out_channels #self.features_reduction = nn.Conv1d(current_out_channels, current_out_channels//2, 3, stride=2) def forward(self, x): """Forward inputs through the layer :param x: the input to forward. :return: an array containing the results of the input at each stage of the down-sampling (before concatenation) which will be used in the decoder later on. The last value of the array is the very last value provided by the encoder (after concatenation) and will be used in the bottleneck. Therefore, provided x is the number of stages there are x + 1 values in the array. """ skips = [] for i, stage in enumerate(self.stages): x = stage(x) buff = self.pooling_stages[i](x) tmp = self.end_stages[i](x) skips.append(x) x = torch.cat([tmp, buff], dim=1) skips.append(x) # skips[-1]=self.features_reduction(skips[-1]) return skips class modular_hdunet_bottleneck(Module): """HDUnet bottleneck with modular parameters """ def __init__(self, base_num_filter, num_stages, conv_kernel_sizes, padding='same', num_steps_bottleneck=4, conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d, dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU): """Object creation :param base_num_filter: base number of filters (output channels). :param num_stages: number of stages of the encoder. :param conv_kernel_sizes: kernel size (can be different for each stage). :param padding: padding used, default is 'same'. :param num_steps_bottleneck: number of steps in the bottleneck, default is 4. :param conv_type: type of convolution used, default is a lazy convolution using: - dropout; - normalization; - nonlinear activation function. Must be a torch Module (should be a custom Module). :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. :param dropout_rate: dropout rate used by dropout, default is 0. :param expansion_rate: expansion rate used to modify the number of filters, default is 1. :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. """ super(modular_hdunet_bottleneck, self).__init__() self.base_num_filter = base_num_filter self.conv_kernel_sizes = conv_kernel_sizes self.padding = padding self.num_steps_bottleneck = num_steps_bottleneck self.conv_type = conv_type self.norm_type = norm_type self.dropout_type = dropout_type self.dropout_rate = dropout_rate self.expansion_rate = expansion_rate self.nonlin = nonlin encoder_output_features = (expansion_rate ** num_stages * base_num_filter) self.stages = [] self.step_conv_kernel_size = [] assert len(conv_kernel_sizes) == num_steps_bottleneck # This is where we manage the number of steps for step in range(num_steps_bottleneck): current_kernel_size = conv_kernel_sizes[step] self.stages.append( conv_type(output_channels=encoder_output_features, kernel_size=current_kernel_size, padding=padding, norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, nonlin=self.nonlin) ) self.stages = nn.ModuleList(self.stages) def forward(self, x): """Forward inputs through the layer :param x: the input to forward. At each step the input is concatenated with its result in order to produce the input of the next bottleneck layer. :return: the input forwarded through the layer. """ for stage in self.stages: buff = stage(x) x = torch.cat([buff, x], dim=1) return x class modular_hdunet_decoder(Module): """HDUnet decoder with modular parameters """ def __init__(self, previous, base_num_filter, num_blocks_per_stage=None, padding='same', conv_type: Module = LazyConvDropoutNormNonlinCat, norm_type: Module = InstanceNorm3d, dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, nonlin: Module = ReLU): """Object creation :param previous: the encoder which was previously used in the model. It is useful to retrieve some information that do not change such as the number of stages or the kernel sizes of each stages per example. :param base_num_filter: base number of filters (output channels). :param num_blocks_per_stage: number of convolutional block per stage (can be different for each stage). If set to None, it will be same than the encoder (reversed). :param padding: padding used, default is 'same'. :param conv_type: type of convolution used, default is a lazy convolution using: - dropout; - normalization; - nonlinear activation function; - concatenation. Must be a torch Module (should be a custom Module). :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. :param dropout_rate: dropout rate used by dropout, default is 0. :param expansion_rate: expansion rate used to modify the number of filters, default is 1. :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. """ super(modular_hdunet_decoder, self).__init__() self.base_num_filter = base_num_filter self.num_blocks_per_stage = num_blocks_per_stage self.padding = padding self.conv_type = conv_type self.norm_type = norm_type self.dropout_type = dropout_type self.dropout_rate = dropout_rate self.expansion_rate = expansion_rate self.nonlin = nonlin # We had to provide the skips using the set function since we are using Lazy layer and torchsummary does not # allow us to use an array as a parameter for the forward function. self.skips = [] # We retrieve the 'architectural' information that were provided to the encoder # in order to have a consistent decoder previous_stages = previous.stages previous_stage_output_features = previous.stage_output_features previous_stage_pool_kernel_size = previous.stage_pool_kernel_size previous_stage_conv_kernel_size = previous.stage_conv_kernel_size # We have the same as the first stage given that bottleneck is done separately self.num_stages = len(previous_stages) # If num_blocks_per_stage is set to None, it will be same than the encoder (reversed). if num_blocks_per_stage is None: self.num_blocks_per_stage = previous.num_blocks_per_stage[:][::-1] if not isinstance(self.num_blocks_per_stage, (list, tuple)): self.num_blocks_per_stage = [self.num_blocks_per_stage] * self.num_stages else: assert len(self.num_blocks_per_stage) == self.num_stages # There should be the same number of stages since we are doing the bottleneck and the encoder parts separately assert len(self.num_blocks_per_stage) == len(previous.num_blocks_per_stage) self.stage_output_features = previous_stage_output_features self.stage_pool_kernel_size = previous_stage_pool_kernel_size[::-1] self.stage_conv_kernel_size = previous_stage_conv_kernel_size[::-1] self.stages = [] number_half_layer = self.num_stages + 1 # This is where we manage the number of steps for stage in range(self.num_stages): current_out_channels = np.round( (expansion_rate ** (2 * number_half_layer - (stage + number_half_layer) - 1)) * self.base_num_filter) current_num_blocks_per_stage = self.num_blocks_per_stage[stage] current_pool_kernel_size = self.stage_pool_kernel_size[stage] current_kernel_size = self.stage_conv_kernel_size[stage] self.stages.append( ModularConvLayers(output_channels=current_out_channels, kernel_size=current_kernel_size, padding=padding, pool_size=current_pool_kernel_size, conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, num_conv_layers=current_num_blocks_per_stage, nonlin=self.nonlin, upsampling=True)) self.stages = nn.ModuleList(self.stages) def forward(self, x): """Forward inputs through the layer :param x: the input to forward. :return: the input forwarded through the layer. """ for i, stage in enumerate(self.stages): x = stage(x, self.skips[i + 1]) return x def set_skips(self, skips): self.skips = skips # We did our best we could to allow a maximum of modularity while keeping a certain sense in the parameters # we propose to modify. Nevertheless, we cannot guarantee that the model will work no matter what parameters you pass. # So if you change some parameters and the result is not what you expected, be careful to understand how it works # If you want to change the type of convolutional layer used, we advise you to check how the existing ones have # been implemented. # “With great power comes great responsibility” Uncle Ben. class modular_hdunet(Module): """HDUnet model with modular parameters """ def __init__(self, base_num_filter, num_blocks_per_stage_encoder, num_stages, pool_kernel_sizes, conv_kernel_sizes, conv_bottleneck_kernel_sizes, num_blocks_per_stage_decoder=None, padding='same', num_steps_bottleneck=4, conv_type: Module = LazyConvDropoutNormNonlinCat, bottleneck_conv_type: Module = LazyConvBottleneckLayer, norm_type: Module = InstanceNorm3d, dropout_type: Module = Dropout, dropout_rate=0, expansion_rate=1, pooling_type: Module = MaxPool3d, pooling_kernel_size=(2, 2, 2), nonlin: Module = ReLU): """Object creation :param base_num_filter: base number of filters (output channels). :param num_blocks_per_stage_encoder: number of convolutional block per stage for the encoder (can be different for each stage). :param num_stages: number of stages. :param pool_kernel_sizes: last convolutional layer of the encoder is strided => we use this parameter to set its kernel size and stride (can be different for each stage). :param conv_kernel_sizes: kernel size for the encoder and decoder (can be different for each stage). :param conv_bottleneck_kernel_sizes: kernel size for the bottleneck (can be different for each stage). :param padding: padding used, default is 'same'. :param num_blocks_per_stage_decoder: number of convolutional block per stage for the decoder (can be different for each stage). Default is None (it will be the same as the encoder). :param num_steps_bottleneck: number of steps in the bottleneck, default is 4. :param conv_type: type of convolution used, default is a lazy convolution using: - dropout; - normalization; - nonlinear activation function; - concatenation. Must be a torch Module (should be a custom Module). :param bottleneck_conv_type: type of convolution used in the bottleneck, default is a lazy convolution using: - dropout; - normalization; - nonlinear activation function. Must be a torch Module (should be a custom Module). :param norm_type: normalization type that is used, default is 3D instance normalization. Must be a torch Module. :param dropout_type: dropout type that is used, default is Dropout. Must be a torch Module. :param dropout_rate: dropout rate used by dropout, default is 0. :param expansion_rate: expansion rate used to modify the number of filters, default is 1. :param pooling_type: type of pooling used, default is 3D max pooling. Must be a torch Module. :param pooling_kernel_size: kernel size of the pooling layer, default is (2, 2, 2). :param nonlin: the nonlinear activation function to use, default is ReLU. Must be a torch Module. """ super(modular_hdunet, self).__init__() self.nonlin = nonlin self.encoder = modular_hdunet_encoder(base_num_filter=base_num_filter, num_blocks_per_stage=num_blocks_per_stage_encoder, num_stages=num_stages, pool_kernel_sizes=pool_kernel_sizes, conv_kernel_sizes=conv_kernel_sizes, padding=padding, conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, expansion_rate=expansion_rate, pooling_type=pooling_type, pooling_kernel_size=pooling_kernel_size, nonlin=self.nonlin) self.bottleNeck = modular_hdunet_bottleneck(base_num_filter=base_num_filter, num_stages=num_stages, conv_kernel_sizes=conv_bottleneck_kernel_sizes, padding=padding, num_steps_bottleneck=num_steps_bottleneck, conv_type=bottleneck_conv_type, norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, expansion_rate=expansion_rate, nonlin=self.nonlin) self.decoder = modular_hdunet_decoder(previous=self.encoder, base_num_filter=base_num_filter, num_blocks_per_stage=num_blocks_per_stage_decoder, padding=padding, conv_type=conv_type, norm_type=norm_type, dropout_type=dropout_type, dropout_rate=dropout_rate, expansion_rate=expansion_rate, nonlin=self.nonlin) self.last_block = nn.Sequential( LazyConv3d(out_channels=1, kernel_size=(3, 3, 3), padding='same'), nonlin() ) def forward(self, x): """Forward inputs through the layer (using the forward functions of the encoder/bottleneck/decoder) :param x: the input to forward. :return: the input forwarded through the layer. """ skips = self.encoder(x) tmp = self.bottleNeck(skips[-1]) # After providing the last value of skips to the bottleneck, # we replace it with the value computed in the bottleneck skips = skips[:-1] skips.append(tmp) # Since the first value that'll be used in the decoder is actually the last one of the array, we reverse it. skips = skips[::-1] self.decoder.set_skips(skips) x = skips[0] x = self.decoder(x) return self.last_block(x)