| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from HD_BET.utils import softmax_helper | |
| class EncodingModule(nn.Module): | |
| def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True): | |
| nn.Module.__init__(self) | |
| self.dropout_p = dropout_p | |
| self.lrelu_inplace = lrelu_inplace | |
| self.inst_norm_affine = inst_norm_affine | |
| self.conv_bias = conv_bias | |
| self.leakiness = leakiness | |
| self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) | |
| self.dropout = nn.Dropout3d(dropout_p) | |
| self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) | |
| def forward(self, x): | |
| skip = x | |
| x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) | |
| x = self.conv1(x) | |
| if self.dropout_p is not None and self.dropout_p > 0: | |
| x = self.dropout(x) | |
| x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) | |
| x = self.conv2(x) | |
| x = x + skip | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): | |
| super(Upsample, self).__init__() | |
| self.align_corners = align_corners | |
| self.mode = mode | |
| self.scale_factor = scale_factor | |
| self.size = size | |
| def forward(self, x): | |
| return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, | |
| align_corners=self.align_corners) | |
| class LocalizationModule(nn.Module): | |
| def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, | |
| lrelu_inplace=True): | |
| nn.Module.__init__(self) | |
| self.lrelu_inplace = lrelu_inplace | |
| self.inst_norm_affine = inst_norm_affine | |
| self.conv_bias = conv_bias | |
| self.leakiness = leakiness | |
| self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) | |
| self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) | |
| self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) | |
| def forward(self, x): | |
| x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) | |
| x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) | |
| return x | |
| class UpsamplingModule(nn.Module): | |
| def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, | |
| lrelu_inplace=True): | |
| nn.Module.__init__(self) | |
| self.lrelu_inplace = lrelu_inplace | |
| self.inst_norm_affine = inst_norm_affine | |
| self.conv_bias = conv_bias | |
| self.leakiness = leakiness | |
| self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) | |
| self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) | |
| self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) | |
| def forward(self, x): | |
| x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, | |
| inplace=self.lrelu_inplace) | |
| return x | |
| class DownsamplingModule(nn.Module): | |
| def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, | |
| lrelu_inplace=True): | |
| nn.Module.__init__(self) | |
| self.lrelu_inplace = lrelu_inplace | |
| self.inst_norm_affine = inst_norm_affine | |
| self.conv_bias = conv_bias | |
| self.leakiness = leakiness | |
| self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) | |
| def forward(self, x): | |
| x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) | |
| b = self.downsample(x) | |
| return x, b | |
| class Network(nn.Module): | |
| def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, | |
| final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, | |
| lrelu_inplace=True, do_ds=True): | |
| super(Network, self).__init__() | |
| self.do_ds = do_ds | |
| self.lrelu_inplace = lrelu_inplace | |
| self.inst_norm_affine = inst_norm_affine | |
| self.conv_bias = conv_bias | |
| self.leakiness = leakiness | |
| self.final_nonlin = final_nonlin | |
| self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) | |
| self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, | |
| conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) | |
| self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) | |
| self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) | |
| self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, | |
| inst_norm_affine=True, lrelu_inplace=True) | |
| self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) | |
| self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) | |
| self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) | |
| self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) | |
| def forward(self, x): | |
| seg_outputs = [] | |
| x = self.init_conv(x) | |
| x = self.context1(x) | |
| skip1, x = self.down1(x) | |
| x = self.context2(x) | |
| skip2, x = self.down2(x) | |
| x = self.context3(x) | |
| skip3, x = self.down3(x) | |
| x = self.context4(x) | |
| skip4, x = self.down4(x) | |
| x = self.context5(x) | |
| x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) | |
| x = self.up1(x) | |
| x = torch.cat((skip4, x), dim=1) | |
| x = self.loc1(x) | |
| x = self.up2(x) | |
| x = torch.cat((skip3, x), dim=1) | |
| x = self.loc2(x) | |
| loc2_seg = self.final_nonlin(self.loc2_seg(x)) | |
| seg_outputs.append(loc2_seg) | |
| x = self.up3(x) | |
| x = torch.cat((skip2, x), dim=1) | |
| x = self.loc3(x) | |
| loc3_seg = self.final_nonlin(self.loc3_seg(x)) | |
| seg_outputs.append(loc3_seg) | |
| x = self.up4(x) | |
| x = torch.cat((skip1, x), dim=1) | |
| x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, | |
| inplace=self.lrelu_inplace) | |
| x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, | |
| inplace=self.lrelu_inplace) | |
| x = self.final_nonlin(self.seg_layer(x)) | |
| seg_outputs.append(x) | |
| if self.do_ds: | |
| return seg_outputs[::-1] | |
| else: | |
| return seg_outputs[-1] | |