| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from rscd.models.backbones.vmamba import VSSM, LayerNorm2d, VSSBlock, Permute |
| |
|
| |
|
| | class ChangeDecoder(nn.Module): |
| | def __init__(self, encoder_dims, channel_first, norm_layer, ssm_act_layer, mlp_act_layer, **kwargs): |
| | super(ChangeDecoder, self).__init__() |
| |
|
| | |
| | self.st_block_41 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1] * 2, out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_42 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| |
|
| | ) |
| | self.st_block_43 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-1], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| |
|
| | self.st_block_31 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2] * 2, out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_32 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_33 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-2], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| |
|
| | self.st_block_21 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3] * 2, out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_22 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_23 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-3], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| |
|
| | self.st_block_11 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4] * 2, out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_12 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| | self.st_block_13 = nn.Sequential( |
| | nn.Conv2d(kernel_size=1, in_channels=encoder_dims[-4], out_channels=128), |
| | Permute(0, 2, 3, 1) if not channel_first else nn.Identity(), |
| | VSSBlock(hidden_dim=128, drop_path=0.1, norm_layer=norm_layer, channel_first=channel_first, |
| | ssm_d_state=kwargs['ssm_d_state'], ssm_ratio=kwargs['ssm_ratio'], ssm_dt_rank=kwargs['ssm_dt_rank'], ssm_act_layer=ssm_act_layer, |
| | ssm_conv=kwargs['ssm_conv'], ssm_conv_bias=kwargs['ssm_conv_bias'], ssm_drop_rate=kwargs['ssm_drop_rate'], ssm_init=kwargs['ssm_init'], |
| | forward_type=kwargs['forward_type'], mlp_ratio=kwargs['mlp_ratio'], mlp_act_layer=mlp_act_layer, mlp_drop_rate=kwargs['mlp_drop_rate'], |
| | gmlp=kwargs['gmlp'], use_checkpoint=kwargs['use_checkpoint']), |
| | Permute(0, 3, 1, 2) if not channel_first else nn.Identity(), |
| | ) |
| |
|
| | |
| | self.fuse_layer_4 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
| | nn.BatchNorm2d(128), nn.ReLU()) |
| | self.fuse_layer_3 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
| | nn.BatchNorm2d(128), nn.ReLU()) |
| | self.fuse_layer_2 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
| | nn.BatchNorm2d(128), nn.ReLU()) |
| | self.fuse_layer_1 = nn.Sequential(nn.Conv2d(kernel_size=1, in_channels=128 * 5, out_channels=128), |
| | nn.BatchNorm2d(128), nn.ReLU()) |
| |
|
| | |
| | self.smooth_layer_3 = ResBlock(in_channels=128, out_channels=128, stride=1) |
| | self.smooth_layer_2 = ResBlock(in_channels=128, out_channels=128, stride=1) |
| | self.smooth_layer_1 = ResBlock(in_channels=128, out_channels=128, stride=1) |
| | |
| | def _upsample_add(self, x, y): |
| | _, _, H, W = y.size() |
| | return F.interpolate(x, size=(H, W), mode='bilinear') + y |
| |
|
| | def forward(self, pre_features, post_features): |
| |
|
| | pre_feat_1, pre_feat_2, pre_feat_3, pre_feat_4 = pre_features |
| |
|
| | post_feat_1, post_feat_2, post_feat_3, post_feat_4 = post_features |
| |
|
| | ''' |
| | Stage I |
| | ''' |
| | p41 = self.st_block_41(torch.cat([pre_feat_4, post_feat_4], dim=1)) |
| | B, C, H, W = pre_feat_4.size() |
| | |
| | ct_tensor_42 = torch.empty(B, C, H, 2*W).cuda() |
| | |
| | ct_tensor_42[:, :, :, ::2] = pre_feat_4 |
| | ct_tensor_42[:, :, :, 1::2] = post_feat_4 |
| | p42 = self.st_block_42(ct_tensor_42) |
| |
|
| | ct_tensor_43 = torch.empty(B, C, H, 2*W).cuda() |
| | ct_tensor_43[:, :, :, 0:W] = pre_feat_4 |
| | ct_tensor_43[:, :, :, W:] = post_feat_4 |
| | p43 = self.st_block_43(ct_tensor_43) |
| |
|
| | p4 = self.fuse_layer_4(torch.cat([p41, p42[:, :, :, ::2], p42[:, :, :, 1::2], p43[:, :, :, 0:W], p43[:, :, :, W:]], dim=1)) |
| | |
| |
|
| | ''' |
| | Stage II |
| | ''' |
| | p31 = self.st_block_31(torch.cat([pre_feat_3, post_feat_3], dim=1)) |
| | B, C, H, W = pre_feat_3.size() |
| | |
| | ct_tensor_32 = torch.empty(B, C, H, 2*W).cuda() |
| | |
| | ct_tensor_32[:, :, :, ::2] = pre_feat_3 |
| | ct_tensor_32[:, :, :, 1::2] = post_feat_3 |
| | p32 = self.st_block_32(ct_tensor_32) |
| |
|
| | ct_tensor_33 = torch.empty(B, C, H, 2*W).cuda() |
| | ct_tensor_33[:, :, :, 0:W] = pre_feat_3 |
| | ct_tensor_33[:, :, :, W:] = post_feat_3 |
| | p33 = self.st_block_33(ct_tensor_33) |
| |
|
| | p3 = self.fuse_layer_3(torch.cat([p31, p32[:, :, :, ::2], p32[:, :, :, 1::2], p33[:, :, :, 0:W], p33[:, :, :, W:]], dim=1)) |
| | p3 = self._upsample_add(p4, p3) |
| | p3 = self.smooth_layer_3(p3) |
| | |
| | ''' |
| | Stage III |
| | ''' |
| | p21 = self.st_block_21(torch.cat([pre_feat_2, post_feat_2], dim=1)) |
| | B, C, H, W = pre_feat_2.size() |
| | |
| | ct_tensor_22 = torch.empty(B, C, H, 2*W).cuda() |
| | |
| | ct_tensor_22[:, :, :, ::2] = pre_feat_2 |
| | ct_tensor_22[:, :, :, 1::2] = post_feat_2 |
| | p22 = self.st_block_22(ct_tensor_22) |
| |
|
| | ct_tensor_23 = torch.empty(B, C, H, 2*W).cuda() |
| | ct_tensor_23[:, :, :, 0:W] = pre_feat_2 |
| | ct_tensor_23[:, :, :, W:] = post_feat_2 |
| | p23 = self.st_block_23(ct_tensor_23) |
| |
|
| | p2 = self.fuse_layer_2(torch.cat([p21, p22[:, :, :, ::2], p22[:, :, :, 1::2], p23[:, :, :, 0:W], p23[:, :, :, W:]], dim=1)) |
| | p2 = self._upsample_add(p3, p2) |
| | p2 = self.smooth_layer_2(p2) |
| | |
| | ''' |
| | Stage IV |
| | ''' |
| | p11 = self.st_block_11(torch.cat([pre_feat_1, post_feat_1], dim=1)) |
| | B, C, H, W = pre_feat_1.size() |
| | |
| | ct_tensor_12 = torch.empty(B, C, H, 2*W).cuda() |
| | |
| | ct_tensor_12[:, :, :, ::2] = pre_feat_1 |
| | ct_tensor_12[:, :, :, 1::2] = post_feat_1 |
| | p12 = self.st_block_12(ct_tensor_12) |
| |
|
| | ct_tensor_13 = torch.empty(B, C, H, 2*W).cuda() |
| | ct_tensor_13[:, :, :, 0:W] = pre_feat_1 |
| | ct_tensor_13[:, :, :, W:] = post_feat_1 |
| | p13 = self.st_block_13(ct_tensor_13) |
| |
|
| | p1 = self.fuse_layer_1(torch.cat([p11, p12[:, :, :, ::2], p12[:, :, :, 1::2], p13[:, :, :, 0:W], p13[:, :, :, W:]], dim=1)) |
| |
|
| | p1 = self._upsample_add(p2, p1) |
| | p1 = self.smooth_layer_1(p1) |
| |
|
| | return p1 |
| |
|
| | |
| | class ResBlock(nn.Module): |
| | def __init__(self, in_channels, out_channels, stride=1, downsample=None): |
| | super(ResBlock, self).__init__() |
| | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
| | self.bn1 = nn.BatchNorm2d(out_channels) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) |
| | self.bn2 = nn.BatchNorm2d(out_channels) |
| | self.downsample = downsample |
| |
|
| | def forward(self, x): |
| | identity = x |
| |
|
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| |
|
| | out = self.conv2(out) |
| | out = self.bn2(out) |
| |
|
| | if self.downsample is not None: |
| | identity = self.downsample(x) |
| |
|
| | out += identity |
| | out = self.relu(out) |
| |
|
| | return out |
| |
|
| | class CMDecoder(nn.Module): |
| | def __init__(self, **kwargs): |
| | super(CMDecoder, self).__init__() |
| | |
| | _NORMLAYERS = dict( |
| | ln=nn.LayerNorm, |
| | ln2d=LayerNorm2d, |
| | bn=nn.BatchNorm2d, |
| | ) |
| | |
| | _ACTLAYERS = dict( |
| | silu=nn.SiLU, |
| | gelu=nn.GELU, |
| | relu=nn.ReLU, |
| | sigmoid=nn.Sigmoid, |
| | ) |
| | |
| |
|
| | norm_layer: nn.Module = _NORMLAYERS.get(kwargs['norm_layer'].lower(), None) |
| | ssm_act_layer: nn.Module = _ACTLAYERS.get(kwargs['ssm_act_layer'].lower(), None) |
| | mlp_act_layer: nn.Module = _ACTLAYERS.get(kwargs['mlp_act_layer'].lower(), None) |
| |
|
| | |
| | clean_kwargs = {k: v for k, v in kwargs.items() if k not in ['norm_layer', 'ssm_act_layer', 'mlp_act_layer']} |
| | self.decoder = ChangeDecoder( |
| | encoder_dims= [int(kwargs['dims'] * 2 ** i_layer) for i_layer in range(len(kwargs['depths']))], |
| | channel_first=True, |
| | norm_layer=norm_layer, |
| | ssm_act_layer=ssm_act_layer, |
| | mlp_act_layer=mlp_act_layer, |
| | **clean_kwargs |
| | ) |
| |
|
| | self.main_clf = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1) |
| |
|
| | def _upsample_add(self, x, y): |
| | _, _, H, W = y.size() |
| | return F.interpolate(x, size=(H, W), mode='bilinear') + y |
| |
|
| | def forward(self, xs): |
| | pre_features, post_features, pre_data_size = xs |
| | |
| | output = self.decoder(pre_features, post_features) |
| |
|
| | output = self.main_clf(output) |
| | output = F.interpolate(output, size=pre_data_size, mode='bilinear') |
| | return output |