| | |
| | |
| | |
| | |
| | import torch.nn as nn |
| | import torch |
| | import torch.nn.functional as F |
| | from .CTrans import ChannelTransformer |
| |
|
| | def get_activation(activation_type): |
| | activation_type = activation_type.lower() |
| | if hasattr(nn, activation_type): |
| | return getattr(nn, activation_type)() |
| | else: |
| | return nn.ReLU() |
| |
|
| | def _make_nConv(in_channels, out_channels, nb_Conv, activation='ReLU'): |
| | layers = [] |
| | layers.append(ConvBatchNorm(in_channels, out_channels, activation)) |
| |
|
| | for _ in range(nb_Conv - 1): |
| | layers.append(ConvBatchNorm(out_channels, out_channels, activation)) |
| | return nn.Sequential(*layers) |
| |
|
| | class ConvBatchNorm(nn.Module): |
| | """(convolution => [BN] => ReLU)""" |
| |
|
| | def __init__(self, in_channels, out_channels, activation='ReLU'): |
| | super(ConvBatchNorm, self).__init__() |
| | self.conv = nn.Conv2d(in_channels, out_channels, |
| | kernel_size=3, padding=1) |
| | self.norm = nn.BatchNorm2d(out_channels) |
| | self.activation = get_activation(activation) |
| |
|
| | def forward(self, x): |
| | out = self.conv(x) |
| | out = self.norm(out) |
| | return self.activation(out) |
| |
|
| | class DownBlock(nn.Module): |
| | """Downscaling with maxpool convolution""" |
| | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): |
| | super(DownBlock, self).__init__() |
| | self.maxpool = nn.MaxPool2d(2) |
| | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) |
| |
|
| | def forward(self, x): |
| | out = self.maxpool(x) |
| | return self.nConvs(out) |
| |
|
| | class Flatten(nn.Module): |
| | def forward(self, x): |
| | return x.view(x.size(0), -1) |
| |
|
| | class CCA(nn.Module): |
| | """ |
| | CCA Block |
| | """ |
| | def __init__(self, F_g, F_x): |
| | super().__init__() |
| | self.mlp_x = nn.Sequential( |
| | Flatten(), |
| | nn.Linear(F_x, F_x)) |
| | self.mlp_g = nn.Sequential( |
| | Flatten(), |
| | nn.Linear(F_g, F_x)) |
| | self.relu = nn.ReLU(inplace=True) |
| |
|
| | def forward(self, g, x): |
| | |
| | avg_pool_x = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) |
| | channel_att_x = self.mlp_x(avg_pool_x) |
| | avg_pool_g = F.avg_pool2d( g, (g.size(2), g.size(3)), stride=(g.size(2), g.size(3))) |
| | channel_att_g = self.mlp_g(avg_pool_g) |
| | channel_att_sum = (channel_att_x + channel_att_g)/2.0 |
| | scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) |
| | x_after_channel = x * scale |
| | out = self.relu(x_after_channel) |
| | return out |
| |
|
| | class UpBlock_attention(nn.Module): |
| | def __init__(self, in_channels, out_channels, nb_Conv, activation='ReLU'): |
| | super().__init__() |
| | self.up = nn.Upsample(scale_factor=2) |
| | self.coatt = CCA(F_g=in_channels//2, F_x=in_channels//2) |
| | self.nConvs = _make_nConv(in_channels, out_channels, nb_Conv, activation) |
| |
|
| | def forward(self, x, skip_x): |
| | up = self.up(x) |
| | skip_x_att = self.coatt(g=up, x=skip_x) |
| | x = torch.cat([skip_x_att, up], dim=1) |
| | return self.nConvs(x) |
| |
|
| | class UCTransNet(nn.Module): |
| | def __init__(self, config,n_channels=3, n_classes=1,img_size=224,vis=False): |
| | super().__init__() |
| | self.vis = vis |
| | self.n_channels = n_channels |
| | self.n_classes = n_classes |
| | in_channels = config.base_channel |
| | self.inc = ConvBatchNorm(n_channels, in_channels) |
| | self.down1 = DownBlock(in_channels, in_channels*2, nb_Conv=2) |
| | self.down2 = DownBlock(in_channels*2, in_channels*4, nb_Conv=2) |
| | self.down3 = DownBlock(in_channels*4, in_channels*8, nb_Conv=2) |
| | self.down4 = DownBlock(in_channels*8, in_channels*8, nb_Conv=2) |
| | self.mtc = ChannelTransformer(config, vis, img_size, |
| | channel_num=[in_channels, in_channels*2, in_channels*4, in_channels*8], |
| | patchSize=config.patch_sizes) |
| | self.up4 = UpBlock_attention(in_channels*16, in_channels*4, nb_Conv=2) |
| | self.up3 = UpBlock_attention(in_channels*8, in_channels*2, nb_Conv=2) |
| | self.up2 = UpBlock_attention(in_channels*4, in_channels, nb_Conv=2) |
| | self.up1 = UpBlock_attention(in_channels*2, in_channels, nb_Conv=2) |
| | self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1,1), stride=(1,1)) |
| | self.last_activation = nn.Sigmoid() |
| |
|
| | def forward(self, x): |
| | x = x.float() |
| | x1 = self.inc(x) |
| | x2 = self.down1(x1) |
| | x3 = self.down2(x2) |
| | x4 = self.down3(x3) |
| | x5 = self.down4(x4) |
| | x1,x2,x3,x4,att_weights = self.mtc(x1,x2,x3,x4) |
| | x = self.up4(x5, x4) |
| | x = self.up3(x, x3) |
| | x = self.up2(x, x2) |
| | x = self.up1(x, x1) |
| | if self.n_classes ==1: |
| | logits = self.last_activation(self.outc(x)) |
| | else: |
| | logits = self.outc(x) |
| | if self.vis: |
| | return logits, att_weights |
| | else: |
| | return logits |
| |
|
| |
|
| |
|
| |
|
| |
|