Spaces:
Sleeping
Sleeping
File size: 5,355 Bytes
d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Two types of reconstructionl layers: 1. original residual layers, 2. residual layers with contrast and adaptive attention(CCA layer)
import torch
import torch.nn as nn
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualBlock_noBN(nn.Module):
"""Residual block w/o BN
---Conv-ReLU-Conv-+-
|________________|
"""
def __init__(self, nf=64):
super(ResidualBlock_noBN, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
"""
Args:
x: with shape of [b, c, t, h, w]
Returns: processed features with shape [b, c, t, h, w]
"""
identity = x
out = self.lrelu(self.conv1(x))
out = self.conv2(out)
out = identity + out
# Remove ReLU at the end of the residual block
# http://torch.ch/blog/2016/02/04/resnets.html
return out
class ResBlock_noBN_new(nn.Module):
def __init__(self, nf):
super(ResBlock_noBN_new, self).__init__()
self.c1 = nn.Conv3d(nf, nf // 4, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=True)
self.d1 = nn.Conv3d(nf // 4, nf // 4, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1),
bias=True) # dilation rate=1
self.d2 = nn.Conv3d(nf // 4, nf // 4, kernel_size=(1, 3, 3), stride=1, padding=(0, 2, 2), dilation=(1, 2, 2),
bias=True) # dilation rate=2
self.d3 = nn.Conv3d(nf // 4, nf // 4, kernel_size=(1, 3, 3), stride=1, padding=(0, 4, 4), dilation=(1, 4, 4),
bias=True) # dilation rate=4
self.d4 = nn.Conv3d(nf // 4, nf // 4, kernel_size=(1, 3, 3), stride=1, padding=(0, 8, 8), dilation=(1, 8, 8),
bias=True) # dilation rate=8
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.c2 = nn.Conv3d(nf, nf, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=True)
def forward(self, x):
output1 = self.act(self.c1(x))
d1 = self.d1(output1)
d2 = self.d2(output1)
d3 = self.d3(output1)
d4 = self.d4(output1)
add1 = d1 + d2
add2 = add1 + d3
add3 = add2 + d4
combine = torch.cat([d1, add1, add2, add3], dim=1)
output2 = self.c2(self.act(combine))
output = x + output2
# remove ReLU at the end of the residual block
# http://torch.ch/blog/2016/02/04/resnets.html
return output
class CCALayer(nn.Module): #############################################3 new
'''Residual block w/o BN
--conv--contrast-conv--x---
| \--mean--| |
|___________________|
'''
def __init__(self, nf=64):
super(CCALayer, self).__init__()
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv_du = nn.Sequential(
nn.Conv2d(nf, 4, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(4, nf, 1, padding=0, bias=True),
nn.Tanh() # change from `Sigmoid` to `Tanh` to make the output between -1 and 1
)
self.contrast = stdv_channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# initialization
initialize_weights([self.conv1, self.conv_du], 0.1)
def forward(self, x):
identity = x
out = self.lrelu(self.conv1(x))
out = self.conv2(out)
out = self.contrast(out) + self.avg_pool(out)
out_channel = self.conv_du(out)
out_channel = out_channel * out
out_last = out_channel + identity
return out_last
def mean_channels(F):
assert (F.dim() == 4), 'Your dim is {} bit not 4'.format(F.dim())
spatial_sum = F.sum(3, keepdim=True).sum(2, keepdim=True)
return spatial_sum / (F.size(2) * F.size(3)) # 对每一个channel都求其特征图的高和宽的平均值
def stdv_channels(F):
assert F.dim() == 4, 'Your dim is {} bit not 4'.format(F.dim())
F_mean = mean_channels(F)
F_variance = (F - F_mean).pow(2).sum(3, keepdim=True).sum(2, keepdim=True) / (F.size(2) * F.size(3))
return F_variance.pow(0.5)
|