Spaces:
Running
Running
# pylint: disable=missing-module-docstring,invalid-name | |
# pylint: disable=missing-docstring | |
# pylint: disable=line-too-long | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class LayerNorm(nn.Module): | |
r"""Applies Layer Normalization over a mini-batch of inputs as described in | |
the paper `Layer Normalization`_ . | |
.. math:: | |
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | |
The mean and standard-deviation are calculated separately over the last | |
certain number dimensions which have to be of the shape specified by | |
:attr:`normalized_shape`. | |
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of | |
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. | |
.. note:: | |
Unlike Batch Normalization and Instance Normalization, which applies | |
scalar scale and bias for each entire channel/plane with the | |
:attr:`affine` option, Layer Normalization applies per-element scale and | |
bias with :attr:`elementwise_affine`. | |
This layer uses statistics computed from input data in both training and | |
evaluation modes. | |
Args: | |
normalized_shape (int or list or torch.Size): input shape from an expected input | |
of size | |
.. math:: | |
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] | |
\times \ldots \times \text{normalized\_shape}[-1]] | |
If a single integer is used, it is treated as a singleton list, and this module will | |
normalize over the last dimension which is expected to be of that specific size. | |
eps: a value added to the denominator for numerical stability. Default: 1e-5 | |
elementwise_affine: a boolean value that when set to ``True``, this module | |
has learnable per-element affine parameters initialized to ones (for weights) | |
and zeros (for biases). Default: ``True``. | |
Shape: | |
- Input: :math:`(N, *)` | |
- Output: :math:`(N, *)` (same shape as input) | |
Examples:: | |
>>> input = torch.randn(20, 5, 10, 10) | |
>>> # With Learnable Parameters | |
>>> m = nn.LayerNorm(input.size()[1:]) | |
>>> # Without Learnable Parameters | |
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False) | |
>>> # Normalize over last two dimensions | |
>>> m = nn.LayerNorm([10, 10]) | |
>>> # Normalize over last dimension of size 10 | |
>>> m = nn.LayerNorm(10) | |
>>> # Activating the module | |
>>> output = m(input) | |
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 | |
""" | |
__constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale'] | |
def __init__(self, features, eps=1e-12, center=True, scale=True): | |
super(LayerNorm, self).__init__() | |
self.features = features | |
self.eps = eps | |
self.center = center | |
self.scale = scale | |
if self.scale: | |
self.weight = nn.Parameter(torch.Tensor(self.features)) | |
else: | |
self.register_parameter('weight', None) | |
if self.center: | |
self.bias = nn.Parameter(torch.Tensor(self.features)) | |
else: | |
self.register_parameter('bias', None) | |
self.reset_parameters() | |
def reset_parameters(self): | |
if self.scale: | |
nn.init.ones_(self.weight) | |
if self.center: | |
nn.init.zeros_(self.bias) | |
def adjust_parameter(self, tensor, parameter): | |
return torch.repeat_interleave( | |
torch.repeat_interleave( | |
parameter.view(-1, 1, 1), | |
repeats=tensor.shape[2], | |
dim=1), | |
repeats=tensor.shape[3], | |
dim=2 | |
) | |
def forward(self, input): | |
normalized_shape = (self.features, input.shape[2], input.shape[3]) | |
weight = self.adjust_parameter(input, self.weight) | |
bias = self.adjust_parameter(input, self.bias) | |
return F.layer_norm( | |
input, normalized_shape, weight, bias, self.eps) | |
def extra_repr(self): | |
return '{features}, eps={eps}, ' \ | |
'center={center}, scale={scale}'.format(**self.__dict__) | |
def gaussian_filter_1d(tensor, dim, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0): | |
sigma = torch.as_tensor(sigma, device=tensor.device, dtype=tensor.dtype) | |
if kernel_size is not None: | |
kernel_size = torch.as_tensor(kernel_size, device=tensor.device, dtype=torch.int64) | |
else: | |
kernel_size = torch.as_tensor(2 * torch.ceil(truncate * sigma) + 1, device=tensor.device, dtype=torch.int64) | |
kernel_size = kernel_size.detach() | |
kernel_size_int = kernel_size.detach().cpu().numpy() | |
mean = (torch.as_tensor(kernel_size, dtype=tensor.dtype) - 1) / 2 | |
grid = torch.arange(kernel_size, device=tensor.device) - mean | |
kernel_shape = (1, 1, kernel_size) | |
grid = grid.view(kernel_shape) | |
grid = grid.detach() | |
source_shape = tensor.shape | |
tensor = torch.movedim(tensor, dim, len(source_shape)-1) | |
dim_last_shape = tensor.shape | |
assert tensor.shape[-1] == source_shape[dim] | |
# we need reshape instead of view for batches like B x C x H x W | |
tensor = tensor.reshape(-1, 1, source_shape[dim]) | |
padding = (math.ceil((kernel_size_int - 1) / 2), math.ceil((kernel_size_int - 1) / 2)) | |
tensor_ = F.pad(tensor, padding, padding_mode, padding_value) | |
# create gaussian kernel from grid using current sigma | |
kernel = torch.exp(-0.5 * (grid / sigma) ** 2) | |
kernel = kernel / kernel.sum() | |
# convolve input with gaussian kernel | |
tensor_ = F.conv1d(tensor_, kernel) | |
tensor_ = tensor_.view(dim_last_shape) | |
tensor_ = torch.movedim(tensor_, len(source_shape)-1, dim) | |
assert tensor_.shape == source_shape | |
return tensor_ | |
class GaussianFilterNd(nn.Module): | |
"""A differentiable gaussian filter""" | |
def __init__(self, dims, sigma, truncate=4, kernel_size=None, padding_mode='replicate', padding_value=0.0, | |
trainable=False): | |
"""Creates a 1d gaussian filter | |
Args: | |
dims ([int]): the dimensions to which the gaussian filter is applied. Negative values won't work | |
sigma (float): standard deviation of the gaussian filter (blur size) | |
input_dims (int, optional): number of input dimensions ignoring batch and channel dimension, | |
i.e. use input_dims=2 for images (default: 2). | |
truncate (float, optional): truncate the filter at this many standard deviations (default: 4.0). | |
This has no effect if the `kernel_size` is explicitely set | |
kernel_size (int): size of the gaussian kernel convolved with the input | |
padding_mode (string, optional): Padding mode implemented by `torch.nn.functional.pad`. | |
padding_value (string, optional): Value used for constant padding. | |
""" | |
# IDEA determine input_dims dynamically for every input | |
super(GaussianFilterNd, self).__init__() | |
self.dims = dims | |
self.sigma = nn.Parameter(torch.tensor(sigma, dtype=torch.float32), requires_grad=trainable) # default: no optimization | |
self.truncate = truncate | |
self.kernel_size = kernel_size | |
# setup padding | |
self.padding_mode = padding_mode | |
self.padding_value = padding_value | |
def forward(self, tensor): | |
"""Applies the gaussian filter to the given tensor""" | |
for dim in self.dims: | |
tensor = gaussian_filter_1d( | |
tensor, | |
dim=dim, | |
sigma=self.sigma, | |
truncate=self.truncate, | |
kernel_size=self.kernel_size, | |
padding_mode=self.padding_mode, | |
padding_value=self.padding_value, | |
) | |
return tensor | |
class Conv2dMultiInput(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, bias=True): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
for k, _in_channels in enumerate(in_channels): | |
if _in_channels: | |
setattr(self, f'conv_part{k}', nn.Conv2d(_in_channels, out_channels, kernel_size, bias=bias)) | |
def forward(self, tensors): | |
assert len(tensors) == len(self.in_channels) | |
out = None | |
for k, (count, tensor) in enumerate(zip(self.in_channels, tensors)): | |
if not count: | |
continue | |
_out = getattr(self, f'conv_part{k}')(tensor) | |
if out is None: | |
out = _out | |
else: | |
out += _out | |
return out | |
# def extra_repr(self): | |
# return f'{self.in_channels}' | |
class LayerNormMultiInput(nn.Module): | |
__constants__ = ['features', 'weight', 'bias', 'eps', 'center', 'scale'] | |
def __init__(self, features, eps=1e-12, center=True, scale=True): | |
super().__init__() | |
self.features = features | |
self.eps = eps | |
self.center = center | |
self.scale = scale | |
for k, _features in enumerate(features): | |
if _features: | |
setattr(self, f'layernorm_part{k}', LayerNorm(_features, eps=eps, center=center, scale=scale)) | |
def forward(self, tensors): | |
assert len(tensors) == len(self.features) | |
out = [] | |
for k, (count, tensor) in enumerate(zip(self.features, tensors)): | |
if not count: | |
assert tensor is None | |
out.append(None) | |
continue | |
out.append(getattr(self, f'layernorm_part{k}')(tensor)) | |
return out | |
class Bias(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.channels = channels | |
self.bias = nn.Parameter(torch.zeros(channels)) | |
def forward(self, tensor): | |
return tensor + self.bias[np.newaxis, :, np.newaxis, np.newaxis] | |
def extra_repr(self): | |
return f'channels={self.channels}' | |
class SelfAttention(nn.Module): | |
""" Self attention Layer | |
adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3 | |
""" | |
def __init__(self, in_channels, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False, return_attention=True): | |
super().__init__() | |
self.in_channels = in_channels | |
if out_channels is None: | |
out_channels = in_channels | |
self.out_channels = out_channels | |
if key_channels is None: | |
key_channels = in_channels // 8 | |
self.key_channels = key_channels | |
self.activation = activation | |
self.skip_connection_with_convolution = skip_connection_with_convolution | |
if not self.skip_connection_with_convolution: | |
if self.out_channels != self.in_channels: | |
raise ValueError("out_channels has to be equal to in_channels with true skip connection!") | |
self.return_attention = return_attention | |
self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1) | |
self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=key_channels, kernel_size=1) | |
self.value_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
if self.skip_connection_with_convolution: | |
self.skip_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x): | |
""" | |
inputs : | |
x : input feature maps( B X C X W X H) | |
returns : | |
out : self attention value + input feature | |
attention: B X N X N (N is Width*Height) | |
""" | |
m_batchsize, C, width, height = x.size() | |
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) | |
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) | |
energy = torch.bmm(proj_query, proj_key) # transpose check | |
attention = self.softmax(energy) # BX (N) X (N) | |
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N | |
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
out = out.view(m_batchsize, self.out_channels, width, height) | |
if self.skip_connection_with_convolution: | |
skip_connection = self.skip_conv(x) | |
else: | |
skip_connection = x | |
out = self.gamma * out + skip_connection | |
if self.activation is not None: | |
out = self.activation(out) | |
if self.return_attention: | |
return out, attention | |
return out | |
class MultiHeadSelfAttention(nn.Module): | |
""" Self attention Layer | |
adapted from https://discuss.pytorch.org/t/attention-in-image-classification/80147/3 | |
""" | |
def __init__(self, in_channels, heads, out_channels=None, key_channels=None, activation=None, skip_connection_with_convolution=False): | |
super().__init__() | |
self.heads = heads | |
self.heads = nn.ModuleList([SelfAttention( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
key_channels=key_channels, | |
activation=activation, | |
skip_connection_with_convolution=skip_connection_with_convolution, | |
return_attention=False, | |
) for _ in range(heads)]) | |
def forward(self, tensor): | |
outs = [head(tensor) for head in self.heads] | |
out = torch.cat(outs, dim=1) | |
return out | |
class FlexibleScanpathHistoryEncoding(nn.Module): | |
""" | |
a convolutional layer which works for different numbers of previous fixations. | |
Nonexistent fixations will deactivate the respective convolutions | |
the bias will be added per fixation (if the given fixation is present) | |
""" | |
def __init__(self, in_fixations, channels_per_fixation, out_channels, kernel_size, bias=True,): | |
super().__init__() | |
self.in_fixations = in_fixations | |
self.channels_per_fixation = channels_per_fixation | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.bias = bias | |
self.convolutions = nn.ModuleList([ | |
nn.Conv2d( | |
in_channels=self.channels_per_fixation, | |
out_channels=self.out_channels, | |
kernel_size=self.kernel_size, | |
bias=self.bias | |
) for i in range(in_fixations) | |
]) | |
def forward(self, tensor): | |
results = None | |
valid_fixations = ~torch.isnan( | |
tensor[:, :self.in_fixations, 0, 0] | |
) | |
# print("valid fix", valid_fixations) | |
for fixation_index in range(self.in_fixations): | |
valid_indices = valid_fixations[:, fixation_index] | |
if not torch.any(valid_indices): | |
continue | |
this_input = tensor[ | |
valid_indices, | |
fixation_index::self.in_fixations | |
] | |
this_result = self.convolutions[fixation_index]( | |
this_input | |
) | |
# TODO: This will break if all data points | |
# in the batch don't have a single fixation | |
# but that's not a case I intend to train | |
# anyway. | |
if results is None: | |
b, _, _, _ = tensor.shape | |
_, _, h, w = this_result.shape | |
results = torch.zeros( | |
(b, self.out_channels, h, w), | |
dtype=tensor.dtype, | |
device=tensor.device | |
) | |
results[valid_indices] += this_result | |
return results | |