|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from imaginaire.layers import Conv2dBlock |
|
|
|
|
|
class NonLocal2dBlock(nn.Module): |
|
r"""Self attention Layer |
|
|
|
Args: |
|
in_channels (int): Number of channels in the input tensor. |
|
scale (bool, optional, default=True): If ``True``, scale the |
|
output by a learnable parameter. |
|
clamp (bool, optional, default=``False``): If ``True``, clamp the |
|
scaling parameter to (-1, 1). |
|
weight_norm_type (str, optional, default='none'): |
|
Type of weight normalization. |
|
``'none'``, ``'spectral'``, ``'weight'`` |
|
or ``'weight_demod'``. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
scale=True, |
|
clamp=False, |
|
weight_norm_type='none'): |
|
super(NonLocal2dBlock, self).__init__() |
|
self.clamp = clamp |
|
self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0 |
|
self.in_channels = in_channels |
|
base_conv2d_block = partial(Conv2dBlock, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
weight_norm_type=weight_norm_type) |
|
self.theta = base_conv2d_block(in_channels, in_channels // 8) |
|
self.phi = base_conv2d_block(in_channels, in_channels // 8) |
|
self.g = base_conv2d_block(in_channels, in_channels // 2) |
|
self.out_conv = base_conv2d_block(in_channels // 2, in_channels) |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.max_pool = nn.MaxPool2d(2) |
|
|
|
def forward(self, x): |
|
r""" |
|
|
|
Args: |
|
x (tensor) : input feature maps (B X C X W X H) |
|
Returns: |
|
(tuple): |
|
- out (tensor) : self attention value + input feature |
|
- attention (tensor): B x N x N (N is Width*Height) |
|
""" |
|
n, c, h, w = x.size() |
|
theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1) |
|
|
|
phi = self.phi(x) |
|
phi = self.max_pool(phi).view(n, -1, h * w // 4) |
|
|
|
energy = torch.bmm(theta, phi) |
|
attention = self.softmax(energy) |
|
|
|
g = self.g(x) |
|
g = self.max_pool(g).view(n, -1, h * w // 4) |
|
|
|
out = torch.bmm(g, attention.permute(0, 2, 1)) |
|
out = out.view(n, c // 2, h, w) |
|
out = self.out_conv(out) |
|
|
|
if self.clamp: |
|
out = self.gamma.clamp(-1, 1) * out + x |
|
else: |
|
out = self.gamma * out + x |
|
return out |
|
|