Spaces:
Running
Running
File size: 5,072 Bytes
aea73e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# -*- coding: utf-8 -*-
# Residual block as defined in:
# He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning
# for image recognition." In Proceedings of the IEEE conference on computer vision
# and pattern recognition, pp. 770-778. 2016.
#
# Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import torch
import torch.nn as nn
from collections import OrderedDict
from models.utils.tf_utils import TFSamepaddingLayer
class ResidualBlock(nn.Module):
"""Residual block as defined in:
He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning
for image recognition." In Proceedings of the IEEE conference on computer vision
and pattern recognition, pp. 770-778. 2016.
"""
def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1):
super(ResidualBlock, self).__init__()
assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info"
self.nr_unit = unit_count
self.in_ch = in_ch
self.unit_ch = unit_ch
# ! For inference only so init values for batchnorm may not match tensorflow
unit_in_ch = in_ch
self.units = nn.ModuleList()
for idx in range(unit_count):
unit_layer = [
("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("preact/relu", nn.ReLU(inplace=True)),
(
"conv1",
nn.Conv2d(
unit_in_ch,
unit_ch[0],
unit_ksize[0],
stride=1,
padding=0,
bias=False,
),
),
("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)),
("conv1/relu", nn.ReLU(inplace=True)),
(
"conv2/pad",
TFSamepaddingLayer(
ksize=unit_ksize[1], stride=stride if idx == 0 else 1
),
),
(
"conv2",
nn.Conv2d(
unit_ch[0],
unit_ch[1],
unit_ksize[1],
stride=stride if idx == 0 else 1,
padding=0,
bias=False,
),
),
("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)),
("conv2/relu", nn.ReLU(inplace=True)),
(
"conv3",
nn.Conv2d(
unit_ch[1],
unit_ch[2],
unit_ksize[2],
stride=1,
padding=0,
bias=False,
),
),
]
# * has bna to conclude each previous block so
# * must not put preact for the first unit of this block
unit_layer = unit_layer if idx != 0 else unit_layer[2:]
self.units.append(nn.Sequential(OrderedDict(unit_layer)))
unit_in_ch = unit_ch[-1]
if in_ch != unit_ch[-1] or stride != 1:
self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False)
else:
self.shortcut = None
self.blk_bna = nn.Sequential(
OrderedDict(
[
("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
("relu", nn.ReLU(inplace=True)),
]
)
)
def out_ch(self):
return self.unit_ch[-1]
def init_weights(self):
"""Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers"""
for m in self.modules():
classname = m.__class__.__name__
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if "norm" in classname.lower():
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if "linear" in classname.lower():
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, prev_feat, freeze=False):
if self.shortcut is None:
shortcut = prev_feat
else:
shortcut = self.shortcut(prev_feat)
for idx in range(0, len(self.units)):
new_feat = prev_feat
if self.training:
with torch.set_grad_enabled(not freeze):
new_feat = self.units[idx](new_feat)
else:
new_feat = self.units[idx](new_feat)
prev_feat = new_feat + shortcut
shortcut = prev_feat
feat = self.blk_bna(prev_feat)
return feat
|