LS / models /SGE.py
LSZTT's picture
Upload 60 files
f5e6066 verified
raw
history blame
No virus
1.82 kB
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SpatialGroupEnhance(nn.Module):
def __init__(self, groups=8):
super().__init__()
self.groups = groups
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1))
self.sig = nn.Sigmoid()
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, h, w = x.shape
x = x.view(b * self.groups, -1, h, w) # bs*g,dim//g,h,w
xn = x * self.avg_pool(x) # bs*g,dim//g,h,w
xn = xn.sum(dim=1, keepdim=True) # bs*g,1,h,w
t = xn.view(b * self.groups, -1) # bs*g,h*w
t = t - t.mean(dim=1, keepdim=True) # bs*g,h*w
std = t.std(dim=1, keepdim=True) + 1e-5
t = t / std # bs*g,h*w
t = t.view(b, self.groups, h, w) # bs,g,h*w
t = t * self.weight + self.bias # bs,g,h*w
t = t.view(b * self.groups, 1, h, w) # bs*g,1,h*w
x = x * self.sig(t)
x = x.view(b, c, h, w)
return x
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
sge = SpatialGroupEnhance(groups=8)
output = sge(input)
print(output.shape)