#!/usr/bin/env python3 from typing import Sequence, Tuple, Union import torch import torch.nn as nn from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer class UNesTBlock(nn.Module): """ """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, # type: ignore kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], res_block: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. res_block: bool argument to determine if residual block is used. """ super(UNesTBlock, self).__init__() upsample_stride = upsample_kernel_size self.transp_conv = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) if res_block: self.conv_block = UnetResBlock( spatial_dims, out_channels + out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, ) else: self.conv_block = UnetBasicBlock( # type: ignore spatial_dims, out_channels + out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, ) def forward(self, inp, skip): # number of channels for skip should equals to out_channels out = self.transp_conv(inp) # print(out.shape) # print(skip.shape) out = torch.cat((out, skip), dim=1) out = self.conv_block(out) return out class UNestUpBlock(nn.Module): """ """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, num_layer: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], upsample_kernel_size: Union[Sequence[int], int], norm_name: Union[Tuple, str], conv_block: bool = False, res_block: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. num_layer: number of upsampling blocks. kernel_size: convolution kernel size. stride: convolution stride. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. conv_block: bool argument to determine if convolutional block is used. res_block: bool argument to determine if residual block is used. """ super().__init__() upsample_stride = upsample_kernel_size self.transp_conv_init = get_conv_layer( spatial_dims, in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ) if conv_block: if res_block: self.blocks = nn.ModuleList( [ nn.Sequential( get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ), UnetResBlock( spatial_dims=3, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ), ) for i in range(num_layer) ] ) else: self.blocks = nn.ModuleList( [ nn.Sequential( get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_stride, conv_only=True, is_transposed=True, ), UnetBasicBlock( spatial_dims=3, in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ), ) for i in range(num_layer) ] ) else: self.blocks = nn.ModuleList( [ get_conv_layer( spatial_dims, out_channels, out_channels, kernel_size=1, stride=1, conv_only=True, is_transposed=True, ) for i in range(num_layer) ] ) def forward(self, x): x = self.transp_conv_init(x) for blk in self.blocks: x = blk(x) return x class UNesTConvBlock(nn.Module): """ UNesT block with skip connections """ def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], norm_name: Union[Tuple, str], res_block: bool = False, ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. stride: convolution stride. norm_name: feature normalization type and arguments. res_block: bool argument to determine if residual block is used. """ super().__init__() if res_block: self.layer = UnetResBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ) else: self.layer = UnetBasicBlock( # type: ignore spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, norm_name=norm_name, ) def forward(self, inp): out = self.layer(inp) return out