Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
__all__ = ['SE3d'] | |
class Swish(nn.Module): | |
def forward(self,x): | |
return x * torch.sigmoid(x) | |
class SE3d(nn.Module): | |
def __init__(self, channel, reduction=8, use_relu=False): | |
super().__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(channel, channel // reduction, bias=False), | |
nn.ReLU(True) if use_relu else Swish() , | |
nn.Linear(channel // reduction, channel, bias=False), | |
nn.Sigmoid() | |
) | |
def forward(self, inputs): | |
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1) | |