# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ResNe(X)t 3D stem helper.""" import torch.nn as nn def get_stem_func(name): """ Retrieves the stem module by name. """ trans_funcs = {"x3d_stem": X3DStem, "basic_stem": ResNetBasicStem} assert ( name in trans_funcs.keys() ), "Transformation function '{}' not supported".format(name) return trans_funcs[name] class VideoModelStem(nn.Module): """ Video 3D stem module. Provides stem operations of Conv, BN, ReLU, MaxPool on input data tensor for one or multiple pathways. """ def __init__( self, dim_in, dim_out, kernel, stride, padding, inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, stem_func_name="basic_stem", ): """ The `__init__` method of any subclass should also contain these arguments. List size of 1 for single pathway models (C2D, I3D, Slow and etc), list size of 2 for two pathway models (SlowFast). Args: dim_in (list): the list of channel dimensions of the inputs. dim_out (list): the output dimension of the convolution in the stem layer. kernel (list): the kernels' size of the convolutions in the stem layers. Temporal kernel size, height kernel size, width kernel size in order. stride (list): the stride sizes of the convolutions in the stem layer. Temporal kernel stride, height kernel size, width kernel size in order. padding (list): the paddings' sizes of the convolutions in the stem layer. Temporal padding size, height padding size, width padding size in order. inplace_relu (bool): calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. stem_func_name (string): name of the the stem function applied on input to the network. """ super(VideoModelStem, self).__init__() assert ( len( { len(dim_in), len(dim_out), len(kernel), len(stride), len(padding), } ) == 1 ), "Input pathway dimensions are not consistent." self.num_pathways = len(dim_in) self.kernel = kernel self.stride = stride self.padding = padding self.inplace_relu = inplace_relu self.eps = eps self.bn_mmt = bn_mmt # Construct the stem layer. self._construct_stem(dim_in, dim_out, norm_module, stem_func_name) def _construct_stem(self, dim_in, dim_out, norm_module, stem_func_name): trans_func = get_stem_func(stem_func_name) for pathway in range(len(dim_in)): stem = trans_func( dim_in[pathway], dim_out[pathway], self.kernel[pathway], self.stride[pathway], self.padding[pathway], self.inplace_relu, self.eps, self.bn_mmt, norm_module, ) self.add_module("pathway{}_stem".format(pathway), stem) def forward(self, x): assert ( len(x) == self.num_pathways ), "Input tensor does not contain {} pathway".format(self.num_pathways) for pathway in range(len(x)): m = getattr(self, "pathway{}_stem".format(pathway)) x[pathway] = m(x[pathway]) return x class ResNetBasicStem(nn.Module): """ ResNe(X)t 3D stem module. Performs spatiotemporal Convolution, BN, and Relu following by a spatiotemporal pooling. """ def __init__( self, dim_in, dim_out, kernel, stride, padding, inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, ): """ The `__init__` method of any subclass should also contain these arguments. Args: dim_in (int): the channel dimension of the input. Normally 3 is used for rgb input, and 2 or 3 is used for optical flow input. dim_out (int): the output dimension of the convolution in the stem layer. kernel (list): the kernel size of the convolution in the stem layer. temporal kernel size, height kernel size, width kernel size in order. stride (list): the stride size of the convolution in the stem layer. temporal kernel stride, height kernel size, width kernel size in order. padding (int): the padding size of the convolution in the stem layer, temporal padding size, height padding size, width padding size in order. inplace_relu (bool): calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(ResNetBasicStem, self).__init__() self.kernel = kernel self.stride = stride self.padding = padding self.inplace_relu = inplace_relu self.eps = eps self.bn_mmt = bn_mmt # Construct the stem layer. self._construct_stem(dim_in, dim_out, norm_module) def _construct_stem(self, dim_in, dim_out, norm_module): self.conv = nn.Conv3d( dim_in, dim_out, self.kernel, stride=self.stride, padding=self.padding, bias=False, ) self.bn = norm_module( num_features=dim_out, eps=self.eps, momentum=self.bn_mmt ) self.relu = nn.ReLU(self.inplace_relu) self.pool_layer = nn.MaxPool3d( kernel_size=[1, 3, 3], stride=[1, 2, 2], padding=[0, 1, 1] ) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pool_layer(x) return x class X3DStem(nn.Module): """ X3D's 3D stem module. Performs a spatial followed by a depthwise temporal Convolution, BN, and Relu following by a spatiotemporal pooling. """ def __init__( self, dim_in, dim_out, kernel, stride, padding, inplace_relu=True, eps=1e-5, bn_mmt=0.1, norm_module=nn.BatchNorm3d, ): """ The `__init__` method of any subclass should also contain these arguments. Args: dim_in (int): the channel dimension of the input. Normally 3 is used for rgb input, and 2 or 3 is used for optical flow input. dim_out (int): the output dimension of the convolution in the stem layer. kernel (list): the kernel size of the convolution in the stem layer. temporal kernel size, height kernel size, width kernel size in order. stride (list): the stride size of the convolution in the stem layer. temporal kernel stride, height kernel size, width kernel size in order. padding (int): the padding size of the convolution in the stem layer, temporal padding size, height padding size, width padding size in order. inplace_relu (bool): calculate the relu on the original input without allocating new memory. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(X3DStem, self).__init__() self.kernel = kernel self.stride = stride self.padding = padding self.inplace_relu = inplace_relu self.eps = eps self.bn_mmt = bn_mmt # Construct the stem layer. self._construct_stem(dim_in, dim_out, norm_module) def _construct_stem(self, dim_in, dim_out, norm_module): self.conv_xy = nn.Conv3d( dim_in, dim_out, kernel_size=(1, self.kernel[1], self.kernel[2]), stride=(1, self.stride[1], self.stride[2]), padding=(0, self.padding[1], self.padding[2]), bias=False, ) self.conv = nn.Conv3d( dim_out, dim_out, kernel_size=(self.kernel[0], 1, 1), stride=(self.stride[0], 1, 1), padding=(self.padding[0], 0, 0), bias=False, groups=dim_out, ) self.bn = norm_module( num_features=dim_out, eps=self.eps, momentum=self.bn_mmt ) self.relu = nn.ReLU(self.inplace_relu) def forward(self, x): x = self.conv_xy(x) x = self.conv(x) x = self.bn(x) x = self.relu(x) return x