Chris Xiao
upload files
2ca2f68
raw
history blame
1.57 kB
import monai
import torch
import itk
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Sequence, Tuple, Union
from monai.networks.layers.factories import Act, Norm
def segNet(
spatial_dim: int,
in_channel: int,
out_channel: int,
channel: Sequence[int],
stride: Sequence[int],
num_res_unit: int = 0,
acts: Union[Tuple, str] = Act.PRELU,
norms: Union[Tuple, str] = Norm.INSTANCE,
dropouts: float = 0.0,
):
seg_net = monai.networks.nets.UNet(
spatial_dims=spatial_dim, # spatial dims
in_channels=in_channel, # input channels
out_channels=out_channel, # output channels
channels=channel, # channel sequence
strides=stride, # convolutional strides
dropout=dropouts,
act=acts,
norm=norms,
num_res_units=num_res_unit
)
return seg_net
def regNet(
spatial_dim: int,
in_channel: int,
out_channel: int,
channel: Sequence[int],
stride: Sequence[int],
num_res_unit: int = 0,
acts: Union[Tuple, str] = Act.PRELU,
norms: Union[Tuple, str] = Norm.INSTANCE,
dropouts: float = 0.0,
):
reg_net = monai.networks.nets.UNet(
spatial_dims=spatial_dim, # spatial dims
in_channels=in_channel, # input channels
out_channels=out_channel, # output channels
channels=channel, # channel sequence
strides=stride, # convolutional strides
dropout=dropouts,
act=acts,
norm=norms,
num_res_units=num_res_unit
)
return reg_net