LSZTT commited on
Commit
29add7c
1 Parent(s): 4cd5cd5

Upload SE.py

Browse files
Files changed (1) hide show
  1. models/SE.py +39 -0
models/SE.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import init
5
+
6
+
7
+
8
+ class SEAttention(nn.Module):
9
+
10
+ def __init__(self, channel=512,reduction=16):
11
+ super().__init__()
12
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
13
+ self.fc = nn.Sequential(
14
+ nn.Linear(channel, channel // reduction, bias=False),
15
+ nn.ReLU(inplace=True),
16
+ nn.Linear(channel // reduction, channel, bias=False),
17
+ nn.Sigmoid()
18
+ )
19
+
20
+
21
+ def init_weights(self):
22
+ for m in self.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, mode='fan_out')
25
+ if m.bias is not None:
26
+ init.constant_(m.bias, 0)
27
+ elif isinstance(m, nn.BatchNorm2d):
28
+ init.constant_(m.weight, 1)
29
+ init.constant_(m.bias, 0)
30
+ elif isinstance(m, nn.Linear):
31
+ init.normal_(m.weight, std=0.001)
32
+ if m.bias is not None:
33
+ init.constant_(m.bias, 0)
34
+
35
+ def forward(self, x):
36
+ b, c, _, _ = x.size()
37
+ y = self.avg_pool(x).view(b, c)
38
+ y = self.fc(y).view(b, c, 1, 1)
39
+ return x * y.expand_as(x)