52Hz commited on
Commit
82e582c
1 Parent(s): 0ad3230

Create block.py

Browse files
Files changed (1) hide show
  1. model/block.py +146 -0
model/block.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ ##########################################################################
4
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
5
+ layer = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride)
6
+ return layer
7
+
8
+
9
+ def conv3x3(in_chn, out_chn, bias=True):
10
+ layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
11
+ return layer
12
+
13
+
14
+ def conv_down(in_chn, out_chn, bias=False):
15
+ layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
16
+ return layer
17
+
18
+ ##########################################################################
19
+ ## Supervised Attention Module (RAM)
20
+ class SAM(nn.Module):
21
+ def __init__(self, n_feat, kernel_size, bias):
22
+ super(SAM, self).__init__()
23
+ self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
24
+ self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
25
+ self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
26
+
27
+ def forward(self, x, x_img):
28
+ x1 = self.conv1(x)
29
+ img = self.conv2(x) + x_img
30
+ x2 = torch.sigmoid(self.conv3(img))
31
+ x1 = x1 * x2
32
+ x1 = x1 + x
33
+ return x1, img
34
+
35
+ ##########################################################################
36
+ ## Spatial Attention
37
+ class SALayer(nn.Module):
38
+ def __init__(self, kernel_size=7):
39
+ super(SALayer, self).__init__()
40
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
41
+ self.sigmoid = nn.Sigmoid()
42
+
43
+ def forward(self, x):
44
+ avg_out = torch.mean(x, dim=1, keepdim=True)
45
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
46
+ y = torch.cat([avg_out, max_out], dim=1)
47
+ y = self.conv1(y)
48
+ y = self.sigmoid(y)
49
+ return x * y
50
+
51
+ # Spatial Attention Block (SAB)
52
+ class SAB(nn.Module):
53
+ def __init__(self, n_feat, kernel_size, reduction, bias, act):
54
+ super(SAB, self).__init__()
55
+ modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
56
+ self.body = nn.Sequential(*modules_body)
57
+ self.SA = SALayer(kernel_size=7)
58
+
59
+ def forward(self, x):
60
+ res = self.body(x)
61
+ res = self.SA(res)
62
+ res += x
63
+ return res
64
+
65
+ ##########################################################################
66
+ ## Pixel Attention
67
+ class PALayer(nn.Module):
68
+ def __init__(self, channel, reduction=16, bias=False):
69
+ super(PALayer, self).__init__()
70
+ self.pa = nn.Sequential(
71
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
72
+ nn.ReLU(inplace=True),
73
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), # channel <-> 1
74
+ nn.Sigmoid()
75
+ )
76
+
77
+ def forward(self, x):
78
+ y = self.pa(x)
79
+ return x * y
80
+
81
+ ## Pixel Attention Block (PAB)
82
+ class PAB(nn.Module):
83
+ def __init__(self, n_feat, kernel_size, reduction, bias, act):
84
+ super(PAB, self).__init__()
85
+ modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
86
+ self.PA = PALayer(n_feat, reduction, bias=bias)
87
+ self.body = nn.Sequential(*modules_body)
88
+
89
+ def forward(self, x):
90
+ res = self.body(x)
91
+ res = self.PA(res)
92
+ res += x
93
+ return res
94
+
95
+ ##########################################################################
96
+ ## Channel Attention Layer
97
+ class CALayer(nn.Module):
98
+ def __init__(self, channel, reduction=16, bias=False):
99
+ super(CALayer, self).__init__()
100
+ # global average pooling: feature --> point
101
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
102
+ # feature channel downscale and upscale --> channel weight
103
+ self.conv_du = nn.Sequential(
104
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
105
+ nn.ReLU(inplace=True),
106
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
107
+ nn.Sigmoid()
108
+ )
109
+
110
+ def forward(self, x):
111
+ y = self.avg_pool(x)
112
+ y = self.conv_du(y)
113
+ return x * y
114
+
115
+ ## Channel Attention Block (CAB)
116
+ class CAB(nn.Module):
117
+ def __init__(self, n_feat, kernel_size, reduction, bias, act):
118
+ super(CAB, self).__init__()
119
+ modules_body = [conv(n_feat, n_feat, kernel_size, bias=bias), act, conv(n_feat, n_feat, kernel_size, bias=bias)]
120
+
121
+ self.CA = CALayer(n_feat, reduction, bias=bias)
122
+ self.body = nn.Sequential(*modules_body)
123
+
124
+ def forward(self, x):
125
+ res = self.body(x)
126
+ res = self.CA(res)
127
+ res += x
128
+ return res
129
+
130
+
131
+ if __name__ == "__main__":
132
+ import time
133
+ from thop import profile
134
+ # layer = CAB(64, 3, 4, False, nn.PReLU())
135
+ layer = PAB(64, 3, 4, False, nn.PReLU())
136
+ # layer = SAB(64, 3, 4, False, nn.PReLU())
137
+ for idx, m in enumerate(layer.modules()):
138
+ print(idx, "-", m)
139
+ s = time.time()
140
+
141
+ rgb = torch.ones(1, 64, 256, 256, dtype=torch.float, requires_grad=False)
142
+ out = layer(rgb)
143
+ flops, params = profile(layer, inputs=(rgb,))
144
+ print('parameters:', params)
145
+ print('flops', flops)
146
+ print('time: {:.4f}ms'.format((time.time()-s)*10))