|
|
|
|
|
|
|
"""Video models.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from pytorchvideo.layers.swish import Swish |
|
|
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False): |
|
""" |
|
Stochastic Depth per sample. |
|
""" |
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * ( |
|
x.ndim - 1 |
|
) |
|
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
mask.floor_() |
|
output = x.div(keep_prob) * mask |
|
return output |
|
|
|
class Nonlocal(nn.Module): |
|
""" |
|
Builds Non-local Neural Networks as a generic family of building |
|
blocks for capturing long-range dependencies. Non-local Network |
|
computes the response at a position as a weighted sum of the |
|
features at all positions. This building block can be plugged into |
|
many computer vision architectures. |
|
More details in the paper: https://arxiv.org/pdf/1711.07971.pdf |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
dim_inner, |
|
pool_size=None, |
|
instantiation="softmax", |
|
zero_init_final_conv=False, |
|
zero_init_final_norm=True, |
|
norm_eps=1e-5, |
|
norm_momentum=0.1, |
|
norm_module=nn.BatchNorm3d, |
|
): |
|
""" |
|
Args: |
|
dim (int): number of dimension for the input. |
|
dim_inner (int): number of dimension inside of the Non-local block. |
|
pool_size (list): the kernel size of spatial temporal pooling, |
|
temporal pool kernel size, spatial pool kernel size, spatial |
|
pool kernel size in order. By default pool_size is None, |
|
then there would be no pooling used. |
|
instantiation (string): supports two different instantiation method: |
|
"dot_product": normalizing correlation matrix with L2. |
|
"softmax": normalizing correlation matrix with Softmax. |
|
zero_init_final_conv (bool): If true, zero initializing the final |
|
convolution of the Non-local block. |
|
zero_init_final_norm (bool): |
|
If true, zero initializing the final batch norm of the Non-local |
|
block. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
""" |
|
super(Nonlocal, self).__init__() |
|
self.dim = dim |
|
self.dim_inner = dim_inner |
|
self.pool_size = pool_size |
|
self.instantiation = instantiation |
|
self.use_pool = ( |
|
False if pool_size is None else any((size > 1 for size in pool_size)) |
|
) |
|
self.norm_eps = norm_eps |
|
self.norm_momentum = norm_momentum |
|
self._construct_nonlocal( |
|
zero_init_final_conv, zero_init_final_norm, norm_module |
|
) |
|
|
|
def _construct_nonlocal( |
|
self, zero_init_final_conv, zero_init_final_norm, norm_module |
|
): |
|
|
|
self.conv_theta = nn.Conv3d( |
|
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.conv_phi = nn.Conv3d( |
|
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.conv_g = nn.Conv3d( |
|
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
|
|
self.conv_out = nn.Conv3d( |
|
self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
self.conv_out.zero_init = zero_init_final_conv |
|
|
|
|
|
self.bn = norm_module( |
|
num_features=self.dim, |
|
eps=self.norm_eps, |
|
momentum=self.norm_momentum, |
|
) |
|
|
|
self.bn.transform_final_bn = zero_init_final_norm |
|
|
|
|
|
if self.use_pool: |
|
self.pool = nn.MaxPool3d( |
|
kernel_size=self.pool_size, |
|
stride=self.pool_size, |
|
padding=[0, 0, 0], |
|
) |
|
|
|
def forward(self, x): |
|
x_identity = x |
|
N, C, T, H, W = x.size() |
|
|
|
theta = self.conv_theta(x) |
|
|
|
|
|
if self.use_pool: |
|
x = self.pool(x) |
|
|
|
phi = self.conv_phi(x) |
|
g = self.conv_g(x) |
|
|
|
theta = theta.view(N, self.dim_inner, -1) |
|
phi = phi.view(N, self.dim_inner, -1) |
|
g = g.view(N, self.dim_inner, -1) |
|
|
|
|
|
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) |
|
|
|
|
|
|
|
|
|
if self.instantiation == "softmax": |
|
|
|
theta_phi = theta_phi * (self.dim_inner**-0.5) |
|
theta_phi = nn.functional.softmax(theta_phi, dim=2) |
|
elif self.instantiation == "dot_product": |
|
spatial_temporal_dim = theta_phi.shape[2] |
|
theta_phi = theta_phi / spatial_temporal_dim |
|
else: |
|
raise NotImplementedError("Unknown norm type {}".format(self.instantiation)) |
|
|
|
|
|
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) |
|
|
|
|
|
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) |
|
|
|
p = self.conv_out(theta_phi_g) |
|
p = self.bn(p) |
|
return x_identity + p |
|
|
|
class SE(nn.Module): |
|
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" |
|
|
|
def _round_width(self, width, multiplier, min_width=8, divisor=8): |
|
""" |
|
Round width of filters based on width multiplier |
|
Args: |
|
width (int): the channel dimensions of the input. |
|
multiplier (float): the multiplication factor. |
|
min_width (int): the minimum width after multiplication. |
|
divisor (int): the new width should be dividable by divisor. |
|
""" |
|
if not multiplier: |
|
return width |
|
|
|
width *= multiplier |
|
min_width = min_width or divisor |
|
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) |
|
if width_out < 0.9 * width: |
|
width_out += divisor |
|
return int(width_out) |
|
|
|
def __init__(self, dim_in, ratio, relu_act=True): |
|
""" |
|
Args: |
|
dim_in (int): the channel dimensions of the input. |
|
ratio (float): the channel reduction ratio for squeeze. |
|
relu_act (bool): whether to use ReLU activation instead |
|
of Swish (default). |
|
divisor (int): the new width should be dividable by divisor. |
|
""" |
|
super(SE, self).__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) |
|
dim_fc = self._round_width(dim_in, ratio) |
|
self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) |
|
self.fc1_act = nn.ReLU() if relu_act else Swish() |
|
self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) |
|
|
|
self.fc2_sig = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
x_in = x |
|
for module in self.children(): |
|
x = module(x) |
|
return x_in * x |
|
|
|
|
|
|
|
|
|
def get_trans_func(name): |
|
""" |
|
Retrieves the transformation module by name. |
|
""" |
|
trans_funcs = { |
|
"bottleneck_transform": BottleneckTransform, |
|
"basic_transform": BasicTransform, |
|
"x3d_transform": X3DTransform, |
|
} |
|
assert ( |
|
name in trans_funcs.keys() |
|
), "Transformation function '{}' not supported".format(name) |
|
return trans_funcs[name] |
|
|
|
|
|
class BasicTransform(nn.Module): |
|
""" |
|
Basic transformation: Tx3x3, 1x3x3, where T is the size of temporal kernel. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
dim_inner=None, |
|
num_groups=1, |
|
stride_1x1=None, |
|
inplace_relu=True, |
|
eps=1e-5, |
|
bn_mmt=0.1, |
|
dilation=1, |
|
norm_module=nn.BatchNorm3d, |
|
block_idx=0, |
|
): |
|
""" |
|
Args: |
|
dim_in (int): the channel dimensions of the input. |
|
dim_out (int): the channel dimension of the output. |
|
temp_kernel_size (int): the temporal kernel sizes of the first |
|
convolution in the basic block. |
|
stride (int): the stride of the bottleneck. |
|
dim_inner (None): the inner dimension would not be used in |
|
BasicTransform. |
|
num_groups (int): number of groups for the convolution. Number of |
|
group is always 1 for BasicTransform. |
|
stride_1x1 (None): stride_1x1 will not be used in BasicTransform. |
|
inplace_relu (bool): if True, calculate the relu on the original |
|
input without allocating new memory. |
|
eps (float): epsilon for batch norm. |
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in |
|
PyTorch = 1 - BN momentum in Caffe2. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
""" |
|
super(BasicTransform, self).__init__() |
|
self.temp_kernel_size = temp_kernel_size |
|
self._inplace_relu = inplace_relu |
|
self._eps = eps |
|
self._bn_mmt = bn_mmt |
|
self._construct(dim_in, dim_out, stride, dilation, norm_module) |
|
|
|
def _construct(self, dim_in, dim_out, stride, dilation, norm_module): |
|
|
|
self.a = nn.Conv3d( |
|
dim_in, |
|
dim_out, |
|
kernel_size=[self.temp_kernel_size, 3, 3], |
|
stride=[1, stride, stride], |
|
padding=[int(self.temp_kernel_size // 2), 1, 1], |
|
bias=False, |
|
) |
|
self.a_bn = norm_module( |
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.a_relu = nn.ReLU(inplace=self._inplace_relu) |
|
|
|
self.b = nn.Conv3d( |
|
dim_out, |
|
dim_out, |
|
kernel_size=[1, 3, 3], |
|
stride=[1, 1, 1], |
|
padding=[0, dilation, dilation], |
|
dilation=[1, dilation, dilation], |
|
bias=False, |
|
) |
|
|
|
self.b.final_conv = True |
|
|
|
self.b_bn = norm_module( |
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
|
|
self.b_bn.transform_final_bn = True |
|
|
|
def forward(self, x): |
|
x = self.a(x) |
|
x = self.a_bn(x) |
|
x = self.a_relu(x) |
|
|
|
x = self.b(x) |
|
x = self.b_bn(x) |
|
return x |
|
|
|
|
|
class X3DTransform(nn.Module): |
|
""" |
|
X3D transformation: 1x1x1, Tx3x3 (channelwise, num_groups=dim_in), 1x1x1, |
|
augmented with (optional) SE (squeeze-excitation) on the 3x3x3 output. |
|
T is the temporal kernel size (defaulting to 3) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
stride_1x1=False, |
|
inplace_relu=True, |
|
eps=1e-5, |
|
bn_mmt=0.1, |
|
dilation=1, |
|
norm_module=nn.BatchNorm3d, |
|
se_ratio=0.0625, |
|
swish_inner=True, |
|
block_idx=0, |
|
): |
|
""" |
|
Args: |
|
dim_in (int): the channel dimensions of the input. |
|
dim_out (int): the channel dimension of the output. |
|
temp_kernel_size (int): the temporal kernel sizes of the middle |
|
convolution in the bottleneck. |
|
stride (int): the stride of the bottleneck. |
|
dim_inner (int): the inner dimension of the block. |
|
num_groups (int): number of groups for the convolution. num_groups=1 |
|
is for standard ResNet like networks, and num_groups>1 is for |
|
ResNeXt like networks. |
|
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise |
|
apply stride to the 3x3 conv. |
|
inplace_relu (bool): if True, calculate the relu on the original |
|
input without allocating new memory. |
|
eps (float): epsilon for batch norm. |
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in |
|
PyTorch = 1 - BN momentum in Caffe2. |
|
dilation (int): size of dilation. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
se_ratio (float): if > 0, apply SE to the Tx3x3 conv, with the SE |
|
channel dimensionality being se_ratio times the Tx3x3 conv dim. |
|
swish_inner (bool): if True, apply swish to the Tx3x3 conv, otherwise |
|
apply ReLU to the Tx3x3 conv. |
|
""" |
|
super(X3DTransform, self).__init__() |
|
self.temp_kernel_size = temp_kernel_size |
|
self._inplace_relu = inplace_relu |
|
self._eps = eps |
|
self._bn_mmt = bn_mmt |
|
self._se_ratio = se_ratio |
|
self._swish_inner = swish_inner |
|
self._stride_1x1 = stride_1x1 |
|
self._block_idx = block_idx |
|
self._construct( |
|
dim_in, |
|
dim_out, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
dilation, |
|
norm_module, |
|
) |
|
|
|
def _construct( |
|
self, |
|
dim_in, |
|
dim_out, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
dilation, |
|
norm_module, |
|
): |
|
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) |
|
|
|
|
|
self.a = nn.Conv3d( |
|
dim_in, |
|
dim_inner, |
|
kernel_size=[1, 1, 1], |
|
stride=[1, str1x1, str1x1], |
|
padding=[0, 0, 0], |
|
bias=False, |
|
) |
|
self.a_bn = norm_module( |
|
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.a_relu = nn.ReLU(inplace=self._inplace_relu) |
|
|
|
|
|
self.b = nn.Conv3d( |
|
dim_inner, |
|
dim_inner, |
|
[self.temp_kernel_size, 3, 3], |
|
stride=[1, str3x3, str3x3], |
|
padding=[int(self.temp_kernel_size // 2), dilation, dilation], |
|
groups=num_groups, |
|
bias=False, |
|
dilation=[1, dilation, dilation], |
|
) |
|
self.b_bn = norm_module( |
|
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
|
|
|
|
use_se = True if (self._block_idx + 1) % 2 else False |
|
if self._se_ratio > 0.0 and use_se: |
|
self.se = SE(dim_inner, self._se_ratio) |
|
|
|
if self._swish_inner: |
|
self.b_relu = Swish() |
|
else: |
|
self.b_relu = nn.ReLU(inplace=self._inplace_relu) |
|
|
|
|
|
self.c = nn.Conv3d( |
|
dim_inner, |
|
dim_out, |
|
kernel_size=[1, 1, 1], |
|
stride=[1, 1, 1], |
|
padding=[0, 0, 0], |
|
bias=False, |
|
) |
|
self.c_bn = norm_module( |
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.c_bn.transform_final_bn = True |
|
|
|
def forward(self, x): |
|
for block in self.children(): |
|
x = block(x) |
|
return x |
|
|
|
|
|
class BottleneckTransform(nn.Module): |
|
""" |
|
Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of |
|
temporal kernel. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
stride_1x1=False, |
|
inplace_relu=True, |
|
eps=1e-5, |
|
bn_mmt=0.1, |
|
dilation=1, |
|
norm_module=nn.BatchNorm3d, |
|
block_idx=0, |
|
): |
|
""" |
|
Args: |
|
dim_in (int): the channel dimensions of the input. |
|
dim_out (int): the channel dimension of the output. |
|
temp_kernel_size (int): the temporal kernel sizes of the first |
|
convolution in the bottleneck. |
|
stride (int): the stride of the bottleneck. |
|
dim_inner (int): the inner dimension of the block. |
|
num_groups (int): number of groups for the convolution. num_groups=1 |
|
is for standard ResNet like networks, and num_groups>1 is for |
|
ResNeXt like networks. |
|
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise |
|
apply stride to the 3x3 conv. |
|
inplace_relu (bool): if True, calculate the relu on the original |
|
input without allocating new memory. |
|
eps (float): epsilon for batch norm. |
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in |
|
PyTorch = 1 - BN momentum in Caffe2. |
|
dilation (int): size of dilation. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
""" |
|
super(BottleneckTransform, self).__init__() |
|
self.temp_kernel_size = temp_kernel_size |
|
self._inplace_relu = inplace_relu |
|
self._eps = eps |
|
self._bn_mmt = bn_mmt |
|
self._stride_1x1 = stride_1x1 |
|
self._construct( |
|
dim_in, |
|
dim_out, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
dilation, |
|
norm_module, |
|
) |
|
|
|
def _construct( |
|
self, |
|
dim_in, |
|
dim_out, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
dilation, |
|
norm_module, |
|
): |
|
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) |
|
|
|
|
|
self.a = nn.Conv3d( |
|
dim_in, |
|
dim_inner, |
|
kernel_size=[self.temp_kernel_size, 1, 1], |
|
stride=[1, str1x1, str1x1], |
|
padding=[int(self.temp_kernel_size // 2), 0, 0], |
|
bias=False, |
|
) |
|
self.a_bn = norm_module( |
|
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.a_relu = nn.ReLU(inplace=self._inplace_relu) |
|
|
|
|
|
self.b = nn.Conv3d( |
|
dim_inner, |
|
dim_inner, |
|
[1, 3, 3], |
|
stride=[1, str3x3, str3x3], |
|
padding=[0, dilation, dilation], |
|
groups=num_groups, |
|
bias=False, |
|
dilation=[1, dilation, dilation], |
|
) |
|
self.b_bn = norm_module( |
|
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.b_relu = nn.ReLU(inplace=self._inplace_relu) |
|
|
|
|
|
self.c = nn.Conv3d( |
|
dim_inner, |
|
dim_out, |
|
kernel_size=[1, 1, 1], |
|
stride=[1, 1, 1], |
|
padding=[0, 0, 0], |
|
bias=False, |
|
) |
|
self.c.final_conv = True |
|
|
|
self.c_bn = norm_module( |
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.c_bn.transform_final_bn = True |
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.a(x) |
|
x = self.a_bn(x) |
|
x = self.a_relu(x) |
|
|
|
|
|
x = self.b(x) |
|
x = self.b_bn(x) |
|
x = self.b_relu(x) |
|
|
|
|
|
x = self.c(x) |
|
x = self.c_bn(x) |
|
return x |
|
|
|
|
|
class ResBlock(nn.Module): |
|
""" |
|
Residual block. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
trans_func, |
|
dim_inner, |
|
num_groups=1, |
|
stride_1x1=False, |
|
inplace_relu=True, |
|
eps=1e-5, |
|
bn_mmt=0.1, |
|
dilation=1, |
|
norm_module=nn.BatchNorm3d, |
|
block_idx=0, |
|
drop_connect_rate=0.0, |
|
): |
|
""" |
|
ResBlock class constructs redisual blocks. More details can be found in: |
|
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. |
|
"Deep residual learning for image recognition." |
|
https://arxiv.org/abs/1512.03385 |
|
Args: |
|
dim_in (int): the channel dimensions of the input. |
|
dim_out (int): the channel dimension of the output. |
|
temp_kernel_size (int): the temporal kernel sizes of the middle |
|
convolution in the bottleneck. |
|
stride (int): the stride of the bottleneck. |
|
trans_func (string): transform function to be used to construct the |
|
bottleneck. |
|
dim_inner (int): the inner dimension of the block. |
|
num_groups (int): number of groups for the convolution. num_groups=1 |
|
is for standard ResNet like networks, and num_groups>1 is for |
|
ResNeXt like networks. |
|
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise |
|
apply stride to the 3x3 conv. |
|
inplace_relu (bool): calculate the relu on the original input |
|
without allocating new memory. |
|
eps (float): epsilon for batch norm. |
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in |
|
PyTorch = 1 - BN momentum in Caffe2. |
|
dilation (int): size of dilation. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
drop_connect_rate (float): basic rate at which blocks are dropped, |
|
linearly increases from input to output blocks. |
|
""" |
|
super(ResBlock, self).__init__() |
|
self._inplace_relu = inplace_relu |
|
self._eps = eps |
|
self._bn_mmt = bn_mmt |
|
self._drop_connect_rate = drop_connect_rate |
|
self._construct( |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
trans_func, |
|
dim_inner, |
|
num_groups, |
|
stride_1x1, |
|
inplace_relu, |
|
dilation, |
|
norm_module, |
|
block_idx, |
|
) |
|
|
|
def _construct( |
|
self, |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
trans_func, |
|
dim_inner, |
|
num_groups, |
|
stride_1x1, |
|
inplace_relu, |
|
dilation, |
|
norm_module, |
|
block_idx, |
|
): |
|
|
|
if (dim_in != dim_out) or (stride != 1): |
|
self.branch1 = nn.Conv3d( |
|
dim_in, |
|
dim_out, |
|
kernel_size=1, |
|
stride=[1, stride, stride], |
|
padding=0, |
|
bias=False, |
|
dilation=1, |
|
) |
|
self.branch1_bn = norm_module( |
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt |
|
) |
|
self.branch2 = trans_func( |
|
dim_in, |
|
dim_out, |
|
temp_kernel_size, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
stride_1x1=stride_1x1, |
|
inplace_relu=inplace_relu, |
|
dilation=dilation, |
|
norm_module=norm_module, |
|
block_idx=block_idx, |
|
) |
|
self.relu = nn.ReLU(self._inplace_relu) |
|
|
|
def forward(self, x): |
|
f_x = self.branch2(x) |
|
if self.training and self._drop_connect_rate > 0.0: |
|
f_x = drop_path(f_x, self._drop_connect_rate) |
|
if hasattr(self, "branch1"): |
|
x = self.branch1_bn(self.branch1(x)) + f_x |
|
else: |
|
x = x + f_x |
|
x = self.relu(x) |
|
return x |
|
|
|
|
|
class ResStage(nn.Module): |
|
""" |
|
Stage of 3D ResNet. It expects to have one or more tensors as input for |
|
single pathway (C2D, I3D, Slow), and multi-pathway (SlowFast) cases. |
|
More details can be found here: |
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. |
|
"SlowFast networks for video recognition." |
|
https://arxiv.org/pdf/1812.03982.pdf |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in, |
|
dim_out, |
|
stride, |
|
temp_kernel_sizes, |
|
num_blocks, |
|
dim_inner, |
|
num_groups, |
|
num_block_temp_kernel, |
|
nonlocal_inds, |
|
nonlocal_group, |
|
nonlocal_pool, |
|
dilation, |
|
instantiation="softmax", |
|
trans_func_name="bottleneck_transform", |
|
stride_1x1=False, |
|
inplace_relu=True, |
|
norm_module=nn.BatchNorm3d, |
|
drop_connect_rate=0.0, |
|
): |
|
""" |
|
The `__init__` method of any subclass should also contain these arguments. |
|
ResStage builds p streams, where p can be greater or equal to one. |
|
Args: |
|
dim_in (list): list of p the channel dimensions of the input. |
|
Different channel dimensions control the input dimension of |
|
different pathways. |
|
dim_out (list): list of p the channel dimensions of the output. |
|
Different channel dimensions control the input dimension of |
|
different pathways. |
|
temp_kernel_sizes (list): list of the p temporal kernel sizes of the |
|
convolution in the bottleneck. Different temp_kernel_sizes |
|
control different pathway. |
|
stride (list): list of the p strides of the bottleneck. Different |
|
stride control different pathway. |
|
num_blocks (list): list of p numbers of blocks for each of the |
|
pathway. |
|
dim_inner (list): list of the p inner channel dimensions of the |
|
input. Different channel dimensions control the input dimension |
|
of different pathways. |
|
num_groups (list): list of number of p groups for the convolution. |
|
num_groups=1 is for standard ResNet like networks, and |
|
num_groups>1 is for ResNeXt like networks. |
|
num_block_temp_kernel (list): extent the temp_kernel_sizes to |
|
num_block_temp_kernel blocks, then fill temporal kernel size |
|
of 1 for the rest of the layers. |
|
nonlocal_inds (list): If the tuple is empty, no nonlocal layer will |
|
be added. If the tuple is not empty, add nonlocal layers after |
|
the index-th block. |
|
dilation (list): size of dilation for each pathway. |
|
nonlocal_group (list): list of number of p nonlocal groups. Each |
|
number controls how to fold temporal dimension to batch |
|
dimension before applying nonlocal transformation. |
|
https://github.com/facebookresearch/video-nonlocal-net. |
|
instantiation (string): different instantiation for nonlocal layer. |
|
Supports two different instantiation method: |
|
"dot_product": normalizing correlation matrix with L2. |
|
"softmax": normalizing correlation matrix with Softmax. |
|
trans_func_name (string): name of the the transformation function apply |
|
on the network. |
|
norm_module (nn.Module): nn.Module for the normalization layer. The |
|
default is nn.BatchNorm3d. |
|
drop_connect_rate (float): basic rate at which blocks are dropped, |
|
linearly increases from input to output blocks. |
|
""" |
|
super(ResStage, self).__init__() |
|
assert all( |
|
( |
|
num_block_temp_kernel[i] <= num_blocks[i] |
|
for i in range(len(temp_kernel_sizes)) |
|
) |
|
) |
|
self.num_blocks = num_blocks |
|
self.nonlocal_group = nonlocal_group |
|
self._drop_connect_rate = drop_connect_rate |
|
self.temp_kernel_sizes = [ |
|
(temp_kernel_sizes[i] * num_blocks[i])[: num_block_temp_kernel[i]] |
|
+ [1] * (num_blocks[i] - num_block_temp_kernel[i]) |
|
for i in range(len(temp_kernel_sizes)) |
|
] |
|
assert ( |
|
len( |
|
{ |
|
len(dim_in), |
|
len(dim_out), |
|
len(temp_kernel_sizes), |
|
len(stride), |
|
len(num_blocks), |
|
len(dim_inner), |
|
len(num_groups), |
|
len(num_block_temp_kernel), |
|
len(nonlocal_inds), |
|
len(nonlocal_group), |
|
} |
|
) |
|
== 1 |
|
) |
|
self.num_pathways = len(self.num_blocks) |
|
self._construct( |
|
dim_in, |
|
dim_out, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
trans_func_name, |
|
stride_1x1, |
|
inplace_relu, |
|
nonlocal_inds, |
|
nonlocal_pool, |
|
instantiation, |
|
dilation, |
|
norm_module, |
|
) |
|
|
|
def _construct( |
|
self, |
|
dim_in, |
|
dim_out, |
|
stride, |
|
dim_inner, |
|
num_groups, |
|
trans_func_name, |
|
stride_1x1, |
|
inplace_relu, |
|
nonlocal_inds, |
|
nonlocal_pool, |
|
instantiation, |
|
dilation, |
|
norm_module, |
|
): |
|
for pathway in range(self.num_pathways): |
|
for i in range(self.num_blocks[pathway]): |
|
|
|
trans_func = get_trans_func(trans_func_name) |
|
|
|
res_block = ResBlock( |
|
dim_in[pathway] if i == 0 else dim_out[pathway], |
|
dim_out[pathway], |
|
self.temp_kernel_sizes[pathway][i], |
|
stride[pathway] if i == 0 else 1, |
|
trans_func, |
|
dim_inner[pathway], |
|
num_groups[pathway], |
|
stride_1x1=stride_1x1, |
|
inplace_relu=inplace_relu, |
|
dilation=dilation[pathway], |
|
norm_module=norm_module, |
|
block_idx=i, |
|
drop_connect_rate=self._drop_connect_rate, |
|
) |
|
self.add_module("pathway{}_res{}".format( |
|
pathway, i), res_block) |
|
if i in nonlocal_inds[pathway]: |
|
nln = Nonlocal( |
|
dim_out[pathway], |
|
dim_out[pathway] // 2, |
|
nonlocal_pool[pathway], |
|
instantiation=instantiation, |
|
norm_module=norm_module, |
|
) |
|
self.add_module( |
|
"pathway{}_nonlocal{}".format(pathway, i), nln) |
|
|
|
def forward(self, inputs): |
|
output = [] |
|
for pathway in range(self.num_pathways): |
|
x = inputs[pathway] |
|
for i in range(self.num_blocks[pathway]): |
|
m = getattr(self, "pathway{}_res{}".format(pathway, i)) |
|
x = m(x) |
|
if hasattr(self, "pathway{}_nonlocal{}".format(pathway, i)): |
|
nln = getattr( |
|
self, "pathway{}_nonlocal{}".format(pathway, i)) |
|
b, c, t, h, w = x.shape |
|
if self.nonlocal_group[pathway] > 1: |
|
|
|
x = x.permute(0, 2, 1, 3, 4) |
|
x = x.reshape( |
|
b * self.nonlocal_group[pathway], |
|
t // self.nonlocal_group[pathway], |
|
c, |
|
h, |
|
w, |
|
) |
|
x = x.permute(0, 2, 1, 3, 4) |
|
x = nln(x) |
|
if self.nonlocal_group[pathway] > 1: |
|
|
|
x = x.permute(0, 2, 1, 3, 4) |
|
x = x.reshape(b, t, c, h, w) |
|
x = x.permute(0, 2, 1, 3, 4) |
|
output.append(x) |
|
|
|
return output |
|
|