ho11laqe's picture
init
ecf08bc
raw
history blame
20 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from nnunet.network_architecture.custom_modules.conv_blocks import StackedConvLayers
from nnunet.network_architecture.generic_UNet import Upsample
from nnunet.network_architecture.neural_network import SegmentationNetwork
from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
from torch import nn
import numpy as np
from torch.optim import SGD
"""
The idea behind this modular U-net ist that we decouple encoder and decoder and thus make things a) a lot more easy to
combine and b) enable easy swapping between segmentation or classification mode of the same architecture
"""
def get_default_network_config(dim=2, dropout_p=None, nonlin="LeakyReLU", norm_type="bn"):
"""
returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use
:return:
"""
props = {}
if dim == 2:
props['conv_op'] = nn.Conv2d
props['dropout_op'] = nn.Dropout2d
elif dim == 3:
props['conv_op'] = nn.Conv3d
props['dropout_op'] = nn.Dropout3d
else:
raise NotImplementedError
if norm_type == "bn":
if dim == 2:
props['norm_op'] = nn.BatchNorm2d
elif dim == 3:
props['norm_op'] = nn.BatchNorm3d
props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}
elif norm_type == "in":
if dim == 2:
props['norm_op'] = nn.InstanceNorm2d
elif dim == 3:
props['norm_op'] = nn.InstanceNorm3d
props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}
else:
raise NotImplementedError
if dropout_p is None:
props['dropout_op'] = None
props['dropout_op_kwargs'] = {'p': 0, 'inplace': True}
else:
props['dropout_op_kwargs'] = {'p': dropout_p, 'inplace': True}
props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network!
if nonlin == "LeakyReLU":
props['nonlin'] = nn.LeakyReLU
props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
elif nonlin == "ReLU":
props['nonlin'] = nn.ReLU
props['nonlin_kwargs'] = {'inplace': True}
else:
raise ValueError
return props
class PlainConvUNetEncoder(nn.Module):
def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale,
pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True,
max_num_features=480):
"""
Following UNet building blocks can be added by utilizing the properties this class exposes (TODO)
this one includes the bottleneck layer!
:param input_channels:
:param base_num_features:
:param num_blocks_per_stage:
:param feat_map_mul_on_downscale:
:param pool_op_kernel_sizes:
:param conv_kernel_sizes:
:param props:
"""
super(PlainConvUNetEncoder, self).__init__()
self.default_return_skips = default_return_skips
self.props = props
self.stages = []
self.stage_output_features = []
self.stage_pool_kernel_size = []
self.stage_conv_op_kernel_size = []
assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes)
num_stages = len(conv_kernel_sizes)
if not isinstance(num_blocks_per_stage, (list, tuple)):
num_blocks_per_stage = [num_blocks_per_stage] * num_stages
else:
assert len(num_blocks_per_stage) == num_stages
self.num_blocks_per_stage = num_blocks_per_stage # decoder may need this
current_input_features = input_channels
for stage in range(num_stages):
current_output_features = min(int(base_num_features * feat_map_mul_on_downscale ** stage), max_num_features)
current_kernel_size = conv_kernel_sizes[stage]
current_pool_kernel_size = pool_op_kernel_sizes[stage]
current_stage = StackedConvLayers(current_input_features, current_output_features, current_kernel_size,
props, num_blocks_per_stage[stage], current_pool_kernel_size)
self.stages.append(current_stage)
self.stage_output_features.append(current_output_features)
self.stage_conv_op_kernel_size.append(current_kernel_size)
self.stage_pool_kernel_size.append(current_pool_kernel_size)
# update current_input_features
current_input_features = current_output_features
self.stages = nn.ModuleList(self.stages)
self.output_features = current_output_features
def forward(self, x, return_skips=None):
"""
:param x:
:param return_skips: if none then self.default_return_skips is used
:return:
"""
skips = []
for s in self.stages:
x = s(x)
if self.default_return_skips:
skips.append(x)
if return_skips is None:
return_skips = self.default_return_skips
if return_skips:
return skips
else:
return x
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, pool_op_kernel_sizes, num_blocks_per_stage_encoder,
feat_map_mul_on_downscale, batch_size):
npool = len(pool_op_kernel_sizes) - 1
current_shape = np.array(patch_size)
tmp = num_blocks_per_stage_encoder[0] * np.prod(current_shape) * base_num_features \
+ num_modalities * np.prod(current_shape)
num_feat = base_num_features
for p in range(1, npool + 1):
current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
num_convs = num_blocks_per_stage_encoder[p]
print(p, num_feat, num_convs, current_shape)
tmp += num_convs * np.prod(current_shape) * num_feat
return tmp * batch_size
class PlainConvUNetDecoder(nn.Module):
def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False,
upscale_logits=False):
super(PlainConvUNetDecoder, self).__init__()
self.num_classes = num_classes
self.deep_supervision = deep_supervision
"""
We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here
"""
previous_stages = previous.stages
previous_stage_output_features = previous.stage_output_features
previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size
if network_props is None:
self.props = previous.props
else:
self.props = network_props
if self.props['conv_op'] == nn.Conv2d:
transpconv = nn.ConvTranspose2d
upsample_mode = "bilinear"
elif self.props['conv_op'] == nn.Conv3d:
transpconv = nn.ConvTranspose3d
upsample_mode = "trilinear"
else:
raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op']))
if num_blocks_per_stage is None:
num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1]
assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1
self.stage_pool_kernel_size = previous_stage_pool_kernel_size
self.stage_output_features = previous_stage_output_features
self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size
num_stages = len(previous_stages) - 1 # we have one less as the first stage here is what comes after the
# bottleneck
self.tus = []
self.stages = []
self.deep_supervision_outputs = []
# only used for upsample_logits
cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int)
for i, s in enumerate(np.arange(num_stages)[::-1]):
features_below = previous_stage_output_features[s + 1]
features_skip = previous_stage_output_features[s]
self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1],
previous_stage_pool_kernel_size[s + 1], bias=False))
# after we tu we concat features so now we have 2xfeatures_skip
self.stages.append(StackedConvLayers(2 * features_skip, features_skip,
previous_stage_conv_op_kernel_size[s], self.props,
num_blocks_per_stage[i]))
if deep_supervision and s != 0:
seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
if upscale_logits:
upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode)
self.deep_supervision_outputs.append(nn.Sequential(seg_layer, upsample))
else:
self.deep_supervision_outputs.append(seg_layer)
self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
self.tus = nn.ModuleList(self.tus)
self.stages = nn.ModuleList(self.stages)
self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
def forward(self, skips, gt=None, loss=None):
# skips come from the encoder. They are sorted so that the bottleneck is last in the list
# what is maybe not perfect is that the TUs and stages here are sorted the other way around
# so let's just reverse the order of skips
skips = skips[::-1]
seg_outputs = []
x = skips[0] # this is the bottleneck
for i in range(len(self.tus)):
x = self.tus[i](x)
x = torch.cat((x, skips[i + 1]), dim=1)
x = self.stages[i](x)
if self.deep_supervision and (i != len(self.tus) - 1):
tmp = self.deep_supervision_outputs[i](x)
if gt is not None:
tmp = loss(tmp, gt)
seg_outputs.append(tmp)
segmentation = self.segmentation_output(x)
if self.deep_supervision:
tmp = segmentation
if gt is not None:
tmp = loss(tmp, gt)
seg_outputs.append(tmp)
return seg_outputs[::-1] # seg_outputs are ordered so that the seg from the highest layer is first, the seg from
# the bottleneck of the UNet last
else:
return segmentation
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder,
feat_map_mul_on_downscale, batch_size):
"""
This only applies for num_blocks_per_stage and convolutional_upsampling=True
not real vram consumption. just a constant term to which the vram consumption will be approx proportional
(+ offset for parameter storage)
:param patch_size:
:param num_pool_per_axis:
:param base_num_features:
:param max_num_features:
:return:
"""
npool = len(pool_op_kernel_sizes) - 1
current_shape = np.array(patch_size)
tmp = (num_blocks_per_stage_decoder[-1] + 1) * np.prod(current_shape) * base_num_features + num_classes * np.prod(current_shape)
num_feat = base_num_features
for p in range(1, npool):
current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
num_convs = num_blocks_per_stage_decoder[-(p+1)] + 1
print(p, num_feat, num_convs, current_shape)
tmp += num_convs * np.prod(current_shape) * num_feat
return tmp * batch_size
class PlainConvUNet(SegmentationNetwork):
use_this_for_batch_size_computation_2D = 1167982592.0
use_this_for_batch_size_computation_3D = 1152286720.0
def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
deep_supervision=False, upscale_logits=False, max_features=512, initializer=None):
super(PlainConvUNet, self).__init__()
self.conv_op = props['conv_op']
self.num_classes = num_classes
self.encoder = PlainConvUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
props, default_return_skips=True, max_num_features=max_features)
self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
deep_supervision, upscale_logits)
if initializer is not None:
self.apply(initializer)
def forward(self, x):
skips = self.encoder(x)
return self.decoder(skips)
@staticmethod
def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, num_classes, pool_op_kernel_sizes, num_blocks_per_stage_encoder,
num_blocks_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
enc = PlainConvUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_modalities, pool_op_kernel_sizes,
num_blocks_per_stage_encoder,
feat_map_mul_on_downscale, batch_size)
dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
num_classes, pool_op_kernel_sizes,
num_blocks_per_stage_decoder,
feat_map_mul_on_downscale, batch_size)
return enc + dec
@staticmethod
def compute_reference_for_vram_consumption_3d():
patch_size = (160, 128, 128)
pool_op_kernel_sizes = ((1, 1, 1),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2))
conv_per_stage_encoder = (2, 2, 2, 2, 2, 2)
conv_per_stage_decoder = (2, 2, 2, 2, 2)
return PlainConvUNet.compute_approx_vram_consumption(patch_size, 32, 512, 4, 3, pool_op_kernel_sizes,
conv_per_stage_encoder, conv_per_stage_decoder, 2, 2)
@staticmethod
def compute_reference_for_vram_consumption_2d():
patch_size = (256, 256)
pool_op_kernel_sizes = (
(1, 1), # (256, 256)
(2, 2), # (128, 128)
(2, 2), # (64, 64)
(2, 2), # (32, 32)
(2, 2), # (16, 16)
(2, 2), # (8, 8)
(2, 2) # (4, 4)
)
conv_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2)
conv_per_stage_decoder = (2, 2, 2, 2, 2, 2)
return PlainConvUNet.compute_approx_vram_consumption(patch_size, 32, 512, 4, 3, pool_op_kernel_sizes,
conv_per_stage_encoder, conv_per_stage_decoder, 2, 56)
if __name__ == "__main__":
conv_op_kernel_sizes = ((3, 3),
(3, 3),
(3, 3),
(3, 3),
(3, 3),
(3, 3),
(3, 3))
pool_op_kernel_sizes = ((1, 1),
(2, 2),
(2, 2),
(2, 2),
(2, 2),
(2, 2),
(2, 2))
patch_size = (256, 256)
batch_size = 56
unet = PlainConvUNet(4, 32, (2, 2, 2, 2, 2, 2, 2), 2, pool_op_kernel_sizes, conv_op_kernel_sizes,
get_default_network_config(2, dropout_p=None), 4, (2, 2, 2, 2, 2, 2), False, False, max_features=512).cuda()
optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
unet.compute_reference_for_vram_consumption_3d()
unet.compute_reference_for_vram_consumption_2d()
dummy_input = torch.rand((batch_size, 4, *patch_size)).cuda()
dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * 4).round().clamp_(0, 3).cuda().long()
optimizer.zero_grad()
skips = unet.encoder(dummy_input)
print([i.shape for i in skips])
loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,
'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, {})
output = unet.decoder(skips)
l = loss(output, dummy_gt)
l.backward()
optimizer.step()
import hiddenlayer as hl
g = hl.build_graph(unet, dummy_input)
g.save("/home/fabian/test.pdf")
"""conv_op_kernel_sizes = ((3, 3, 3),
(3, 3, 3),
(3, 3, 3),
(3, 3, 3),
(3, 3, 3),
(3, 3, 3))
pool_op_kernel_sizes = ((1, 1, 1),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2),
(2, 2, 2))
patch_size = (160, 128, 128)
unet = PlainConvUNet(4, 32, (2, 2, 2, 2, 2, 2), 2, pool_op_kernel_sizes, conv_op_kernel_sizes,
get_default_network_config(3, dropout_p=None), 4, (2, 2, 2, 2, 2), False, False, max_features=512).cuda()
optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
unet.compute_reference_for_vram_consumption_3d()
unet.compute_reference_for_vram_consumption_2d()
dummy_input = torch.rand((2, 4, *patch_size)).cuda()
dummy_gt = (torch.rand((2, 1, *patch_size)) * 4).round().clamp_(0, 3).cuda().long()
optimizer.zero_grad()
skips = unet.encoder(dummy_input)
print([i.shape for i in skips])
loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,
'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, {})
output = unet.decoder(skips)
l = loss(output, dummy_gt)
l.backward()
optimizer.step()
import hiddenlayer as hl
g = hl.build_graph(unet, dummy_input)
g.save("/home/fabian/test.pdf")"""