eP-ALM / TimeSformer /timesformer /models /nonlocal_helper.py
mshukor
init
3eb682b
raw
history blame
5.4 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Non-local helper"""
import torch
import torch.nn as nn
class Nonlocal(nn.Module):
"""
Builds Non-local Neural Networks as a generic family of building
blocks for capturing long-range dependencies. Non-local Network
computes the response at a position as a weighted sum of the
features at all positions. This building block can be plugged into
many computer vision architectures.
More details in the paper: https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(
self,
dim,
dim_inner,
pool_size=None,
instantiation="softmax",
zero_init_final_conv=False,
zero_init_final_norm=True,
norm_eps=1e-5,
norm_momentum=0.1,
norm_module=nn.BatchNorm3d,
):
"""
Args:
dim (int): number of dimension for the input.
dim_inner (int): number of dimension inside of the Non-local block.
pool_size (list): the kernel size of spatial temporal pooling,
temporal pool kernel size, spatial pool kernel size, spatial
pool kernel size in order. By default pool_size is None,
then there would be no pooling used.
instantiation (string): supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
zero_init_final_conv (bool): If true, zero initializing the final
convolution of the Non-local block.
zero_init_final_norm (bool):
If true, zero initializing the final batch norm of the Non-local
block.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(Nonlocal, self).__init__()
self.dim = dim
self.dim_inner = dim_inner
self.pool_size = pool_size
self.instantiation = instantiation
self.use_pool = (
False
if pool_size is None
else any((size > 1 for size in pool_size))
)
self.norm_eps = norm_eps
self.norm_momentum = norm_momentum
self._construct_nonlocal(
zero_init_final_conv, zero_init_final_norm, norm_module
)
def _construct_nonlocal(
self, zero_init_final_conv, zero_init_final_norm, norm_module
):
# Three convolution heads: theta, phi, and g.
self.conv_theta = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
self.conv_phi = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
self.conv_g = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
# Final convolution output.
self.conv_out = nn.Conv3d(
self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0
)
# Zero initializing the final convolution output.
self.conv_out.zero_init = zero_init_final_conv
# TODO: change the name to `norm`
self.bn = norm_module(
num_features=self.dim,
eps=self.norm_eps,
momentum=self.norm_momentum,
)
# Zero initializing the final bn.
self.bn.transform_final_bn = zero_init_final_norm
# Optional to add the spatial-temporal pooling.
if self.use_pool:
self.pool = nn.MaxPool3d(
kernel_size=self.pool_size,
stride=self.pool_size,
padding=[0, 0, 0],
)
def forward(self, x):
x_identity = x
N, C, T, H, W = x.size()
theta = self.conv_theta(x)
# Perform temporal-spatial pooling to reduce the computation.
if self.use_pool:
x = self.pool(x)
phi = self.conv_phi(x)
g = self.conv_g(x)
theta = theta.view(N, self.dim_inner, -1)
phi = phi.view(N, self.dim_inner, -1)
g = g.view(N, self.dim_inner, -1)
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
# For original Non-local paper, there are two main ways to normalize
# the affinity tensor:
# 1) Softmax normalization (norm on exp).
# 2) dot_product normalization.
if self.instantiation == "softmax":
# Normalizing the affinity tensor theta_phi before softmax.
theta_phi = theta_phi * (self.dim_inner ** -0.5)
theta_phi = nn.functional.softmax(theta_phi, dim=2)
elif self.instantiation == "dot_product":
spatial_temporal_dim = theta_phi.shape[2]
theta_phi = theta_phi / spatial_temporal_dim
else:
raise NotImplementedError(
"Unknown norm type {}".format(self.instantiation)
)
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, TxHxW) => (N, C, T, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
p = self.conv_out(theta_phi_g)
p = self.bn(p)
return x_identity + p