# Used for Models Genesis import math import torch import torch.nn as nn import torch.nn.functional as F from backbones.classifier import FracClassifier class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 5D input (got {}D input)'.format(input.dim())) def forward(self, input): self._check_input_dim(input) return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps) class LUConv(nn.Module): def __init__(self, in_chan, out_chan, act): super(LUConv, self).__init__() self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1) self.bn1 = ContBatchNorm3d(out_chan) if act == 'relu': self.activation = nn.ReLU(out_chan) elif act == 'prelu': self.activation = nn.PReLU(out_chan) elif act == 'elu': self.activation = nn.ELU(inplace=True) else: raise def forward(self, x): out = self.activation(self.bn1(self.conv1(x))) return out def _make_nConv(in_channel, depth, act, double_chnnel=False): if double_chnnel: layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act) layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act) else: layer1 = LUConv(in_channel, 32*(2**depth),act) layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act) return nn.Sequential(layer1,layer2) class DownTransition(nn.Module): def __init__(self, in_channel,depth, act): super(DownTransition, self).__init__() self.ops = _make_nConv(in_channel, depth,act) self.maxpool = nn.MaxPool3d(2) self.current_depth = depth def forward(self, x): if self.current_depth == 3: out = self.ops(x) out_before_pool = out else: out_before_pool = self.ops(x) out = self.maxpool(out_before_pool) return out, out_before_pool class UpTransition(nn.Module): def __init__(self, inChans, outChans, depth,act): super(UpTransition, self).__init__() self.depth = depth self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2) self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True) def forward(self, x, skip_x): out_up_conv = self.up_conv(x) concat = torch.cat((out_up_conv,skip_x),1) out = self.ops(concat) return out class OutputTransition(nn.Module): def __init__(self, inChans, n_labels): super(OutputTransition, self).__init__() self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1) #self.sigmoid = nn.Sigmoid() def forward(self, x): out = torch.sigmoid(self.final_conv(x)) return out class UNet3D(nn.Module): # the number of convolutions in each layer corresponds # to what is in the actual prototxt, not the intent def __init__(self, input_size, n_class=1, act='relu', in_channels=1): super(UNet3D, self).__init__() self.down_tr64 = DownTransition(in_channels,0,act) self.down_tr128 = DownTransition(64,1,act) self.down_tr256 = DownTransition(128,2,act) self.down_tr512 = DownTransition(256,3,act) # Classification self.classifier = FracClassifier(encoder_channels=512, final_channels=n_class, linear_kernel=int(math.pow(input_size / 32, 3) * 512)) def forward(self, x): self.out64, _ = self.down_tr64(x) self.out128, _ = self.down_tr128(self.out64) self.out256, _ = self.down_tr256(self.out128) self.out512, _ = self.down_tr512(self.out256) self.out = self.classifier(self.out512) return self.out