katielink's picture
fixed the dimension in convolution according to MONAI 1.0 update
9272243
#!/usr/bin/env python3
"""
The 3D NEST transformer based segmentation model
MASI Lab, Vanderbilty University
Authors: Xin Yu, Yinchi Zhou, Yucheng Tang, Bennett Landman
The NEST code is partly from
Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and
Interpretable Visual Understanding
https://arxiv.org/pdf/2105.12723.pdf
"""
# limitations under the License.
from typing import Sequence, Tuple, Union
import torch
import torch.nn as nn
from monai.networks.blocks import Convolution
from monai.networks.blocks.dynunet_block import UnetOutBlock
# from scripts.networks.swin_transformer_3d import SwinTransformer3D
from scripts.networks.nest_transformer_3D import NestTransformer3D
from scripts.networks.unest_block import UNesTBlock, UNesTConvBlock, UNestUpBlock
# from monai.networks.blocks.unetr_block import UnetstrBasicBlock, UnetrPrUpBlock, UnetResBlock
class UNesT(nn.Module):
"""
UNesT model implementation
"""
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: Sequence[int] = (96, 96, 96),
feature_size: int = 16,
patch_size: int = 2,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
window_size: Sequence[int] = (7, 7, 7),
norm_name: Union[Tuple, str] = "instance",
conv_block: bool = False,
res_block: bool = True,
# featResBlock: bool = False,
dropout_rate: float = 0.0,
) -> None:
"""
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
img_size: dimension of input image.
feature_size: dimension of network feature size.
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
norm_name: feature normalization type and arguments.
conv_block: bool argument to determine if convolutional block is used.
res_block: bool argument to determine if residual block is used.
dropout_rate: faction of the input units to drop.
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
self.embed_dim = [128, 256, 512]
self.nestViT = NestTransformer3D(
img_size=96,
in_chans=1,
patch_size=4,
num_levels=3,
embed_dims=(128, 256, 512),
num_heads=(4, 8, 16),
depths=(2, 2, 8),
num_classes=1000,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.5,
norm_layer=None,
act_layer=None,
pad_type="",
weight_init="",
global_pool="avg",
)
self.encoder1 = UNesTConvBlock(
spatial_dims=3,
in_channels=1,
out_channels=feature_size * 2,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder2 = UNestUpBlock(
spatial_dims=3,
in_channels=self.embed_dim[0],
out_channels=feature_size * 4,
num_layer=1,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=False,
res_block=False,
)
self.encoder3 = UNesTConvBlock(
spatial_dims=3,
in_channels=self.embed_dim[0],
out_channels=8 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder4 = UNesTConvBlock(
spatial_dims=3,
in_channels=self.embed_dim[1],
out_channels=16 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.decoder5 = UNesTBlock(
spatial_dims=3,
in_channels=2 * self.embed_dim[2],
out_channels=feature_size * 32,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder4 = UNesTBlock(
spatial_dims=3,
in_channels=self.embed_dim[2],
out_channels=feature_size * 16,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UNesTBlock(
spatial_dims=3,
in_channels=feature_size * 16,
out_channels=feature_size * 8,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UNesTBlock(
spatial_dims=3,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder1 = UNesTBlock(
spatial_dims=3,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
stride=1,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.encoder10 = Convolution(
spatial_dims=3,
in_channels=32 * feature_size,
out_channels=64 * feature_size,
strides=2,
adn_ordering="ADN",
dropout=0.0,
)
self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) # type: ignore
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size)
x = x.permute(0, 4, 1, 2, 3).contiguous()
return x
def load_from(self, weights):
with torch.no_grad():
# copy weights from patch embedding
for i in weights["state_dict"]:
print(i)
self.vit.patch_embedding.position_embeddings.copy_(
weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"]
)
self.vit.patch_embedding.cls_token.copy_(
weights["state_dict"]["module.transformer.patch_embedding.cls_token"]
)
self.vit.patch_embedding.patch_embeddings[1].weight.copy_(
weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.weight"]
)
self.vit.patch_embedding.patch_embeddings[1].bias.copy_(
weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings_3d.1.bias"]
)
# copy weights from encoding blocks (default: num of blocks: 12)
for bname, block in self.vit.blocks.named_children():
print(block)
block.loadFrom(weights, n_block=bname)
# last norm layer of transformer
self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"])
self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"])
def forward(self, x_in):
x, hidden_states_out = self.nestViT(x_in)
enc0 = self.encoder1(x_in) # 2, 32, 96, 96, 96
x1 = hidden_states_out[0] # 2, 128, 24, 24, 24
enc1 = self.encoder2(x1) # 2, 64, 48, 48, 48
x2 = hidden_states_out[1] # 2, 128, 24, 24, 24
enc2 = self.encoder3(x2) # 2, 128, 24, 24, 24
x3 = hidden_states_out[2] # 2, 256, 12, 12, 12
enc3 = self.encoder4(x3) # 2, 256, 12, 12, 12
x4 = hidden_states_out[3]
enc4 = x4 # 2, 512, 6, 6, 6
dec4 = x # 2, 512, 6, 6, 6
dec4 = self.encoder10(dec4) # 2, 1024, 3, 3, 3
dec3 = self.decoder5(dec4, enc4) # 2, 512, 6, 6, 6
dec2 = self.decoder4(dec3, enc3) # 2, 256, 12, 12, 12
dec1 = self.decoder3(dec2, enc2) # 2, 128, 24, 24, 24
dec0 = self.decoder2(dec1, enc1) # 2, 64, 48, 48, 48
out = self.decoder1(dec0, enc0) # 2, 32, 96, 96, 96
logits = self.out(out)
return logits