LKCell / models /utils /dense.py
qingke1's picture
initial commit
aea73e2
# -*- coding: utf-8 -*-
# Dense Block as defined in:
# Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger.
# "Densely connected convolutional networks." In Proceedings of the IEEE conference
# on computer vision and pattern recognition, pp. 4700-4708. 2017.
#
# Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import torch
import torch.nn as nn
from collections import OrderedDict
class DenseBlock(nn.Module):
"""Dense Block as defined in:
Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger.
"Densely connected convolutional networks." In Proceedings of the IEEE conference
on computer vision and pattern recognition, pp. 4700-4708. 2017.
Only performs `valid` convolution.
"""
def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1):
super(DenseBlock, self).__init__()
assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info"
self.nr_unit = unit_count
self.in_ch = in_ch
self.unit_ch = unit_ch
# ! For inference only so init values for batchnorm may not match tensorflow
unit_in_ch = in_ch
self.units = nn.ModuleList()
for idx in range(unit_count):
self.units.append(
nn.Sequential(
OrderedDict(
[
("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("preact_bna/relu", nn.ReLU(inplace=True)),
(
"conv1",
nn.Conv2d(
unit_in_ch,
unit_ch[0],
unit_ksize[0],
stride=1,
padding=0,
bias=False,
),
),
("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)),
("conv1/relu", nn.ReLU(inplace=True)),
# ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)),
(
"conv2",
nn.Conv2d(
unit_ch[0],
unit_ch[1],
unit_ksize[1],
groups=split,
stride=1,
padding=0,
bias=False,
),
),
]
)
)
)
unit_in_ch += unit_ch[1]
self.blk_bna = nn.Sequential(
OrderedDict(
[
("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
]
)
)
def out_ch(self):
return self.in_ch + self.nr_unit * self.unit_ch[-1]
def init_weights(self):
"""Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers"""
for m in self.modules():
classname = m.__class__.__name__
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if "norm" in classname.lower():
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if "linear" in classname.lower():
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, prev_feat):
for idx in range(self.nr_unit):
new_feat = self.units[idx](prev_feat)
prev_feat = crop_to_shape(prev_feat, new_feat)
prev_feat = torch.cat([prev_feat, new_feat], dim=1)
prev_feat = self.blk_bna(prev_feat)
return prev_feat
# helper functions for cropping
def crop_op(x, cropping, data_format="NCHW"):
"""Center crop image.
Args:
x: input image
cropping: the substracted amount
data_format: choose either `NCHW` or `NHWC`
"""
crop_t = cropping[0] // 2
crop_b = cropping[0] - crop_t
crop_l = cropping[1] // 2
crop_r = cropping[1] - crop_l
if data_format == "NCHW":
x = x[:, :, crop_t:-crop_b, crop_l:-crop_r]
else:
x = x[:, crop_t:-crop_b, crop_l:-crop_r, :]
return x
def crop_to_shape(x, y, data_format="NCHW"):
"""Centre crop x so that x has shape of y. y dims must be smaller than x dims.
Args:
x: input array
y: array with desired shape.
"""
assert (
y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1]
), "Ensure that y dimensions are smaller than x dimensions!"
x_shape = x.size()
y_shape = y.size()
if data_format == "NCHW":
crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3])
else:
crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2])
return crop_op(x, crop_shape, data_format)