#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Video models.""" import torch import torch.nn as nn from pytorchvideo.layers.swish import Swish def drop_path(x, drop_prob: float = 0.0, training: bool = False): """ Stochastic Depth per sample. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) mask.floor_() # binarize output = x.div(keep_prob) * mask return output class Nonlocal(nn.Module): """ Builds Non-local Neural Networks as a generic family of building blocks for capturing long-range dependencies. Non-local Network computes the response at a position as a weighted sum of the features at all positions. This building block can be plugged into many computer vision architectures. More details in the paper: https://arxiv.org/pdf/1711.07971.pdf """ def __init__( self, dim, dim_inner, pool_size=None, instantiation="softmax", zero_init_final_conv=False, zero_init_final_norm=True, norm_eps=1e-5, norm_momentum=0.1, norm_module=nn.BatchNorm3d, ): """ Args: dim (int): number of dimension for the input. dim_inner (int): number of dimension inside of the Non-local block. pool_size (list): the kernel size of spatial temporal pooling, temporal pool kernel size, spatial pool kernel size, spatial pool kernel size in order. By default pool_size is None, then there would be no pooling used. instantiation (string): supports two different instantiation method: "dot_product": normalizing correlation matrix with L2. "softmax": normalizing correlation matrix with Softmax. zero_init_final_conv (bool): If true, zero initializing the final convolution of the Non-local block. zero_init_final_norm (bool): If true, zero initializing the final batch norm of the Non-local block. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(Nonlocal, self).__init__() self.dim = dim self.dim_inner = dim_inner self.pool_size = pool_size self.instantiation = instantiation self.use_pool = ( False if pool_size is None else any((size > 1 for size in pool_size)) ) self.norm_eps = norm_eps self.norm_momentum = norm_momentum self._construct_nonlocal( zero_init_final_conv, zero_init_final_norm, norm_module ) def _construct_nonlocal( self, zero_init_final_conv, zero_init_final_norm, norm_module ): # Three convolution heads: theta, phi, and g. self.conv_theta = nn.Conv3d( self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 ) self.conv_phi = nn.Conv3d( self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 ) self.conv_g = nn.Conv3d( self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 ) # Final convolution output. self.conv_out = nn.Conv3d( self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 ) # Zero initializing the final convolution output. self.conv_out.zero_init = zero_init_final_conv # TODO: change the name to `norm` self.bn = norm_module( num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum, ) # Zero initializing the final bn. self.bn.transform_final_bn = zero_init_final_norm # Optional to add the spatial-temporal pooling. if self.use_pool: self.pool = nn.MaxPool3d( kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0], ) def forward(self, x): x_identity = x N, C, T, H, W = x.size() theta = self.conv_theta(x) # Perform temporal-spatial pooling to reduce the computation. if self.use_pool: x = self.pool(x) phi = self.conv_phi(x) g = self.conv_g(x) theta = theta.view(N, self.dim_inner, -1) phi = phi.view(N, self.dim_inner, -1) g = g.view(N, self.dim_inner, -1) # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) # For original Non-local paper, there are two main ways to normalize # the affinity tensor: # 1) Softmax normalization (norm on exp). # 2) dot_product normalization. if self.instantiation == "softmax": # Normalizing the affinity tensor theta_phi before softmax. theta_phi = theta_phi * (self.dim_inner**-0.5) theta_phi = nn.functional.softmax(theta_phi, dim=2) elif self.instantiation == "dot_product": spatial_temporal_dim = theta_phi.shape[2] theta_phi = theta_phi / spatial_temporal_dim else: raise NotImplementedError("Unknown norm type {}".format(self.instantiation)) # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) # (N, C, TxHxW) => (N, C, T, H, W). theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) p = self.conv_out(theta_phi_g) p = self.bn(p) return x_identity + p class SE(nn.Module): """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" def _round_width(self, width, multiplier, min_width=8, divisor=8): """ Round width of filters based on width multiplier Args: width (int): the channel dimensions of the input. multiplier (float): the multiplication factor. min_width (int): the minimum width after multiplication. divisor (int): the new width should be dividable by divisor. """ if not multiplier: return width width *= multiplier min_width = min_width or divisor width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) if width_out < 0.9 * width: width_out += divisor return int(width_out) def __init__(self, dim_in, ratio, relu_act=True): """ Args: dim_in (int): the channel dimensions of the input. ratio (float): the channel reduction ratio for squeeze. relu_act (bool): whether to use ReLU activation instead of Swish (default). divisor (int): the new width should be dividable by divisor. """ super(SE, self).__init__() self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) dim_fc = self._round_width(dim_in, ratio) self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) self.fc1_act = nn.ReLU() if relu_act else Swish() self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) self.fc2_sig = nn.Sigmoid() def forward(self, x): x_in = x for module in self.children(): x = module(x) return x_in * x def get_trans_func(name): """ Retrieves the transformation module by name. """ trans_funcs = { "bottleneck_transform": BottleneckTransform, "basic_transform": BasicTransform, "x3d_transform": X3DTransform, } assert ( name in trans_funcs.keys() ), "Transformation function '{}' not supported".format(name) return trans_funcs[name] class BasicTransform(nn.Module): """ Basic transformation: Tx3x3, 1x3x3, where T is the size of temporal kernel. """ def __init__( self, dim_in, dim_out, temp_kernel_size, stride, dim_inner=None, num_groups=1, stride_1x1=None, inplace_relu=True, eps=1e-5, bn_mmt=0.1, dilation=1, norm_module=nn.BatchNorm3d, block_idx=0, ): """ Args: dim_in (int): the channel dimensions of the input. dim_out (int): the channel dimension of the output. temp_kernel_size (int): the temporal kernel sizes of the first convolution in the basic block. stride (int): the stride of the bottleneck. dim_inner (None): the inner dimension would not be used in BasicTransform. num_groups (int): number of groups for the convolution. Number of group is always 1 for BasicTransform. stride_1x1 (None): stride_1x1 will not be used in BasicTransform. inplace_relu (bool): if True, calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(BasicTransform, self).__init__() self.temp_kernel_size = temp_kernel_size self._inplace_relu = inplace_relu self._eps = eps self._bn_mmt = bn_mmt self._construct(dim_in, dim_out, stride, dilation, norm_module) def _construct(self, dim_in, dim_out, stride, dilation, norm_module): # Tx3x3, BN, ReLU. self.a = nn.Conv3d( dim_in, dim_out, kernel_size=[self.temp_kernel_size, 3, 3], stride=[1, stride, stride], padding=[int(self.temp_kernel_size // 2), 1, 1], bias=False, ) self.a_bn = norm_module( num_features=dim_out, eps=self._eps, momentum=self._bn_mmt ) self.a_relu = nn.ReLU(inplace=self._inplace_relu) # 1x3x3, BN. self.b = nn.Conv3d( dim_out, dim_out, kernel_size=[1, 3, 3], stride=[1, 1, 1], padding=[0, dilation, dilation], dilation=[1, dilation, dilation], bias=False, ) self.b.final_conv = True self.b_bn = norm_module( num_features=dim_out, eps=self._eps, momentum=self._bn_mmt ) self.b_bn.transform_final_bn = True def forward(self, x): x = self.a(x) x = self.a_bn(x) x = self.a_relu(x) x = self.b(x) x = self.b_bn(x) return x class X3DTransform(nn.Module): """ X3D transformation: 1x1x1, Tx3x3 (channelwise, num_groups=dim_in), 1x1x1, augmented with (optional) SE (squeeze-excitation) on the 3x3x3 output. T is the temporal kernel size (defaulting to 3) """ def __init__( self, dim_in, dim_out, temp_kernel_size, stride, dim_inner, num_groups, stride_1x1=False, inplace_relu=True, eps=1e-5, bn_mmt=0.1, dilation=1, norm_module=nn.BatchNorm3d, se_ratio=0.0625, swish_inner=True, block_idx=0, ): """ Args: dim_in (int): the channel dimensions of the input. dim_out (int): the channel dimension of the output. temp_kernel_size (int): the temporal kernel sizes of the middle convolution in the bottleneck. stride (int): the stride of the bottleneck. dim_inner (int): the inner dimension of the block. num_groups (int): number of groups for the convolution. num_groups=1 is for standard ResNet like networks, and num_groups>1 is for ResNeXt like networks. stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise apply stride to the 3x3 conv. inplace_relu (bool): if True, calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. dilation (int): size of dilation. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. se_ratio (float): if > 0, apply SE to the Tx3x3 conv, with the SE channel dimensionality being se_ratio times the Tx3x3 conv dim. swish_inner (bool): if True, apply swish to the Tx3x3 conv, otherwise apply ReLU to the Tx3x3 conv. """ super(X3DTransform, self).__init__() self.temp_kernel_size = temp_kernel_size self._inplace_relu = inplace_relu self._eps = eps self._bn_mmt = bn_mmt self._se_ratio = se_ratio self._swish_inner = swish_inner self._stride_1x1 = stride_1x1 self._block_idx = block_idx self._construct( dim_in, dim_out, stride, dim_inner, num_groups, dilation, norm_module, ) def _construct( self, dim_in, dim_out, stride, dim_inner, num_groups, dilation, norm_module, ): (str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) # 1x1x1, BN, ReLU. self.a = nn.Conv3d( dim_in, dim_inner, kernel_size=[1, 1, 1], stride=[1, str1x1, str1x1], padding=[0, 0, 0], bias=False, ) self.a_bn = norm_module( num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt ) self.a_relu = nn.ReLU(inplace=self._inplace_relu) # Tx3x3, BN, ReLU. self.b = nn.Conv3d( dim_inner, dim_inner, [self.temp_kernel_size, 3, 3], stride=[1, str3x3, str3x3], padding=[int(self.temp_kernel_size // 2), dilation, dilation], groups=num_groups, bias=False, dilation=[1, dilation, dilation], ) self.b_bn = norm_module( num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt ) # Apply SE attention or not use_se = True if (self._block_idx + 1) % 2 else False if self._se_ratio > 0.0 and use_se: self.se = SE(dim_inner, self._se_ratio) if self._swish_inner: self.b_relu = Swish() else: self.b_relu = nn.ReLU(inplace=self._inplace_relu) # 1x1x1, BN. self.c = nn.Conv3d( dim_inner, dim_out, kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], bias=False, ) self.c_bn = norm_module( num_features=dim_out, eps=self._eps, momentum=self._bn_mmt ) self.c_bn.transform_final_bn = True def forward(self, x): for block in self.children(): x = block(x) return x class BottleneckTransform(nn.Module): """ Bottleneck transformation: Tx1x1, 1x3x3, 1x1x1, where T is the size of temporal kernel. """ def __init__( self, dim_in, dim_out, temp_kernel_size, stride, dim_inner, num_groups, stride_1x1=False, inplace_relu=True, eps=1e-5, bn_mmt=0.1, dilation=1, norm_module=nn.BatchNorm3d, block_idx=0, ): """ Args: dim_in (int): the channel dimensions of the input. dim_out (int): the channel dimension of the output. temp_kernel_size (int): the temporal kernel sizes of the first convolution in the bottleneck. stride (int): the stride of the bottleneck. dim_inner (int): the inner dimension of the block. num_groups (int): number of groups for the convolution. num_groups=1 is for standard ResNet like networks, and num_groups>1 is for ResNeXt like networks. stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise apply stride to the 3x3 conv. inplace_relu (bool): if True, calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. dilation (int): size of dilation. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(BottleneckTransform, self).__init__() self.temp_kernel_size = temp_kernel_size self._inplace_relu = inplace_relu self._eps = eps self._bn_mmt = bn_mmt self._stride_1x1 = stride_1x1 self._construct( dim_in, dim_out, stride, dim_inner, num_groups, dilation, norm_module, ) def _construct( self, dim_in, dim_out, stride, dim_inner, num_groups, dilation, norm_module, ): (str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride) # Tx1x1, BN, ReLU. self.a = nn.Conv3d( dim_in, dim_inner, kernel_size=[self.temp_kernel_size, 1, 1], stride=[1, str1x1, str1x1], padding=[int(self.temp_kernel_size // 2), 0, 0], bias=False, ) self.a_bn = norm_module( num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt ) self.a_relu = nn.ReLU(inplace=self._inplace_relu) # 1x3x3, BN, ReLU. self.b = nn.Conv3d( dim_inner, dim_inner, [1, 3, 3], stride=[1, str3x3, str3x3], padding=[0, dilation, dilation], groups=num_groups, bias=False, dilation=[1, dilation, dilation], ) self.b_bn = norm_module( num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt ) self.b_relu = nn.ReLU(inplace=self._inplace_relu) # 1x1x1, BN. self.c = nn.Conv3d( dim_inner, dim_out, kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], bias=False, ) self.c.final_conv = True self.c_bn = norm_module( num_features=dim_out, eps=self._eps, momentum=self._bn_mmt ) self.c_bn.transform_final_bn = True def forward(self, x): # Explicitly forward every layer. # Branch2a. x = self.a(x) x = self.a_bn(x) x = self.a_relu(x) # Branch2b. x = self.b(x) x = self.b_bn(x) x = self.b_relu(x) # Branch2c x = self.c(x) x = self.c_bn(x) return x class ResBlock(nn.Module): """ Residual block. """ def __init__( self, dim_in, dim_out, temp_kernel_size, stride, trans_func, dim_inner, num_groups=1, stride_1x1=False, inplace_relu=True, eps=1e-5, bn_mmt=0.1, dilation=1, norm_module=nn.BatchNorm3d, block_idx=0, drop_connect_rate=0.0, ): """ ResBlock class constructs redisual blocks. More details can be found in: Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." https://arxiv.org/abs/1512.03385 Args: dim_in (int): the channel dimensions of the input. dim_out (int): the channel dimension of the output. temp_kernel_size (int): the temporal kernel sizes of the middle convolution in the bottleneck. stride (int): the stride of the bottleneck. trans_func (string): transform function to be used to construct the bottleneck. dim_inner (int): the inner dimension of the block. num_groups (int): number of groups for the convolution. num_groups=1 is for standard ResNet like networks, and num_groups>1 is for ResNeXt like networks. stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise apply stride to the 3x3 conv. inplace_relu (bool): calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. dilation (int): size of dilation. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. drop_connect_rate (float): basic rate at which blocks are dropped, linearly increases from input to output blocks. """ super(ResBlock, self).__init__() self._inplace_relu = inplace_relu self._eps = eps self._bn_mmt = bn_mmt self._drop_connect_rate = drop_connect_rate self._construct( dim_in, dim_out, temp_kernel_size, stride, trans_func, dim_inner, num_groups, stride_1x1, inplace_relu, dilation, norm_module, block_idx, ) def _construct( self, dim_in, dim_out, temp_kernel_size, stride, trans_func, dim_inner, num_groups, stride_1x1, inplace_relu, dilation, norm_module, block_idx, ): # Use skip connection with projection if dim or res change. if (dim_in != dim_out) or (stride != 1): self.branch1 = nn.Conv3d( dim_in, dim_out, kernel_size=1, stride=[1, stride, stride], padding=0, bias=False, dilation=1, ) self.branch1_bn = norm_module( num_features=dim_out, eps=self._eps, momentum=self._bn_mmt ) self.branch2 = trans_func( dim_in, dim_out, temp_kernel_size, stride, dim_inner, num_groups, stride_1x1=stride_1x1, inplace_relu=inplace_relu, dilation=dilation, norm_module=norm_module, block_idx=block_idx, ) self.relu = nn.ReLU(self._inplace_relu) def forward(self, x): f_x = self.branch2(x) if self.training and self._drop_connect_rate > 0.0: f_x = drop_path(f_x, self._drop_connect_rate) if hasattr(self, "branch1"): x = self.branch1_bn(self.branch1(x)) + f_x else: x = x + f_x x = self.relu(x) return x class ResStage(nn.Module): """ Stage of 3D ResNet. It expects to have one or more tensors as input for single pathway (C2D, I3D, Slow), and multi-pathway (SlowFast) cases. More details can be found here: Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. "SlowFast networks for video recognition." https://arxiv.org/pdf/1812.03982.pdf """ def __init__( self, dim_in, dim_out, stride, temp_kernel_sizes, num_blocks, dim_inner, num_groups, num_block_temp_kernel, nonlocal_inds, nonlocal_group, nonlocal_pool, dilation, instantiation="softmax", trans_func_name="bottleneck_transform", stride_1x1=False, inplace_relu=True, norm_module=nn.BatchNorm3d, drop_connect_rate=0.0, ): """ The `__init__` method of any subclass should also contain these arguments. ResStage builds p streams, where p can be greater or equal to one. Args: dim_in (list): list of p the channel dimensions of the input. Different channel dimensions control the input dimension of different pathways. dim_out (list): list of p the channel dimensions of the output. Different channel dimensions control the input dimension of different pathways. temp_kernel_sizes (list): list of the p temporal kernel sizes of the convolution in the bottleneck. Different temp_kernel_sizes control different pathway. stride (list): list of the p strides of the bottleneck. Different stride control different pathway. num_blocks (list): list of p numbers of blocks for each of the pathway. dim_inner (list): list of the p inner channel dimensions of the input. Different channel dimensions control the input dimension of different pathways. num_groups (list): list of number of p groups for the convolution. num_groups=1 is for standard ResNet like networks, and num_groups>1 is for ResNeXt like networks. num_block_temp_kernel (list): extent the temp_kernel_sizes to num_block_temp_kernel blocks, then fill temporal kernel size of 1 for the rest of the layers. nonlocal_inds (list): If the tuple is empty, no nonlocal layer will be added. If the tuple is not empty, add nonlocal layers after the index-th block. dilation (list): size of dilation for each pathway. nonlocal_group (list): list of number of p nonlocal groups. Each number controls how to fold temporal dimension to batch dimension before applying nonlocal transformation. https://github.com/facebookresearch/video-nonlocal-net. instantiation (string): different instantiation for nonlocal layer. Supports two different instantiation method: "dot_product": normalizing correlation matrix with L2. "softmax": normalizing correlation matrix with Softmax. trans_func_name (string): name of the the transformation function apply on the network. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. drop_connect_rate (float): basic rate at which blocks are dropped, linearly increases from input to output blocks. """ super(ResStage, self).__init__() assert all( ( num_block_temp_kernel[i] <= num_blocks[i] for i in range(len(temp_kernel_sizes)) ) ) self.num_blocks = num_blocks self.nonlocal_group = nonlocal_group self._drop_connect_rate = drop_connect_rate self.temp_kernel_sizes = [ (temp_kernel_sizes[i] * num_blocks[i])[: num_block_temp_kernel[i]] + [1] * (num_blocks[i] - num_block_temp_kernel[i]) for i in range(len(temp_kernel_sizes)) ] assert ( len( { len(dim_in), len(dim_out), len(temp_kernel_sizes), len(stride), len(num_blocks), len(dim_inner), len(num_groups), len(num_block_temp_kernel), len(nonlocal_inds), len(nonlocal_group), } ) == 1 ) self.num_pathways = len(self.num_blocks) self._construct( dim_in, dim_out, stride, dim_inner, num_groups, trans_func_name, stride_1x1, inplace_relu, nonlocal_inds, nonlocal_pool, instantiation, dilation, norm_module, ) def _construct( self, dim_in, dim_out, stride, dim_inner, num_groups, trans_func_name, stride_1x1, inplace_relu, nonlocal_inds, nonlocal_pool, instantiation, dilation, norm_module, ): for pathway in range(self.num_pathways): for i in range(self.num_blocks[pathway]): # Retrieve the transformation function. trans_func = get_trans_func(trans_func_name) # Construct the block. res_block = ResBlock( dim_in[pathway] if i == 0 else dim_out[pathway], dim_out[pathway], self.temp_kernel_sizes[pathway][i], stride[pathway] if i == 0 else 1, trans_func, dim_inner[pathway], num_groups[pathway], stride_1x1=stride_1x1, inplace_relu=inplace_relu, dilation=dilation[pathway], norm_module=norm_module, block_idx=i, drop_connect_rate=self._drop_connect_rate, ) self.add_module("pathway{}_res{}".format( pathway, i), res_block) if i in nonlocal_inds[pathway]: nln = Nonlocal( dim_out[pathway], dim_out[pathway] // 2, nonlocal_pool[pathway], instantiation=instantiation, norm_module=norm_module, ) self.add_module( "pathway{}_nonlocal{}".format(pathway, i), nln) def forward(self, inputs): output = [] for pathway in range(self.num_pathways): x = inputs[pathway] for i in range(self.num_blocks[pathway]): m = getattr(self, "pathway{}_res{}".format(pathway, i)) x = m(x) if hasattr(self, "pathway{}_nonlocal{}".format(pathway, i)): nln = getattr( self, "pathway{}_nonlocal{}".format(pathway, i)) b, c, t, h, w = x.shape if self.nonlocal_group[pathway] > 1: # Fold temporal dimension into batch dimension. x = x.permute(0, 2, 1, 3, 4) x = x.reshape( b * self.nonlocal_group[pathway], t // self.nonlocal_group[pathway], c, h, w, ) x = x.permute(0, 2, 1, 3, 4) x = nln(x) if self.nonlocal_group[pathway] > 1: # Fold back to temporal dimension. x = x.permute(0, 2, 1, 3, 4) x = x.reshape(b, t, c, h, w) x = x.permute(0, 2, 1, 3, 4) output.append(x) return output