|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
from nnunet.network_architecture.custom_modules.helperModules import Identity |
|
from torch import nn |
|
|
|
|
|
class ConvDropoutNormReLU(nn.Module): |
|
def __init__(self, input_channels, output_channels, kernel_size, network_props): |
|
""" |
|
if network_props['dropout_op'] is None then no dropout |
|
if network_props['norm_op'] is None then no norm |
|
:param input_channels: |
|
:param output_channels: |
|
:param kernel_size: |
|
:param network_props: |
|
""" |
|
super(ConvDropoutNormReLU, self).__init__() |
|
|
|
network_props = deepcopy(network_props) |
|
|
|
self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size, |
|
padding=[(i - 1) // 2 for i in kernel_size], |
|
**network_props['conv_op_kwargs']) |
|
|
|
|
|
if network_props['dropout_op'] is not None: |
|
self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs']) |
|
else: |
|
self.do = Identity() |
|
|
|
if network_props['norm_op'] is not None: |
|
self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs']) |
|
else: |
|
self.norm = Identity() |
|
|
|
self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs']) |
|
|
|
self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin) |
|
|
|
def forward(self, x): |
|
return self.all(x) |
|
|
|
|
|
class StackedConvLayers(nn.Module): |
|
def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None): |
|
""" |
|
if network_props['dropout_op'] is None then no dropout |
|
if network_props['norm_op'] is None then no norm |
|
:param input_channels: |
|
:param output_channels: |
|
:param kernel_size: |
|
:param network_props: |
|
""" |
|
super(StackedConvLayers, self).__init__() |
|
|
|
network_props = deepcopy(network_props) |
|
network_props_first = deepcopy(network_props) |
|
|
|
if first_stride is not None: |
|
network_props_first['conv_op_kwargs']['stride'] = first_stride |
|
|
|
self.convs = nn.Sequential( |
|
ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first), |
|
*[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in |
|
range(num_convs - 1)] |
|
) |
|
|
|
def forward(self, x): |
|
return self.convs(x) |
|
|
|
|
|
class BasicResidualBlock(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): |
|
""" |
|
This is the conv bn nonlin conv bn nonlin kind of block |
|
:param in_planes: |
|
:param out_planes: |
|
:param props: |
|
:param override_stride: |
|
""" |
|
super().__init__() |
|
|
|
self.kernel_size = kernel_size |
|
props['conv_op_kwargs']['stride'] = 1 |
|
|
|
self.stride = stride |
|
self.props = props |
|
self.out_planes = out_planes |
|
self.in_planes = in_planes |
|
|
|
if stride is not None: |
|
kwargs_conv1 = deepcopy(props['conv_op_kwargs']) |
|
kwargs_conv1['stride'] = stride |
|
else: |
|
kwargs_conv1 = props['conv_op_kwargs'] |
|
|
|
self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
|
**kwargs_conv1) |
|
self.norm1 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
|
self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) |
|
|
|
if props['dropout_op_kwargs']['p'] != 0: |
|
self.dropout = props['dropout_op'](**props['dropout_op_kwargs']) |
|
else: |
|
self.dropout = Identity() |
|
|
|
self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
|
**props['conv_op_kwargs']) |
|
self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
|
self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) |
|
|
|
if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): |
|
stride_here = stride if stride is not None else 1 |
|
self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False), |
|
props['norm_op'](out_planes, **props['norm_op_kwargs'])) |
|
else: |
|
self.downsample_skip = lambda x: x |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.dropout(self.conv1(x)) |
|
out = self.nonlin1(self.norm1(out)) |
|
|
|
out = self.norm2(self.conv2(out)) |
|
|
|
residual = self.downsample_skip(residual) |
|
|
|
out += residual |
|
|
|
return self.nonlin2(out) |
|
|
|
|
|
class ResidualBottleneckBlock(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, props, stride=None): |
|
""" |
|
This is the conv bn nonlin conv bn nonlin kind of block |
|
:param in_planes: |
|
:param out_planes: |
|
:param props: |
|
:param override_stride: |
|
""" |
|
super().__init__() |
|
|
|
if props['dropout_op_kwargs'] is None and props['dropout_op_kwargs'] > 0: |
|
raise NotImplementedError("ResidualBottleneckBlock does not yet support dropout!") |
|
|
|
self.kernel_size = kernel_size |
|
props['conv_op_kwargs']['stride'] = 1 |
|
|
|
self.stride = stride |
|
self.props = props |
|
self.out_planes = out_planes |
|
self.in_planes = in_planes |
|
self.bottleneck_planes = out_planes // 4 |
|
|
|
if stride is not None: |
|
kwargs_conv1 = deepcopy(props['conv_op_kwargs']) |
|
kwargs_conv1['stride'] = stride |
|
else: |
|
kwargs_conv1 = props['conv_op_kwargs'] |
|
|
|
self.conv1 = props['conv_op'](in_planes, self.bottleneck_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size], |
|
**kwargs_conv1) |
|
self.norm1 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs']) |
|
self.nonlin1 = props['nonlin'](**props['nonlin_kwargs']) |
|
|
|
self.conv2 = props['conv_op'](self.bottleneck_planes, self.bottleneck_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size], |
|
**props['conv_op_kwargs']) |
|
self.norm2 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs']) |
|
self.nonlin2 = props['nonlin'](**props['nonlin_kwargs']) |
|
|
|
self.conv3 = props['conv_op'](self.bottleneck_planes, out_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size], |
|
**props['conv_op_kwargs']) |
|
self.norm3 = props['norm_op'](out_planes, **props['norm_op_kwargs']) |
|
self.nonlin3 = props['nonlin'](**props['nonlin_kwargs']) |
|
|
|
if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes): |
|
stride_here = stride if stride is not None else 1 |
|
self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False), |
|
props['norm_op'](out_planes, **props['norm_op_kwargs'])) |
|
else: |
|
self.downsample_skip = lambda x: x |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.nonlin1(self.norm1(self.conv1(x))) |
|
out = self.nonlin2(self.norm2(self.conv2(out))) |
|
|
|
out = self.norm3(self.conv3(out)) |
|
|
|
residual = self.downsample_skip(residual) |
|
|
|
out += residual |
|
|
|
return self.nonlin3(out) |
|
|
|
|
|
class ResidualLayer(nn.Module): |
|
def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None, block=BasicResidualBlock): |
|
super().__init__() |
|
|
|
network_props = deepcopy(network_props) |
|
|
|
self.convs = nn.Sequential( |
|
block(input_channels, output_channels, kernel_size, network_props, first_stride), |
|
*[block(output_channels, output_channels, kernel_size, network_props) for _ in |
|
range(num_blocks - 1)] |
|
) |
|
|
|
def forward(self, x): |
|
return self.convs(x) |
|
|
|
|