|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from HD_BET.utils import softmax_helper |
|
|
|
|
|
class EncodingModule(nn.Module): |
|
def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True): |
|
nn.Module.__init__(self) |
|
self.dropout_p = dropout_p |
|
self.lrelu_inplace = lrelu_inplace |
|
self.inst_norm_affine = inst_norm_affine |
|
self.conv_bias = conv_bias |
|
self.leakiness = leakiness |
|
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) |
|
self.dropout = nn.Dropout3d(dropout_p) |
|
self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) |
|
|
|
def forward(self, x): |
|
skip = x |
|
x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) |
|
x = self.conv1(x) |
|
if self.dropout_p is not None and self.dropout_p > 0: |
|
x = self.dropout(x) |
|
x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) |
|
x = self.conv2(x) |
|
x = x + skip |
|
return x |
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): |
|
super(Upsample, self).__init__() |
|
self.align_corners = align_corners |
|
self.mode = mode |
|
self.scale_factor = scale_factor |
|
self.size = size |
|
|
|
def forward(self, x): |
|
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, |
|
align_corners=self.align_corners) |
|
|
|
|
|
class LocalizationModule(nn.Module): |
|
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, |
|
lrelu_inplace=True): |
|
nn.Module.__init__(self) |
|
self.lrelu_inplace = lrelu_inplace |
|
self.inst_norm_affine = inst_norm_affine |
|
self.conv_bias = conv_bias |
|
self.leakiness = leakiness |
|
self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) |
|
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) |
|
self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) |
|
|
|
def forward(self, x): |
|
x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) |
|
x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) |
|
return x |
|
|
|
|
|
class UpsamplingModule(nn.Module): |
|
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, |
|
lrelu_inplace=True): |
|
nn.Module.__init__(self) |
|
self.lrelu_inplace = lrelu_inplace |
|
self.inst_norm_affine = inst_norm_affine |
|
self.conv_bias = conv_bias |
|
self.leakiness = leakiness |
|
self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) |
|
self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) |
|
self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) |
|
|
|
def forward(self, x): |
|
x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, |
|
inplace=self.lrelu_inplace) |
|
return x |
|
|
|
|
|
class DownsamplingModule(nn.Module): |
|
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, |
|
lrelu_inplace=True): |
|
nn.Module.__init__(self) |
|
self.lrelu_inplace = lrelu_inplace |
|
self.inst_norm_affine = inst_norm_affine |
|
self.conv_bias = conv_bias |
|
self.leakiness = leakiness |
|
self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) |
|
|
|
def forward(self, x): |
|
x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) |
|
b = self.downsample(x) |
|
return x, b |
|
|
|
|
|
class Network(nn.Module): |
|
def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, |
|
final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, |
|
lrelu_inplace=True, do_ds=True): |
|
super(Network, self).__init__() |
|
|
|
self.do_ds = do_ds |
|
self.lrelu_inplace = lrelu_inplace |
|
self.inst_norm_affine = inst_norm_affine |
|
self.conv_bias = conv_bias |
|
self.leakiness = leakiness |
|
self.final_nonlin = final_nonlin |
|
self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) |
|
|
|
self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, |
|
conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) |
|
self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) |
|
self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, |
|
inst_norm_affine=True, lrelu_inplace=True) |
|
|
|
self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) |
|
self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) |
|
self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) |
|
self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) |
|
|
|
def forward(self, x): |
|
seg_outputs = [] |
|
|
|
x = self.init_conv(x) |
|
x = self.context1(x) |
|
|
|
skip1, x = self.down1(x) |
|
x = self.context2(x) |
|
|
|
skip2, x = self.down2(x) |
|
x = self.context3(x) |
|
|
|
skip3, x = self.down3(x) |
|
x = self.context4(x) |
|
|
|
skip4, x = self.down4(x) |
|
x = self.context5(x) |
|
|
|
x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) |
|
x = self.up1(x) |
|
|
|
x = torch.cat((skip4, x), dim=1) |
|
x = self.loc1(x) |
|
x = self.up2(x) |
|
|
|
x = torch.cat((skip3, x), dim=1) |
|
x = self.loc2(x) |
|
loc2_seg = self.final_nonlin(self.loc2_seg(x)) |
|
seg_outputs.append(loc2_seg) |
|
x = self.up3(x) |
|
|
|
x = torch.cat((skip2, x), dim=1) |
|
x = self.loc3(x) |
|
loc3_seg = self.final_nonlin(self.loc3_seg(x)) |
|
seg_outputs.append(loc3_seg) |
|
x = self.up4(x) |
|
|
|
x = torch.cat((skip1, x), dim=1) |
|
x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, |
|
inplace=self.lrelu_inplace) |
|
x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, |
|
inplace=self.lrelu_inplace) |
|
x = self.final_nonlin(self.seg_layer(x)) |
|
seg_outputs.append(x) |
|
|
|
if self.do_ds: |
|
return seg_outputs[::-1] |
|
else: |
|
return seg_outputs[-1] |
|
|