import torch from torch import nn class EMA(nn.Module): def __init__(self, channels, factor=8): super(EMA, self).__init__() self.groups = factor assert channels // self.groups > 0 self.softmax = nn.Softmax(-1) self.agp = nn.AdaptiveAvgPool2d((1, 1)) self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups) self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0) self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1) def forward(self, x): b, c, h, w = x.size() group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w x_h = self.pool_h(group_x) x_w = self.pool_w(group_x).permute(0, 1, 3, 2) hw = self.conv1x1(torch.cat([x_h, x_w], dim=2)) x_h, x_w = torch.split(hw, [h, w], dim=2) x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid()) x2 = self.conv3x3(group_x) x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1)) x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w) return (group_x * weights.sigmoid()).reshape(b, c, h, w)