# --------------------------------------------------------------- # Copyright (c) 2021, NVIDIA Corporation. All rights reserved. # # This work is licensed under the NVIDIA Source Code License # --------------------------------------------------------------- import torch import torch.nn as nn from functools import partial import math from itertools import repeat import collections.abc from typing import Tuple, Union from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock, UnetrPrUpBlock from monai.networks.blocks.dynunet_block import get_conv_layer # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = _ntuple def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor #%% class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.dwconv = DWConv(hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): x = self.fc1(x) x = self.dwconv(x, H, W) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): B, N, C = x.shape q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): x = x + self.drop_path(self.attn(self.norm1(x), H, W)) x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) return x #%% class OverlapPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): x = self.proj(x) # [2, 3, 224, 224]-> [2, 64, 56, 56] # print(f"{x.shape=}") _, _, H, W = x.shape x = x.flatten(2).transpose(1, 2) # [2, 64, 56, 56]-> [2, 3136, 64] # print(f"{x.shape=}") x = self.norm(x) # [2, 3136, 64]-> [2, 3136, 64] # print(f"{x.shape=}") return x, H, W # embed_dims=[64, 128, 256, 512] # patch_embed1 = OverlapPatchEmbed(img_size=224,patch_size=7,stride=4,in_chans=in_chans, embed_dim=64) # x1, H, W = patch_embed1(input_img) # x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], # embed_dim=embed_dims[1]) # x2, H, W = patch_embed2(x1) # x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], # embed_dim=embed_dims[2]) # x3, H, W = patch_embed3(x2) # x3 = x3.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],embed_dim=embed_dims[3]) # x4, H, W = patch_embed4(x3) # x4 = x4.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() #%% class MixVisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): super().__init__() # self.num_classes = num_classes self.depths = depths # patch_embed self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0]) self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) # transformer encoder dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 self.block1 = nn.ModuleList([Block( dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0]) for i in range(depths[0])]) self.norm1 = norm_layer(embed_dims[0]) cur += depths[0] self.block2 = nn.ModuleList([Block( dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1]) for i in range(depths[1])]) self.norm2 = norm_layer(embed_dims[1]) cur += depths[1] self.block3 = nn.ModuleList([Block( dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2]) for i in range(depths[2])]) self.norm3 = norm_layer(embed_dims[2]) cur += depths[2] self.block4 = nn.ModuleList([Block( dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3]) for i in range(depths[3])]) self.norm4 = norm_layer(embed_dims[3]) # classification head # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def init_weights(self, pretrained=None): if isinstance(pretrained, str): # logger = get_root_logger() # load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) # load_checkpoint(self, pretrained, map_location='cpu', strict=False) torch.load(pretrained, map_location='cpu') def reset_drop_path(self, drop_path_rate): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] cur = 0 for i in range(self.depths[0]): self.block1[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[0] for i in range(self.depths[1]): self.block2[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[1] for i in range(self.depths[2]): self.block3[i].drop_path.drop_prob = dpr[cur + i] cur += self.depths[2] for i in range(self.depths[3]): self.block4[i].drop_path.drop_prob = dpr[cur + i] def freeze_patch_emb(self): self.patch_embed1.requires_grad = False @torch.jit.ignore def no_weight_decay(self): return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better def get_classifier(self): return self.head # def reset_classifier(self, num_classes, global_pool=''): # self.num_classes = num_classes # self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): B = x.shape[0] outs = [] # stage 1 x, H, W = self.patch_embed1(x) for i, blk in enumerate(self.block1): x = blk(x, H, W) x = self.norm1(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # stage 2 x, H, W = self.patch_embed2(x) for i, blk in enumerate(self.block2): x = blk(x, H, W) x = self.norm2(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # stage 3 x, H, W = self.patch_embed3(x) for i, blk in enumerate(self.block3): x = blk(x, H, W) x = self.norm3(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # stage 4 x, H, W = self.patch_embed4(x) for i, blk in enumerate(self.block4): x = blk(x, H, W) x = self.norm4(x) x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) return outs def forward(self, x): x = self.forward_features(x) # x = self.head(x) return x class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) return x class mit_b0(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b0, self).__init__( patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) class mit_b1(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b1, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) class mit_b2(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b2, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) class mit_b3(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b3, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) class mit_b4(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b4, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) class mit_b5(MixVisionTransformer): def __init__(self, **kwargs): super(mit_b5, self).__init__( patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) #%% B2 class MiT_B2_UNet_MultiHead(nn.Module): def __init__(self, in_channels: int, out_channels: int, regress_class: int = 1, img_size: Tuple[int, int] = (256,256), feature_size: int = 16, spatial_dims: int = 2, # hidden_size: int = 768, # mlp_dim: int = 3072, num_heads = [1, 2, 4, 8], # pos_embed: str = "perceptron", norm_name: Union[Tuple, str] = "instance", conv_block: bool = False, res_block: bool = True, dropout_rate: float = 0.0, debug: bool = False ): super().__init__() self.debug = debug self.mit_b3 = MixVisionTransformer(img_size=img_size, patch_size=4, embed_dims=[feature_size*2, feature_size*4, feature_size*8, feature_size*16], num_heads=num_heads, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) self.encoder1 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder2 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=2 * feature_size, out_channels=2 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder3 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=4 * feature_size, out_channels=4 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder4 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=8 * feature_size, out_channels=8 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder5 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=16 * feature_size, out_channels=16 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.decoder4 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 16, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder3 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder2 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.transp_conv = get_conv_layer( spatial_dims=2, in_channels=feature_size*2, out_channels=feature_size*2, kernel_size=3, stride=2, conv_only=True, is_transposed=True, ) self.decoder1 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.out_interior = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore self.out_dist = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=1) # type: ignore def forward(self, x_in): hidden_states_out = self.mit_b3(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768) enc1 = self.encoder1(x_in) # (B, 16, 256, 256) x1 = hidden_states_out[0] # (B, 32, 64, 64) enc2 = self.encoder2(x1) # (B, 64, 32, 32) x2 = hidden_states_out[1] # (B, 64, 32, 32) enc3 = self.encoder3(x2) # (B, 128, 16, 16) x3 = hidden_states_out[2] # (B, 128, 16,16) enc4 = self.encoder4(x3) # (B, 256, 8, 8) x4 = hidden_states_out[3] # (B, 256, 8, 8) enc5 = self.encoder5(x4) # (B, 256, 8, 8) # print(f"{enc1.shape=}, {enc2.shape=}, {enc3.shape=}, {enc4.shape=}, {enc5.shape=}") dec4 = self.decoder4(enc5, enc4) # (B, 128, 16, 16); up -> cat -> ResConv; (B, 128, 16, 16) dec3 = self.decoder3(dec4, enc3) # (B, 64, 32, 32) dec2 = self.decoder2(dec3, enc2) # (B, 32, 64, 64) dec2_up = self.transp_conv(dec2) # [B, 32, 128, 128] dec1 = self.decoder1(dec2_up, enc1) # (B, 16, 256, 256) logits = self.out_interior(dec1) dist = self.out_dist(dec1) if self.debug: return hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits else: return logits, dist # print(f"{dec1.shape=}, {dec2.shape=}, {dec3.shape=}, {dec4.shape=}, {logits.shape=}") img_size = 256 in_chans = 3 B = 2 input_img = torch.randn((B,in_chans,img_size,img_size)) b2 = MiT_B2_UNet_MultiHead(3, 3, img_size=img_size) logits, dist = b2(input_img) #%% B3 class MiT_B3_UNet_MultiHead(nn.Module): def __init__(self, in_channels: int, out_channels: int, regress_class: int = 1, img_size: Tuple[int, int] = (256,256), feature_size: int = 16, spatial_dims: int = 2, # hidden_size: int = 768, # mlp_dim: int = 3072, num_heads = [1, 2, 4, 8], # pos_embed: str = "perceptron", norm_name: Union[Tuple, str] = "instance", conv_block: bool = False, res_block: bool = True, dropout_rate: float = 0.0, debug: bool = False ): super().__init__() self.debug = debug self.mit_b3 = MixVisionTransformer(img_size=img_size, patch_size=4, embed_dims=[feature_size*2, feature_size*4, feature_size*8, feature_size*16], num_heads=num_heads, mlp_ratios=[4, 4, 4, 4], qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) self.encoder1 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder2 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=2 * feature_size, out_channels=2 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder3 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=4 * feature_size, out_channels=4 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder4 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=8 * feature_size, out_channels=8 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.encoder5 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=16 * feature_size, out_channels=16 * feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=True, ) self.decoder4 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 16, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder3 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder2 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.transp_conv = get_conv_layer( spatial_dims=2, in_channels=feature_size*2, out_channels=feature_size*2, kernel_size=3, stride=2, conv_only=True, is_transposed=True, ) self.decoder1 = UnetrUpBlock( spatial_dims=2, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.out_interior = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=out_channels) # type: ignore self.out_dist = UnetOutBlock(spatial_dims=2, in_channels=feature_size, out_channels=1) # type: ignore def forward(self, x_in): hidden_states_out = self.mit_b3(x_in) # x: (B, 256,768), hidden_states_out: list, 12 elements, (B,256,768) enc1 = self.encoder1(x_in) # (B, 16, 256, 256) x1 = hidden_states_out[0] # (B, 32, 64, 64) enc2 = self.encoder2(x1) # (B, 64, 32, 32) x2 = hidden_states_out[1] # (B, 64, 32, 32) enc3 = self.encoder3(x2) # (B, 128, 16, 16) x3 = hidden_states_out[2] # (B, 128, 16,16) enc4 = self.encoder4(x3) # (B, 256, 8, 8) x4 = hidden_states_out[3] # (B, 256, 8, 8) enc5 = self.encoder5(x4) # (B, 256, 8, 8) # print(f"{enc1.shape=}, {enc2.shape=}, {enc3.shape=}, {enc4.shape=}, {enc5.shape=}") dec4 = self.decoder4(enc5, enc4) # (B, 128, 16, 16); up -> cat -> ResConv; (B, 128, 16, 16) dec3 = self.decoder3(dec4, enc3) # (B, 64, 32, 32) dec2 = self.decoder2(dec3, enc2) # (B, 32, 64, 64) dec2_up = self.transp_conv(dec2) # [B, 32, 128, 128] dec1 = self.decoder1(dec2_up, enc1) # (B, 16, 256, 256) logits = self.out_interior(dec1) dist = self.out_dist(dec1) if self.debug: return hidden_states_out, enc1, enc2, enc3, enc4, dec4, dec3, dec2, dec1, logits else: return logits, dist # print(f"{dec1.shape=}, {dec2.shape=}, {dec3.shape=}, {dec4.shape=}, {logits.shape=}") #%% head class MLPEmbedding(nn.Module): """ Linear Embedding used in head """ def __init__(self, input_dim=2048, embed_dim=768): super().__init__() self.proj = nn.Linear(input_dim, embed_dim) def forward(self, x): x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x class All_MLP_Head(nn.Module): """ All MLP head in segformer Simple and Efficient Design for Semantic Segmentation with Transformers """ def __init__(self, in_channels=[64,128,320,512], # channel number of multi-scale features in_index=[0,1,2,3], feature_strides=[4,8,16,32], dropout_ratio=0.1, num_classes=3, embedding_dim=768, output_hidden_states=False): super().__init__() self.in_channels = in_channels assert len(feature_strides) == len(self.in_channels) assert min(feature_strides) == feature_strides[0] self.in_index = in_index self.feature_strides = feature_strides self.dropout_ratio = dropout_ratio self.num_classes = num_classes self.output_hidden_states = output_hidden_states c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels # unify channel number to 768 self.linear_c4 = MLPEmbedding(input_dim=c4_in_channels, embed_dim=embedding_dim) self.linear_c3 = MLPEmbedding(input_dim=c3_in_channels, embed_dim=embedding_dim) self.linear_c2 = MLPEmbedding(input_dim=c2_in_channels, embed_dim=embedding_dim) self.linear_c1 = MLPEmbedding(input_dim=c1_in_channels, embed_dim=embedding_dim) self.linear_fuse = nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1,bias=False) self.batch_norm = nn.BatchNorm2d(embedding_dim) # 4: number of blocks self.activation = nn.ReLU() if dropout_ratio>0: self.dropout = nn.Dropout2d(self.dropout_ratio) self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) def forward(self, inputs): # x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 c1, c2, c3, c4 = inputs ############## MLP decoder on C1-C4 ########### n, _, h, w = c4.shape # normalize channel number and resample to 1/4 HxW _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) _c4 = nn.functional.interpolate(_c4, size=c1.size()[2:], mode='bilinear',align_corners=False) _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) _c3 = nn.functional.interpolate(_c3, size=c1.size()[2:], mode='bilinear',align_corners=False) _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) _c2 = nn.functional.interpolate(_c2, size=c1.size()[2:], mode='bilinear',align_corners=False) _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) # concatenate features hidden_states = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) hidden_states = self.batch_norm(hidden_states) hidden_states = self.activation(hidden_states) hidden_states = self.dropout(hidden_states) # predict results x = self.linear_pred(hidden_states) if self.output_hidden_states: return x, hidden_states else: return x #%% test different networks # img_size = 256 # in_chans = 3 # B = 2 # input_img = torch.randn((B,in_chans,img_size,img_size)) # b3 = mit_b3_demo(img_size=img_size) # b3_out = b3(input_img) # for feature in b3_out: # print(f"{feature.shape=}") # head = All_MLP_Head() # outputs = head(b3_out) # print(f"{outputs.shape = }")