# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from nnunet.utilities.nd_softmax import softmax_helper from torch import nn import torch import numpy as np from nnunet.network_architecture.initialization import InitWeights_He from nnunet.network_architecture.neural_network import SegmentationNetwork import torch.nn.functional import matplotlib import matplotlib.pyplot as plt class ConvDropoutNormNonlin(nn.Module): """ fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad. """ def __init__(self, input_channels, output_channels, conv_op=nn.Conv2d, conv_kwargs=None, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None): super(ConvDropoutNormNonlin, self).__init__() if nonlin_kwargs is None: nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} if dropout_op_kwargs is None: dropout_op_kwargs = {'p': 0.5, 'inplace': True} if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} if conv_kwargs is None: conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} self.nonlin_kwargs = nonlin_kwargs self.nonlin = nonlin self.dropout_op = dropout_op self.dropout_op_kwargs = dropout_op_kwargs self.norm_op_kwargs = norm_op_kwargs self.conv_kwargs = conv_kwargs self.conv_op = conv_op self.norm_op = norm_op self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[ 'p'] > 0: self.dropout = self.dropout_op(**self.dropout_op_kwargs) else: self.dropout = None self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs) self.lrelu = self.nonlin(**self.nonlin_kwargs) def forward(self, x): x = self.conv(x) if self.dropout is not None: x = self.dropout(x) return self.lrelu(self.instnorm(x)) class ConvDropoutNonlinNorm(ConvDropoutNormNonlin): def forward(self, x): x = self.conv(x) if self.dropout is not None: x = self.dropout(x) return self.instnorm(self.lrelu(x)) class StackedConvLayers(nn.Module): def __init__(self, input_feature_channels, output_feature_channels, num_convs, conv_op=nn.Conv2d, conv_kwargs=None, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin): ''' stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers :param input_feature_channels: :param output_feature_channels: :param num_convs: :param dilation: :param kernel_size: :param padding: :param dropout: :param initial_stride: :param conv_op: :param norm_op: :param dropout_op: :param inplace: :param neg_slope: :param norm_affine: :param conv_bias: ''' self.input_channels = input_feature_channels self.output_channels = output_feature_channels if nonlin_kwargs is None: nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} if dropout_op_kwargs is None: dropout_op_kwargs = {'p': 0.5, 'inplace': True} if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} if conv_kwargs is None: conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} self.nonlin_kwargs = nonlin_kwargs self.nonlin = nonlin self.dropout_op = dropout_op self.dropout_op_kwargs = dropout_op_kwargs self.norm_op_kwargs = norm_op_kwargs self.conv_kwargs = conv_kwargs self.conv_op = conv_op self.norm_op = norm_op if first_stride is not None: self.conv_kwargs_first_conv = deepcopy(conv_kwargs) self.conv_kwargs_first_conv['stride'] = first_stride else: self.conv_kwargs_first_conv = conv_kwargs super(StackedConvLayers, self).__init__() self.blocks = nn.Sequential( *([basic_block(input_feature_channels, output_feature_channels, self.conv_op, self.conv_kwargs_first_conv, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs)] + [basic_block(output_feature_channels, output_feature_channels, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)])) def forward(self, x): return self.blocks(x) def print_module_training_status(module): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \ isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \ or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module, nn.BatchNorm1d): print(str(module), module.training) class Upsample(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): super(Upsample, self).__init__() self.align_corners = align_corners self.mode = mode self.scale_factor = scale_factor self.size = size def forward(self, x): return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) class Generic_UNet_MTLearly(SegmentationNetwork): DEFAULT_BATCH_SIZE_3D = 2 DEFAULT_PATCH_SIZE_3D = (64, 192, 160) SPACING_FACTOR_BETWEEN_STAGES = 2 BASE_NUM_FEATURES_3D = 30 MAX_NUMPOOL_3D = 999 MAX_NUM_FILTERS_3D = 320 DEFAULT_PATCH_SIZE_2D = (256, 256) BASE_NUM_FEATURES_2D = 30 DEFAULT_BATCH_SIZE_2D = 50 MAX_NUMPOOL_2D = 999 MAX_FILTERS_2D = 480 use_this_for_batch_size_computation_2D = 19739648 use_this_for_batch_size_computation_3D = 520000000 # 505789440 def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2, feat_map_mul_on_downscale=2, conv_op=nn.Conv2d, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, dropout_op=nn.Dropout2d, dropout_op_kwargs=None, nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False, final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None, conv_kernel_sizes=None, upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False, max_num_features=None, basic_block=ConvDropoutNormNonlin, seg_output_use_bias=False): """ basically more flexible than v1, architecture is the same Does this look complicated? Nah bro. Functionality > usability This does everything you need, including world peace. Questions? -> f.isensee@dkfz.de """ super(Generic_UNet_MTLearly, self).__init__() self.convolutional_upsampling = convolutional_upsampling self.convolutional_pooling = convolutional_pooling self.upscale_logits = upscale_logits if nonlin_kwargs is None: nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} if dropout_op_kwargs is None: dropout_op_kwargs = {'p': 0.5, 'inplace': True} if norm_op_kwargs is None: norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True} self.nonlin = nonlin self.nonlin_kwargs = nonlin_kwargs self.dropout_op_kwargs = dropout_op_kwargs self.norm_op_kwargs = norm_op_kwargs self.weightInitializer = weightInitializer self.conv_op = conv_op self.norm_op = norm_op self.dropout_op = dropout_op self.num_classes = num_classes self.final_nonlin = final_nonlin self._deep_supervision = deep_supervision self.do_ds = deep_supervision if conv_op == nn.Conv2d: upsample_mode = 'bilinear' pool_op = nn.MaxPool2d transpconv = nn.ConvTranspose2d if pool_op_kernel_sizes is None: pool_op_kernel_sizes = [(2, 2)] * num_pool if conv_kernel_sizes is None: conv_kernel_sizes = [(3, 3)] * (num_pool + 1) elif conv_op == nn.Conv3d: upsample_mode = 'trilinear' pool_op = nn.MaxPool3d transpconv = nn.ConvTranspose3d if pool_op_kernel_sizes is None: pool_op_kernel_sizes = [(2, 2, 2)] * num_pool if conv_kernel_sizes is None: conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1) else: raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op)) self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) self.pool_op_kernel_sizes = pool_op_kernel_sizes self.conv_kernel_sizes = conv_kernel_sizes self.conv_pad_sizes = [] for krnl in self.conv_kernel_sizes: self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) if max_num_features is None: if self.conv_op == nn.Conv3d: self.max_num_features = self.MAX_NUM_FILTERS_3D else: self.max_num_features = self.MAX_FILTERS_2D else: self.max_num_features = max_num_features self.conv_blocks_context = [] self.conv_blocks_localization_1 = [] self.conv_blocks_localization_2 = [] self.td = [] self.tu_1 = [] self.tu_2 = [] self.seg_outputs_1 = [] self.seg_outputs_2 = [] output_features = base_num_features input_features = input_channels for d in range(num_pool): # determine the first stride if d != 0 and self.convolutional_pooling: first_stride = pool_op_kernel_sizes[d - 1] else: first_stride = None self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] self.conv_kwargs['padding'] = self.conv_pad_sizes[d] # add convolutions self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, first_stride, basic_block=basic_block)) if not self.convolutional_pooling: self.td.append(pool_op(pool_op_kernel_sizes[d])) input_features = output_features output_features = int(np.round(output_features * feat_map_mul_on_downscale)) output_features = min(output_features, self.max_num_features) # now the bottleneck. # determine the first stride if self.convolutional_pooling: first_stride = pool_op_kernel_sizes[-1] else: first_stride = None # the output of the last conv must match the number of features from the skip connection if we are not using # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be # done by the transposed conv if self.convolutional_upsampling: final_num_features = output_features else: final_num_features = self.conv_blocks_context[-1].output_channels self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] self.conv_blocks_context.append(nn.Sequential( StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, first_stride, basic_block=basic_block), StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block))) # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here if not dropout_in_localization: old_dropout_p = self.dropout_op_kwargs['p'] self.dropout_op_kwargs['p'] = 0.0 # now lets build the localization pathway for u in range(num_pool): nfeatures_from_down = final_num_features nfeatures_from_skip = self.conv_blocks_context[ -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 n_features_after_tu_and_concat = nfeatures_from_skip * 2 # the first conv reduces the number of features to match those of skip # the following convs work on that number of features # if not convolutional upsampling then the final conv reduces the num of features again if u != num_pool - 1 and not self.convolutional_upsampling: final_num_features = self.conv_blocks_context[-(3 + u)].output_channels else: final_num_features = nfeatures_from_skip if not self.convolutional_upsampling: self.tu_1.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) self.tu_2.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode)) else: self.tu_1.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], pool_op_kernel_sizes[-(u + 1)], bias=False)) self.tu_2.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)], pool_op_kernel_sizes[-(u + 1)], bias=False)) self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)] self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)] self.conv_blocks_localization_1.append(nn.Sequential( StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block) )) self.conv_blocks_localization_2.append(nn.Sequential( StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block), StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block) )) for ds in range(len(self.conv_blocks_localization_1)): self.seg_outputs_1.append(conv_op(self.conv_blocks_localization_1[ds][-1].output_channels, num_classes[0], 1, 1, 0, 1, 1, seg_output_use_bias)) for ds in range(len(self.conv_blocks_localization_2)): self.seg_outputs_2.append(conv_op(self.conv_blocks_localization_2[ds][-1].output_channels, num_classes[1], 1, 1, 0, 1, 1, seg_output_use_bias)) self.upscale_logits_ops = [] cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1] for usl in range(num_pool - 1): if self.upscale_logits: self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]), mode=upsample_mode)) else: self.upscale_logits_ops.append(lambda x: x) if not dropout_in_localization: self.dropout_op_kwargs['p'] = old_dropout_p # register all modules properly self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) self.conv_blocks_localization_1 = nn.ModuleList(self.conv_blocks_localization_1) self.conv_blocks_localization_2 = nn.ModuleList(self.conv_blocks_localization_2) self.td = nn.ModuleList(self.td) self.tu_1 = nn.ModuleList(self.tu_1) self.tu_2 = nn.ModuleList(self.tu_2) self.seg_outputs_1 = nn.ModuleList(self.seg_outputs_1) self.seg_outputs_2 = nn.ModuleList(self.seg_outputs_2) if self.upscale_logits: self.upscale_logits_ops = nn.ModuleList( self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here if self.weightInitializer is not None: self.apply(self.weightInitializer) # self.apply(print_module_training_status) def forward(self, x): skips = [] seg_outputs_1 = [] seg_outputs_2 = [] for d in range(len(self.conv_blocks_context) - 1): x = self.conv_blocks_context[d](x) skips.append(x) if not self.convolutional_pooling: x = self.td[d](x) x1 = self.conv_blocks_context[-1](x) x2 = x1.clone() # Decoder 1 for u in range(len(self.tu_1)): x1 = self.tu_1[u](x1) x1 = torch.cat((x1, skips[-(u + 1)]), dim=1) x1 = self.conv_blocks_localization_1[u](x1) seg_outputs_1.append(self.final_nonlin(self.seg_outputs_1[u](x1))) # Decoder2 for u in range(len(self.tu_2)): x2 = self.tu_2[u](x2) x2 = torch.cat((x2, skips[-(u + 1)]), dim=1) x2 = self.conv_blocks_localization_2[u](x2) seg_outputs_2.append(self.final_nonlin(self.seg_outputs_2[u](x2))) if self._deep_supervision and self.do_ds: seg_1 = tuple([seg_outputs_1[-1]] + [i(j) for i, j in zip(list(self.upscale_logits_ops)[::-1], seg_outputs_1[:-1][::-1])]) seg_2 = tuple([seg_outputs_2[-1]] + [i(j) for i, j in zip(list(self.upscale_logits_ops)[::-1], seg_outputs_2[:-1][::-1])]) seg = tuple([torch.cat([s1, s2], dim=1) for s1, s2 in zip(seg_1, seg_2)]) return seg else: seg = tuple([torch.cat([s1, s2], dim=1) for s1, s2 in zip(seg_outputs_1, seg_outputs_2)]) return seg[-1] @staticmethod def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features, num_modalities, num_classes, pool_op_kernel_sizes, deep_supervision=False, conv_per_stage=2): """ This only applies for num_conv_per_stage and convolutional_upsampling=True not real vram consumption. just a constant term to which the vram consumption will be approx proportional (+ offset for parameter storage) :param deep_supervision: :param patch_size: :param num_pool_per_axis: :param base_num_features: :param max_num_features: :param num_modalities: :param num_classes: :param pool_op_kernel_sizes: :return: """ if not isinstance(num_pool_per_axis, np.ndarray): num_pool_per_axis = np.array(num_pool_per_axis) npool = len(pool_op_kernel_sizes) map_size = np.array(patch_size) tmp = np.int64((conv_per_stage * 2 + 1) * np.prod(map_size, dtype=np.int64) * base_num_features + num_modalities * np.prod(map_size, dtype=np.int64) + num_classes * np.prod(map_size, dtype=np.int64)) num_feat = base_num_features for p in range(npool): for pi in range(len(num_pool_per_axis)): map_size[pi] /= pool_op_kernel_sizes[p][pi] num_feat = min(num_feat * 2, max_num_features) # num_blocks = (conv_per_stage * 2 + 1) if p < (npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode and 1 for transposed conv num_blocks = (conv_per_stage * 5 + 1) if p < (npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode*2 and 1 for transposed conv tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat if deep_supervision and p < (npool - 2): tmp += np.prod(map_size, dtype=np.int64) * num_classes # print(p, map_size, num_feat, tmp) return tmp