katielink's picture
complete the model package
cd6dcce
#!/usr/bin/env python3
from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
class UNesTBlock(nn.Module):
""" """
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int, # type: ignore
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
res_block: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
res_block: bool argument to determine if residual block is used.
"""
super(UNesTBlock, self).__init__()
upsample_stride = upsample_kernel_size
self.transp_conv = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
if res_block:
self.conv_block = UnetResBlock(
spatial_dims,
out_channels + out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
norm_name=norm_name,
)
else:
self.conv_block = UnetBasicBlock( # type: ignore
spatial_dims,
out_channels + out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
norm_name=norm_name,
)
def forward(self, inp, skip):
# number of channels for skip should equals to out_channels
out = self.transp_conv(inp)
# print(out.shape)
# print(skip.shape)
out = torch.cat((out, skip), dim=1)
out = self.conv_block(out)
return out
class UNestUpBlock(nn.Module):
""" """
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
num_layer: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
upsample_kernel_size: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
conv_block: bool = False,
res_block: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
num_layer: number of upsampling blocks.
kernel_size: convolution kernel size.
stride: convolution stride.
upsample_kernel_size: convolution kernel size for transposed convolution layers.
norm_name: feature normalization type and arguments.
conv_block: bool argument to determine if convolutional block is used.
res_block: bool argument to determine if residual block is used.
"""
super().__init__()
upsample_stride = upsample_kernel_size
self.transp_conv_init = get_conv_layer(
spatial_dims,
in_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
)
if conv_block:
if res_block:
self.blocks = nn.ModuleList(
[
nn.Sequential(
get_conv_layer(
spatial_dims,
out_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
),
UnetResBlock(
spatial_dims=3,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
),
)
for i in range(num_layer)
]
)
else:
self.blocks = nn.ModuleList(
[
nn.Sequential(
get_conv_layer(
spatial_dims,
out_channels,
out_channels,
kernel_size=upsample_kernel_size,
stride=upsample_stride,
conv_only=True,
is_transposed=True,
),
UnetBasicBlock(
spatial_dims=3,
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
),
)
for i in range(num_layer)
]
)
else:
self.blocks = nn.ModuleList(
[
get_conv_layer(
spatial_dims,
out_channels,
out_channels,
kernel_size=1,
stride=1,
conv_only=True,
is_transposed=True,
)
for i in range(num_layer)
]
)
def forward(self, x):
x = self.transp_conv_init(x)
for blk in self.blocks:
x = blk(x)
return x
class UNesTConvBlock(nn.Module):
"""
UNesT block with skip connections
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[Sequence[int], int],
stride: Union[Sequence[int], int],
norm_name: Union[Tuple, str],
res_block: bool = False,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
kernel_size: convolution kernel size.
stride: convolution stride.
norm_name: feature normalization type and arguments.
res_block: bool argument to determine if residual block is used.
"""
super().__init__()
if res_block:
self.layer = UnetResBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
)
else:
self.layer = UnetBasicBlock( # type: ignore
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
norm_name=norm_name,
)
def forward(self, inp):
out = self.layer(inp)
return out