DmitrMakeev's picture
Upload 8 files
add457a
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from functools import partial
import torch
import torch.nn as nn
from imaginaire.layers import Conv2dBlock
class NonLocal2dBlock(nn.Module):
r"""Self attention Layer
Args:
in_channels (int): Number of channels in the input tensor.
scale (bool, optional, default=True): If ``True``, scale the
output by a learnable parameter.
clamp (bool, optional, default=``False``): If ``True``, clamp the
scaling parameter to (-1, 1).
weight_norm_type (str, optional, default='none'):
Type of weight normalization.
``'none'``, ``'spectral'``, ``'weight'``
or ``'weight_demod'``.
"""
def __init__(self,
in_channels,
scale=True,
clamp=False,
weight_norm_type='none'):
super(NonLocal2dBlock, self).__init__()
self.clamp = clamp
self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
self.in_channels = in_channels
base_conv2d_block = partial(Conv2dBlock,
kernel_size=1,
stride=1,
padding=0,
weight_norm_type=weight_norm_type)
self.theta = base_conv2d_block(in_channels, in_channels // 8)
self.phi = base_conv2d_block(in_channels, in_channels // 8)
self.g = base_conv2d_block(in_channels, in_channels // 2)
self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
self.softmax = nn.Softmax(dim=-1)
self.max_pool = nn.MaxPool2d(2)
def forward(self, x):
r"""
Args:
x (tensor) : input feature maps (B X C X W X H)
Returns:
(tuple):
- out (tensor) : self attention value + input feature
- attention (tensor): B x N x N (N is Width*Height)
"""
n, c, h, w = x.size()
theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1)
phi = self.phi(x)
phi = self.max_pool(phi).view(n, -1, h * w // 4)
energy = torch.bmm(theta, phi)
attention = self.softmax(energy)
g = self.g(x)
g = self.max_pool(g).view(n, -1, h * w // 4)
out = torch.bmm(g, attention.permute(0, 2, 1))
out = out.view(n, c // 2, h, w)
out = self.out_conv(out)
if self.clamp:
out = self.gamma.clamp(-1, 1) * out + x
else:
out = self.gamma * out + x
return out