|
from collections import OrderedDict |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torch |
|
import torch.utils.checkpoint |
|
|
|
|
|
|
|
|
|
def humanbytes(B): |
|
'Return the given bytes as a human friendly KB, MB, GB, or TB string' |
|
B = float(B) |
|
KB = float(1024) |
|
MB = float(KB ** 2) |
|
GB = float(KB ** 3) |
|
TB = float(KB ** 4) |
|
|
|
if B < KB: |
|
return '{0} {1}'.format(B, 'Bytes' if 0 == B > 1 else 'Byte') |
|
elif KB <= B < MB: |
|
return '{0:.2f} KB'.format(B / KB) |
|
elif MB <= B < GB: |
|
return '{0:.2f} MB'.format(B / MB) |
|
elif GB <= B < TB: |
|
return '{0:.2f} GB'.format(B / GB) |
|
elif TB <= B: |
|
return '{0:.2f} TB'.format(B / TB) |
|
|
|
|
|
def get_preact_conv(in_channels, out_channels, kernel_size=3, padding=1, dropout_2d=0.2): |
|
block = nn.Sequential( |
|
nn.BatchNorm2d(in_channels), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding), |
|
nn.Dropout2d(dropout_2d) |
|
) |
|
return block |
|
|
|
|
|
def _dense_layer_function_factory(norm, relu, conv): |
|
def bn_function(*inputs): |
|
concated_features = torch.cat(inputs, 1) |
|
bottleneck_output = conv(relu(norm(concated_features))) |
|
return bottleneck_output |
|
|
|
return bn_function |
|
|
|
|
|
class DenseLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, dropout_2d=0.2, efficient=False): |
|
super(DenseLayer, self).__init__() |
|
self.add_module('norm', nn.BatchNorm2d(in_channels)), |
|
self.add_module('relu', nn.ReLU(inplace=True)), |
|
self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)), |
|
self.dropout_2d = dropout_2d |
|
self.efficient = efficient |
|
|
|
def forward(self, *prev_features): |
|
dense_layer_function = _dense_layer_function_factory(self.norm, self.relu, self.conv) |
|
if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): |
|
new_features = torch.utils.checkpoint.checkpoint(dense_layer_function, *prev_features) |
|
else: |
|
new_features = dense_layer_function(*prev_features) |
|
if 0 < self.dropout_2d: |
|
new_features = F.dropout2d(new_features, p=self.dropout_2d, training=self.training) |
|
return new_features |
|
|
|
|
|
class SELayer(nn.Module): |
|
def __init__(self, in_channels, ratio): |
|
super(SELayer, self).__init__() |
|
self.block = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Flatten(), |
|
nn.Linear(in_channels, in_channels // ratio), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(in_channels // ratio, in_channels), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
excitation = self.block(x) |
|
x *= excitation[:, :, None, None] |
|
return x |
|
|
|
|
|
class DenseBlock(nn.Module): |
|
def __init__(self, in_channels, n_layers, growth_rate, dropout_2d, return_only_new=False, efficient=False): |
|
super(DenseBlock, self).__init__() |
|
assert 0 < n_layers, "n_layers should be at least 1" |
|
self.in_channels = in_channels |
|
self.return_only_new = return_only_new |
|
|
|
channels = in_channels |
|
self.layers = torch.nn.ModuleList() |
|
for j in range(n_layers): |
|
|
|
layer = DenseLayer(channels, growth_rate, dropout_2d=dropout_2d, efficient=efficient) |
|
self.layers.append(layer) |
|
channels += growth_rate |
|
|
|
if return_only_new: |
|
se_layer_in_channel = channels - in_channels |
|
else: |
|
se_layer_in_channel = channels |
|
self.se_layer = SELayer(se_layer_in_channel, ratio=1) |
|
|
|
|
|
def forward(self, x): |
|
features = [x] |
|
for layer in self.layers: |
|
new_features = layer(*features) |
|
features.append(new_features) |
|
|
|
if self.return_only_new: |
|
features = features[1:] |
|
|
|
features = torch.cat(features, 1) |
|
features = self.se_layer(features) |
|
|
|
return features |
|
|
|
|
|
def get_transition_down(in_channels, out_channels, dropout_2d=0.2): |
|
block = nn.Sequential( |
|
get_preact_conv(in_channels, out_channels, kernel_size=1, padding=0, dropout_2d=dropout_2d), |
|
nn.MaxPool2d(kernel_size=2, stride=2) |
|
) |
|
return block |
|
|
|
|
|
def cat_non_matching(x1, x2): |
|
diffY = x1.size()[2] - x2.size()[2] |
|
diffX = x1.size()[3] - x2.size()[3] |
|
|
|
x2 = F.pad(x2, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) |
|
|
|
|
|
|
|
|
|
|
|
x = torch.cat([x1, x2], dim=1) |
|
return x |
|
|
|
|
|
class TransitionUp(nn.Module): |
|
def __init__(self, in_channels, n_filters_keep): |
|
super(TransitionUp, self).__init__() |
|
self.conv_transpose_2d = nn.ConvTranspose2d(in_channels, n_filters_keep, kernel_size=4, stride=2, padding=1) |
|
|
|
def forward(self, x, skip_connection): |
|
x = self.conv_transpose_2d(x) |
|
x = cat_non_matching(x, skip_connection) |
|
return x |
|
|
|
|
|
class ICTNetBackbone(nn.Module): |
|
""" |
|
ICTNet model: https://theictlab.org/lp/2019ICTNet. |
|
""" |
|
def __init__(self, preset_model='FC-DenseNet56', in_channels=3, out_channels=2, n_filters_first_conv=48, n_pool=5, growth_rate=12, n_layers_per_block=4, dropout_2d=0.2, efficient=False): |
|
super().__init__() |
|
|
|
|
|
if preset_model == 'FC-DenseNet56': |
|
n_pool = 5 |
|
growth_rate = 12 |
|
n_layers_per_block = 4 |
|
elif preset_model == 'FC-DenseNet67': |
|
n_pool = 5 |
|
growth_rate = 16 |
|
n_layers_per_block = 5 |
|
elif preset_model == 'FC-DenseNet103': |
|
n_pool = 5 |
|
growth_rate = 16 |
|
n_layers_per_block = [4, 5, 7, 10, 12, 15, 12, 10, 7, 5, 4] |
|
else: |
|
n_pool = n_pool |
|
growth_rate = growth_rate |
|
n_layers_per_block = n_layers_per_block |
|
|
|
if type(n_layers_per_block) == list: |
|
assert (len(n_layers_per_block) == 2 * n_pool + 1) |
|
elif type(n_layers_per_block) == int: |
|
n_layers_per_block = [n_layers_per_block] * (2 * n_pool + 1) |
|
else: |
|
raise ValueError |
|
|
|
|
|
self.first_conv = nn.Conv2d(in_channels, n_filters_first_conv, 3, padding=1) |
|
|
|
|
|
channels = n_filters_first_conv |
|
self.down_dense_blocks = torch.nn.ModuleList() |
|
self.transition_downs = torch.nn.ModuleList() |
|
skip_connection_channels = [] |
|
for i in range(n_pool): |
|
|
|
self.down_dense_blocks.append(DenseBlock(in_channels=channels, n_layers=n_layers_per_block[i], growth_rate=growth_rate, dropout_2d=dropout_2d, return_only_new=False, efficient=efficient)) |
|
channels += growth_rate * n_layers_per_block[i] |
|
skip_connection_channels.append(channels) |
|
|
|
self.transition_downs.append(get_transition_down(in_channels=channels, out_channels=channels, dropout_2d=dropout_2d)) |
|
|
|
|
|
self.bottleneck_dense_block = DenseBlock(in_channels=channels, n_layers=n_layers_per_block[n_pool], growth_rate=growth_rate, dropout_2d=dropout_2d, return_only_new=True, efficient=efficient) |
|
up_in_channels = n_layers_per_block[n_pool] * growth_rate |
|
|
|
|
|
self.transition_ups = torch.nn.ModuleList() |
|
self.up_dense_blocks = torch.nn.ModuleList() |
|
for i in range(n_pool): |
|
|
|
n_filters_keep = growth_rate * n_layers_per_block[n_pool + i] |
|
self.transition_ups.append(TransitionUp(in_channels=up_in_channels, n_filters_keep=n_filters_keep)) |
|
up_out_channels = skip_connection_channels[n_pool - i - 1] + n_filters_keep |
|
|
|
|
|
|
|
self.up_dense_blocks.append( |
|
DenseBlock(in_channels=up_out_channels, n_layers=n_layers_per_block[n_pool + i + 1], growth_rate=growth_rate, |
|
dropout_2d=dropout_2d, return_only_new=True, efficient=efficient)) |
|
up_in_channels = growth_rate * n_layers_per_block[n_pool + i + 1] |
|
|
|
|
|
self.final_conv = nn.Conv2d(up_in_channels, out_channels, 1, padding=0) |
|
|
|
|
|
def forward(self, x): |
|
stack = self.first_conv(x) |
|
|
|
skip_connection_list = [] |
|
|
|
for down_dense_block, transition_down in zip(self.down_dense_blocks, self.transition_downs): |
|
|
|
stack = down_dense_block(stack) |
|
|
|
|
|
skip_connection_list.append(stack) |
|
|
|
|
|
stack = transition_down(stack) |
|
|
|
|
|
skip_connection_list = skip_connection_list[::-1] |
|
|
|
|
|
|
|
stack = self.bottleneck_dense_block(stack) |
|
|
|
|
|
|
|
for transition_up, up_dense_block, skip_connection in zip(self.transition_ups, self.up_dense_blocks, skip_connection_list): |
|
|
|
stack = transition_up(stack, skip_connection) |
|
|
|
|
|
|
|
stack = up_dense_block(stack) |
|
|
|
|
|
|
|
stack = self.final_conv(stack) |
|
|
|
result = OrderedDict() |
|
result["out"] = stack |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
def count_trainable_params(model): |
|
count = 0 |
|
for param in model.parameters(): |
|
if param.requires_grad: |
|
count += param.numel() |
|
return count |
|
|
|
|
|
def main(): |
|
device = "cuda" |
|
b = 2 |
|
c = 3 |
|
h = 512 |
|
w = 512 |
|
features = 32 |
|
|
|
|
|
x = torch.rand((b, c, h, w), device=device) |
|
print("x: ", x.shape, x.min().item(), x.max().item()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backbone = ICTNetBackbone(out_channels=features, preset_model="FC-DenseNet103", dropout_2d=0.0, efficient=True) |
|
print("ICTNetBackbone has {} trainable params".format(count_trainable_params(backbone))) |
|
|
|
backbone.to(device) |
|
result = backbone(x) |
|
y = result["out"] |
|
print("y: ", y.shape) |
|
print("------") |
|
|
|
print("Back-prop:") |
|
loss = torch.sum(y) |
|
loss.backward() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|