|
|
|
|
|
from typing import Sequence, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer |
|
|
|
|
|
class UNesTBlock(nn.Module): |
|
""" """ |
|
|
|
def __init__( |
|
self, |
|
spatial_dims: int, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: Union[Sequence[int], int], |
|
stride: Union[Sequence[int], int], |
|
upsample_kernel_size: Union[Sequence[int], int], |
|
norm_name: Union[Tuple, str], |
|
res_block: bool = False, |
|
) -> None: |
|
""" |
|
Args: |
|
spatial_dims: number of spatial dimensions. |
|
in_channels: number of input channels. |
|
out_channels: number of output channels. |
|
kernel_size: convolution kernel size. |
|
stride: convolution stride. |
|
upsample_kernel_size: convolution kernel size for transposed convolution layers. |
|
norm_name: feature normalization type and arguments. |
|
res_block: bool argument to determine if residual block is used. |
|
|
|
""" |
|
|
|
super(UNesTBlock, self).__init__() |
|
upsample_stride = upsample_kernel_size |
|
self.transp_conv = get_conv_layer( |
|
spatial_dims, |
|
in_channels, |
|
out_channels, |
|
kernel_size=upsample_kernel_size, |
|
stride=upsample_stride, |
|
conv_only=True, |
|
is_transposed=True, |
|
) |
|
|
|
if res_block: |
|
self.conv_block = UnetResBlock( |
|
spatial_dims, |
|
out_channels + out_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
norm_name=norm_name, |
|
) |
|
else: |
|
self.conv_block = UnetBasicBlock( |
|
spatial_dims, |
|
out_channels + out_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
norm_name=norm_name, |
|
) |
|
|
|
def forward(self, inp, skip): |
|
|
|
out = self.transp_conv(inp) |
|
|
|
|
|
out = torch.cat((out, skip), dim=1) |
|
out = self.conv_block(out) |
|
return out |
|
|
|
|
|
class UNestUpBlock(nn.Module): |
|
""" """ |
|
|
|
def __init__( |
|
self, |
|
spatial_dims: int, |
|
in_channels: int, |
|
out_channels: int, |
|
num_layer: int, |
|
kernel_size: Union[Sequence[int], int], |
|
stride: Union[Sequence[int], int], |
|
upsample_kernel_size: Union[Sequence[int], int], |
|
norm_name: Union[Tuple, str], |
|
conv_block: bool = False, |
|
res_block: bool = False, |
|
) -> None: |
|
""" |
|
Args: |
|
spatial_dims: number of spatial dimensions. |
|
in_channels: number of input channels. |
|
out_channels: number of output channels. |
|
num_layer: number of upsampling blocks. |
|
kernel_size: convolution kernel size. |
|
stride: convolution stride. |
|
upsample_kernel_size: convolution kernel size for transposed convolution layers. |
|
norm_name: feature normalization type and arguments. |
|
conv_block: bool argument to determine if convolutional block is used. |
|
res_block: bool argument to determine if residual block is used. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
upsample_stride = upsample_kernel_size |
|
self.transp_conv_init = get_conv_layer( |
|
spatial_dims, |
|
in_channels, |
|
out_channels, |
|
kernel_size=upsample_kernel_size, |
|
stride=upsample_stride, |
|
conv_only=True, |
|
is_transposed=True, |
|
) |
|
if conv_block: |
|
if res_block: |
|
self.blocks = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
get_conv_layer( |
|
spatial_dims, |
|
out_channels, |
|
out_channels, |
|
kernel_size=upsample_kernel_size, |
|
stride=upsample_stride, |
|
conv_only=True, |
|
is_transposed=True, |
|
), |
|
UnetResBlock( |
|
spatial_dims=3, |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
norm_name=norm_name, |
|
), |
|
) |
|
for i in range(num_layer) |
|
] |
|
) |
|
else: |
|
self.blocks = nn.ModuleList( |
|
[ |
|
nn.Sequential( |
|
get_conv_layer( |
|
spatial_dims, |
|
out_channels, |
|
out_channels, |
|
kernel_size=upsample_kernel_size, |
|
stride=upsample_stride, |
|
conv_only=True, |
|
is_transposed=True, |
|
), |
|
UnetBasicBlock( |
|
spatial_dims=3, |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
norm_name=norm_name, |
|
), |
|
) |
|
for i in range(num_layer) |
|
] |
|
) |
|
else: |
|
self.blocks = nn.ModuleList( |
|
[ |
|
get_conv_layer( |
|
spatial_dims, |
|
out_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
conv_only=True, |
|
is_transposed=True, |
|
) |
|
for i in range(num_layer) |
|
] |
|
) |
|
|
|
def forward(self, x): |
|
x = self.transp_conv_init(x) |
|
for blk in self.blocks: |
|
x = blk(x) |
|
return x |
|
|
|
|
|
class UNesTConvBlock(nn.Module): |
|
""" |
|
UNesT block with skip connections |
|
""" |
|
|
|
def __init__( |
|
self, |
|
spatial_dims: int, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: Union[Sequence[int], int], |
|
stride: Union[Sequence[int], int], |
|
norm_name: Union[Tuple, str], |
|
res_block: bool = False, |
|
) -> None: |
|
""" |
|
Args: |
|
spatial_dims: number of spatial dimensions. |
|
in_channels: number of input channels. |
|
out_channels: number of output channels. |
|
kernel_size: convolution kernel size. |
|
stride: convolution stride. |
|
norm_name: feature normalization type and arguments. |
|
res_block: bool argument to determine if residual block is used. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
if res_block: |
|
self.layer = UnetResBlock( |
|
spatial_dims=spatial_dims, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
norm_name=norm_name, |
|
) |
|
else: |
|
self.layer = UnetBasicBlock( |
|
spatial_dims=spatial_dims, |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
norm_name=norm_name, |
|
) |
|
|
|
def forward(self, inp): |
|
out = self.layer(inp) |
|
return out |
|
|