Spaces:
Running
Running
from typing import Sequence, Union | |
import torch | |
import torch.nn as nn | |
from monai.networks.blocks.convolutions import Convolution, ResidualUnit | |
from monai.networks.layers.factories import Act, Norm | |
from monai.networks.layers.simplelayers import SkipConnection | |
from monai.utils import alias, export | |
class UNet_double(nn.Module): | |
def __init__( | |
self, | |
dimensions: int, | |
in_channels: int, | |
out_channels: int, | |
channels: Sequence[int], | |
strides: Sequence[int], | |
kernel_size: Union[Sequence[int], int] = 3, | |
up_kernel_size: Union[Sequence[int], int] = 3, | |
num_res_units: int = 0, | |
act=Act.PRELU, | |
norm=Norm.INSTANCE, | |
dropout=0.0,) -> None: | |
super().__init__() | |
self.dimensions = dimensions | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.channels = channels | |
self.strides = strides | |
self.kernel_size = kernel_size | |
self.up_kernel_size = up_kernel_size | |
self.num_res_units = num_res_units | |
self.act = act | |
self.norm = norm | |
self.dropout = dropout | |
def _create_block( | |
inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool) -> nn.Sequential: | |
c = channels[0] | |
s = strides[0] | |
subblock: nn.Module | |
if len(channels) > 2: | |
subblock1, subblock2 = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down | |
upc = c * 2 | |
else: | |
# the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer | |
subblock = self._get_bottom_layer(c, channels[1]) | |
upc = c + channels[1] | |
down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path | |
up1 = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path | |
up2 = self._get_up_layer(upc, outc, s, is_top) | |
return nn.Sequential(down, SkipConnection(subblock), up1), nn.Sequential(down, SkipConnection(subblock), up2) | |
down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path | |
up1 = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path | |
up2 = self._get_up_layer(upc, outc, s, is_top) | |
return nn.Sequential(down, SkipConnection(subblock1), up1), nn.Sequential(down, SkipConnection(subblock2), up2) | |
self.model1, self.model2 = _create_block(in_channels, out_channels, self.channels, self.strides, True) | |
self.activation = nn.Sigmoid() | |
def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: | |
""" | |
Args: | |
in_channels: number of input channels. | |
out_channels: number of output channels. | |
strides: convolution stride. | |
is_top: True if this is the top block. | |
""" | |
if self.num_res_units > 0: | |
return ResidualUnit( | |
self.dimensions, | |
in_channels, | |
out_channels, | |
strides=strides, | |
kernel_size=self.kernel_size, | |
subunits=self.num_res_units, | |
act=self.act, | |
norm=self.norm, | |
dropout=self.dropout, | |
) | |
return Convolution( | |
self.dimensions, | |
in_channels, | |
out_channels, | |
strides=strides, | |
kernel_size=self.kernel_size, | |
act=self.act, | |
norm=self.norm, | |
dropout=self.dropout, | |
) | |
def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: | |
""" | |
Args: | |
in_channels: number of input channels. | |
out_channels: number of output channels. | |
""" | |
return self._get_down_layer(in_channels, out_channels, 1, False) | |
def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: | |
""" | |
Args: | |
in_channels: number of input channels. | |
out_channels: number of output channels. | |
strides: convolution stride. | |
is_top: True if this is the top block. | |
""" | |
conv: Union[Convolution, nn.Sequential] | |
conv = Convolution( | |
self.dimensions, | |
in_channels, | |
out_channels, | |
strides=strides, | |
kernel_size=self.up_kernel_size, | |
act=self.act, | |
norm=self.norm, | |
dropout=self.dropout, | |
conv_only=is_top and self.num_res_units == 0, | |
is_transposed=True, | |
) | |
if self.num_res_units > 0: | |
ru = ResidualUnit( | |
self.dimensions, | |
out_channels, | |
out_channels, | |
strides=1, | |
kernel_size=self.kernel_size, | |
subunits=1, | |
act=self.act, | |
norm=self.norm, | |
dropout=self.dropout, | |
last_conv_only=is_top, | |
) | |
conv = nn.Sequential(conv, ru) | |
return conv | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.activation(self.model1(x)) | |
def train_step(self, image, segment, criterion, segbox = None): | |
forwarded = self.forward_train(image) | |
target = torch.cat((segment, segbox), dim = 1) | |
return criterion(forwarded, target) | |
def forward_train(self, image): | |
x1 = self.model1(image) | |
x2 = self.model2(image) | |
return self.activation(torch.cat((x1, x2), dim = 1)) | |