tongyujun's picture
Upload 641 files
8c6b5ee verified
raw
history blame contribute delete
904 Bytes
import torch.nn as nn
from torch.nn import functional as F
__all__ = ["Attention"]
class Attention(nn.Module):
"""Attention from `"Dynamic Domain Generalization" <https://github.com/MetaVisionLab/DDG>`_.
"""
def __init__(
self,
in_channels: int,
out_features: int,
squeeze=None,
bias: bool = True
):
super(Attention, self).__init__()
self.squeeze = squeeze if squeeze else in_channels // 16
assert self.squeeze > 0
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(in_channels, self.squeeze, bias=bias)
self.fc2 = nn.Linear(self.squeeze, out_features, bias=bias)
self.sf = nn.Softmax(dim=-1)
def forward(self, x):
x = self.avg_pool(x).view(x.shape[:-2])
x = self.fc1(x)
x = F.relu(x, inplace=True)
x = self.fc2(x)
return self.sf(x)