# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from nnunet.network_architecture.custom_modules.helperModules import Identity from torch import nn class ConvDropoutNormReLU(nn.Module): def __init__(self, input_channels, output_channels, kernel_size, network_props): """ if network_props['dropout_op'] is None then no dropout if network_props['norm_op'] is None then no norm :param input_channels: :param output_channels: :param kernel_size: :param network_props: """ super(ConvDropoutNormReLU, self).__init__() network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe. self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], **network_props['conv_op_kwargs']) # maybe dropout if network_props['dropout_op'] is not None: self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs']) else: self.do = Identity() if network_props['norm_op'] is not None: self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs']) else: self.norm = Identity() self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs']) self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin) def forward(self, x): return self.all(x) class StackedConvLayers(nn.Module): def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None): """ if network_props['dropout_op'] is None then no dropout if network_props['norm_op'] is None then no norm :param input_channels: :param output_channels: :param kernel_size: :param network_props: """ super(StackedConvLayers, self).__init__() network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe. network_props_first = deepcopy(network_props) if first_stride is not None: network_props_first['conv_op_kwargs']['stride'] = first_stride self.convs = nn.Sequential( ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first), *[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in range(num_convs - 1)] ) def forward(self, x): return self.convs(x) class BasicResidualBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): """ This is the conv bn nonlin conv bn nonlin kind of block :param in_planes: :param out_planes: :param props: :param override_stride: """ super().__init__() self.kernel_size = kernel_size props['conv_op_kwargs']['stride'] = 1 self.stride = stride self.props = props self.out_planes = out_planes self.in_planes = in_planes if stride is not None: kwargs_conv1 = deepcopy(props['conv_op_kwargs']) kwargs_conv1['stride'] = stride else: kwargs_conv1 = props['conv_op_kwargs'] self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], **kwargs_conv1) self.norm1 = props['norm_op'](out_planes, **props['norm_op_kwargs']) self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) if props['dropout_op_kwargs']['p'] != 0: self.dropout = props['dropout_op'](**props['dropout_op_kwargs']) else: self.dropout = Identity() self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], **props['conv_op_kwargs']) self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs']) self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): stride_here = stride if stride is not None else 1 self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False), props['norm_op'](out_planes, **props['norm_op_kwargs'])) else: self.downsample_skip = lambda x: x def forward(self, x): residual = x out = self.dropout(self.conv1(x)) out = self.nonlin1(self.norm1(out)) out = self.norm2(self.conv2(out)) residual = self.downsample_skip(residual) out += residual return self.nonlin2(out) class ResidualBottleneckBlock(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): """ This is the conv bn nonlin conv bn nonlin kind of block :param in_planes: :param out_planes: :param props: :param override_stride: """ super().__init__() if props['dropout_op_kwargs'] is None and props['dropout_op_kwargs'] > 0: raise NotImplementedError("ResidualBottleneckBlock does not yet support dropout!") self.kernel_size = kernel_size props['conv_op_kwargs']['stride'] = 1 self.stride = stride self.props = props self.out_planes = out_planes self.in_planes = in_planes self.bottleneck_planes = out_planes // 4 if stride is not None: kwargs_conv1 = deepcopy(props['conv_op_kwargs']) kwargs_conv1['stride'] = stride else: kwargs_conv1 = props['conv_op_kwargs'] self.conv1 = props['conv_op'](in_planes, self.bottleneck_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size], **kwargs_conv1) self.norm1 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs']) self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) self.conv2 = props['conv_op'](self.bottleneck_planes, self.bottleneck_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], **props['conv_op_kwargs']) self.norm2 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs']) self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) self.conv3 = props['conv_op'](self.bottleneck_planes, out_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size], **props['conv_op_kwargs']) self.norm3 = props['norm_op'](out_planes, **props['norm_op_kwargs']) self.nonlin3 = props['nonlin'](**props['nonlin_kwargs']) if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): stride_here = stride if stride is not None else 1 self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False), props['norm_op'](out_planes, **props['norm_op_kwargs'])) else: self.downsample_skip = lambda x: x def forward(self, x): residual = x out = self.nonlin1(self.norm1(self.conv1(x))) out = self.nonlin2(self.norm2(self.conv2(out))) out = self.norm3(self.conv3(out)) residual = self.downsample_skip(residual) out += residual return self.nonlin3(out) class ResidualLayer(nn.Module): def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None, block=BasicResidualBlock): super().__init__() network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe. self.convs = nn.Sequential( block(input_channels, output_channels, kernel_size, network_props, first_stride), *[block(output_channels, output_channels, kernel_size, network_props) for _ in range(num_blocks - 1)] ) def forward(self, x): return self.convs(x)