# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Video models.""" import math import torch import torch.nn as nn import timesformer.utils.weight_init_helper as init_helper from timesformer.models.batchnorm_helper import get_norm from . import head_helper, resnet_helper, stem_helper from .build import MODEL_REGISTRY import math from torch.nn import ReplicationPad3d from torch import einsum from einops import rearrange, reduce, repeat import copy import numpy as np from timesformer.models.vit import vit_base_patch16_224 # Number of blocks for different stages given the model depth. _MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3)} # Basis of temporal kernel sizes for each of the stage. _TEMPORAL_KERNEL_BASIS = { "c2d": [ [[1]], # conv1 temporal kernel. [[1]], # res2 temporal kernel. [[1]], # res3 temporal kernel. [[1]], # res4 temporal kernel. [[1]], # res5 temporal kernel. ], "c2d_nopool": [ [[1]], # conv1 temporal kernel. [[1]], # res2 temporal kernel. [[1]], # res3 temporal kernel. [[1]], # res4 temporal kernel. [[1]], # res5 temporal kernel. ], "i3d": [ [[5]], # conv1 temporal kernel. [[3]], # res2 temporal kernel. [[3, 1]], # res3 temporal kernel. [[3, 1]], # res4 temporal kernel. [[1, 3]], # res5 temporal kernel. ], "i3d_nopool": [ [[5]], # conv1 temporal kernel. [[3]], # res2 temporal kernel. [[3, 1]], # res3 temporal kernel. [[3, 1]], # res4 temporal kernel. [[1, 3]], # res5 temporal kernel. ], "slow": [ [[1]], # conv1 temporal kernel. [[1]], # res2 temporal kernel. [[1]], # res3 temporal kernel. [[3]], # res4 temporal kernel. [[3]], # res5 temporal kernel. ], "slowfast": [ [[1], [5]], # conv1 temporal kernel for slow and fast pathway. [[1], [3]], # res2 temporal kernel for slow and fast pathway. [[1], [3]], # res3 temporal kernel for slow and fast pathway. [[3], [3]], # res4 temporal kernel for slow and fast pathway. [[3], [3]], # res5 temporal kernel for slow and fast pathway. ], "x3d": [ [[5]], # conv1 temporal kernels. [[3]], # res2 temporal kernels. [[3]], # res3 temporal kernels. [[3]], # res4 temporal kernels. [[3]], # res5 temporal kernels. ], } _POOL1 = { "c2d": [[2, 1, 1]], "c2d_nopool": [[1, 1, 1]], "i3d": [[2, 1, 1]], "i3d_nopool": [[1, 1, 1]], "slow": [[1, 1, 1]], "slowfast": [[1, 1, 1], [1, 1, 1]], "x3d": [[1, 1, 1]], } class FuseFastToSlow(nn.Module): """ Fuses the information from the Fast pathway to the Slow pathway. Given the tensors from Slow pathway and Fast pathway, fuse information from Fast to Slow, then return the fused tensors from Slow and Fast pathway in order. """ def __init__( self, dim_in, fusion_conv_channel_ratio, fusion_kernel, alpha, eps=1e-5, bn_mmt=0.1, inplace_relu=True, norm_module=nn.BatchNorm3d, ): """ Args: dim_in (int): the channel dimension of the input. fusion_conv_channel_ratio (int): channel ratio for the convolution used to fuse from Fast pathway to Slow pathway. fusion_kernel (int): kernel size of the convolution used to fuse from Fast pathway to Slow pathway. alpha (int): the frame rate ratio between the Fast and Slow pathway. eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. inplace_relu (bool): if True, calculate the relu on the original input without allocating new memory. norm_module (nn.Module): nn.Module for the normalization layer. The default is nn.BatchNorm3d. """ super(FuseFastToSlow, self).__init__() self.conv_f2s = nn.Conv3d( dim_in, dim_in * fusion_conv_channel_ratio, kernel_size=[fusion_kernel, 1, 1], stride=[alpha, 1, 1], padding=[fusion_kernel // 2, 0, 0], bias=False, ) self.bn = norm_module( num_features=dim_in * fusion_conv_channel_ratio, eps=eps, momentum=bn_mmt, ) self.relu = nn.ReLU(inplace_relu) def forward(self, x): x_s = x[0] x_f = x[1] fuse = self.conv_f2s(x_f) fuse = self.bn(fuse) fuse = self.relu(fuse) x_s_fuse = torch.cat([x_s, fuse], 1) return [x_s_fuse, x_f] @MODEL_REGISTRY.register() class SlowFast(nn.Module): """ SlowFast model builder for SlowFast network. Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. "SlowFast networks for video recognition." https://arxiv.org/pdf/1812.03982.pdf """ def __init__(self, cfg): """ The `__init__` method of any subclass should also contain these arguments. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ super(SlowFast, self).__init__() self.norm_module = get_norm(cfg) self.enable_detection = cfg.DETECTION.ENABLE self.num_pathways = 2 self._construct_network(cfg) init_helper.init_weights( self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN ) def _construct_network(self, cfg): """ Builds a SlowFast model. The first pathway is the Slow pathway and the second pathway is the Fast pathway. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ assert cfg.MODEL.ARCH in _POOL1.keys() pool_size = _POOL1[cfg.MODEL.ARCH] assert len({len(pool_size), self.num_pathways}) == 1 assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] num_groups = cfg.RESNET.NUM_GROUPS width_per_group = cfg.RESNET.WIDTH_PER_GROUP dim_inner = num_groups * width_per_group out_dim_ratio = ( cfg.SLOWFAST.BETA_INV // cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO ) temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] self.s1 = stem_helper.VideoModelStem( dim_in=cfg.DATA.INPUT_CHANNEL_NUM, dim_out=[width_per_group, width_per_group // cfg.SLOWFAST.BETA_INV], kernel=[temp_kernel[0][0] + [7, 7], temp_kernel[0][1] + [7, 7]], stride=[[1, 2, 2]] * 2, padding=[ [temp_kernel[0][0][0] // 2, 3, 3], [temp_kernel[0][1][0] // 2, 3, 3], ], norm_module=self.norm_module, ) self.s1_fuse = FuseFastToSlow( width_per_group // cfg.SLOWFAST.BETA_INV, cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, cfg.SLOWFAST.FUSION_KERNEL_SZ, cfg.SLOWFAST.ALPHA, norm_module=self.norm_module, ) self.s2 = resnet_helper.ResStage( dim_in=[ width_per_group + width_per_group // out_dim_ratio, width_per_group // cfg.SLOWFAST.BETA_INV, ], dim_out=[ width_per_group * 4, width_per_group * 4 // cfg.SLOWFAST.BETA_INV, ], dim_inner=[dim_inner, dim_inner // cfg.SLOWFAST.BETA_INV], temp_kernel_sizes=temp_kernel[1], stride=cfg.RESNET.SPATIAL_STRIDES[0], num_blocks=[d2] * 2, num_groups=[num_groups] * 2, num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], nonlocal_inds=cfg.NONLOCAL.LOCATION[0], nonlocal_group=cfg.NONLOCAL.GROUP[0], nonlocal_pool=cfg.NONLOCAL.POOL[0], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, dilation=cfg.RESNET.SPATIAL_DILATIONS[0], norm_module=self.norm_module, ) self.s2_fuse = FuseFastToSlow( width_per_group * 4 // cfg.SLOWFAST.BETA_INV, cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, cfg.SLOWFAST.FUSION_KERNEL_SZ, cfg.SLOWFAST.ALPHA, norm_module=self.norm_module, ) for pathway in range(self.num_pathways): pool = nn.MaxPool3d( kernel_size=pool_size[pathway], stride=pool_size[pathway], padding=[0, 0, 0], ) self.add_module("pathway{}_pool".format(pathway), pool) self.s3 = resnet_helper.ResStage( dim_in=[ width_per_group * 4 + width_per_group * 4 // out_dim_ratio, width_per_group * 4 // cfg.SLOWFAST.BETA_INV, ], dim_out=[ width_per_group * 8, width_per_group * 8 // cfg.SLOWFAST.BETA_INV, ], dim_inner=[dim_inner * 2, dim_inner * 2 // cfg.SLOWFAST.BETA_INV], temp_kernel_sizes=temp_kernel[2], stride=cfg.RESNET.SPATIAL_STRIDES[1], num_blocks=[d3] * 2, num_groups=[num_groups] * 2, num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], nonlocal_inds=cfg.NONLOCAL.LOCATION[1], nonlocal_group=cfg.NONLOCAL.GROUP[1], nonlocal_pool=cfg.NONLOCAL.POOL[1], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, dilation=cfg.RESNET.SPATIAL_DILATIONS[1], norm_module=self.norm_module, ) self.s3_fuse = FuseFastToSlow( width_per_group * 8 // cfg.SLOWFAST.BETA_INV, cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, cfg.SLOWFAST.FUSION_KERNEL_SZ, cfg.SLOWFAST.ALPHA, norm_module=self.norm_module, ) self.s4 = resnet_helper.ResStage( dim_in=[ width_per_group * 8 + width_per_group * 8 // out_dim_ratio, width_per_group * 8 // cfg.SLOWFAST.BETA_INV, ], dim_out=[ width_per_group * 16, width_per_group * 16 // cfg.SLOWFAST.BETA_INV, ], dim_inner=[dim_inner * 4, dim_inner * 4 // cfg.SLOWFAST.BETA_INV], temp_kernel_sizes=temp_kernel[3], stride=cfg.RESNET.SPATIAL_STRIDES[2], num_blocks=[d4] * 2, num_groups=[num_groups] * 2, num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], nonlocal_inds=cfg.NONLOCAL.LOCATION[2], nonlocal_group=cfg.NONLOCAL.GROUP[2], nonlocal_pool=cfg.NONLOCAL.POOL[2], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, dilation=cfg.RESNET.SPATIAL_DILATIONS[2], norm_module=self.norm_module, ) self.s4_fuse = FuseFastToSlow( width_per_group * 16 // cfg.SLOWFAST.BETA_INV, cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO, cfg.SLOWFAST.FUSION_KERNEL_SZ, cfg.SLOWFAST.ALPHA, norm_module=self.norm_module, ) self.s5 = resnet_helper.ResStage( dim_in=[ width_per_group * 16 + width_per_group * 16 // out_dim_ratio, width_per_group * 16 // cfg.SLOWFAST.BETA_INV, ], dim_out=[ width_per_group * 32, width_per_group * 32 // cfg.SLOWFAST.BETA_INV, ], dim_inner=[dim_inner * 8, dim_inner * 8 // cfg.SLOWFAST.BETA_INV], temp_kernel_sizes=temp_kernel[4], stride=cfg.RESNET.SPATIAL_STRIDES[3], num_blocks=[d5] * 2, num_groups=[num_groups] * 2, num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], nonlocal_inds=cfg.NONLOCAL.LOCATION[3], nonlocal_group=cfg.NONLOCAL.GROUP[3], nonlocal_pool=cfg.NONLOCAL.POOL[3], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, dilation=cfg.RESNET.SPATIAL_DILATIONS[3], norm_module=self.norm_module, ) if cfg.DETECTION.ENABLE: self.head = head_helper.ResNetRoIHead( dim_in=[ width_per_group * 32, width_per_group * 32 // cfg.SLOWFAST.BETA_INV, ], num_classes=cfg.MODEL.NUM_CLASSES, pool_size=[ [ cfg.DATA.NUM_FRAMES // cfg.SLOWFAST.ALPHA // pool_size[0][0], 1, 1, ], [cfg.DATA.NUM_FRAMES // pool_size[1][0], 1, 1], ], resolution=[[cfg.DETECTION.ROI_XFORM_RESOLUTION] * 2] * 2, scale_factor=[cfg.DETECTION.SPATIAL_SCALE_FACTOR] * 2, dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, aligned=cfg.DETECTION.ALIGNED, ) else: head = head_helper.ResNetBasicHead( dim_in=[ width_per_group * 32, width_per_group * 32 // cfg.SLOWFAST.BETA_INV, ], num_classes=cfg.MODEL.NUM_CLASSES, pool_size=[None, None] if cfg.MULTIGRID.SHORT_CYCLE else [ [ cfg.DATA.NUM_FRAMES // cfg.SLOWFAST.ALPHA // pool_size[0][0], cfg.DATA.TRAIN_CROP_SIZE // 32 // pool_size[0][1], cfg.DATA.TRAIN_CROP_SIZE // 32 // pool_size[0][2], ], [ cfg.DATA.NUM_FRAMES // pool_size[1][0], cfg.DATA.TRAIN_CROP_SIZE // 32 // pool_size[1][1], cfg.DATA.TRAIN_CROP_SIZE // 32 // pool_size[1][2], ], ], # None for AdaptiveAvgPool3d((1, 1, 1)) dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, ) self.head_name = "head{}".format(cfg.TASK) self.add_module(self.head_name, head) def forward(self, x, bboxes=None): x = self.s1(x) x = self.s1_fuse(x) x = self.s2(x) x = self.s2_fuse(x) for pathway in range(self.num_pathways): pool = getattr(self, "pathway{}_pool".format(pathway)) x[pathway] = pool(x[pathway]) x = self.s3(x) x = self.s3_fuse(x) x = self.s4(x) x = self.s4_fuse(x) x = self.s5(x) head = getattr(self, self.head_name) if self.enable_detection: x = head(x, bboxes) else: x = head(x) return x @MODEL_REGISTRY.register() class ResNet(nn.Module): """ ResNet model builder. It builds a ResNet like network backbone without lateral connection (C2D, I3D, Slow). Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He. "SlowFast networks for video recognition." https://arxiv.org/pdf/1812.03982.pdf Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. "Non-local neural networks." https://arxiv.org/pdf/1711.07971.pdf """ def __init__(self, cfg): """ The `__init__` method of any subclass should also contain these arguments. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ super(ResNet, self).__init__() self.norm_module = get_norm(cfg) self.enable_detection = cfg.DETECTION.ENABLE self.num_pathways = 1 self._construct_network(cfg) init_helper.init_weights( self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN ) def _construct_network(self, cfg): """ Builds a single pathway ResNet model. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ assert cfg.MODEL.ARCH in _POOL1.keys() pool_size = _POOL1[cfg.MODEL.ARCH] assert len({len(pool_size), self.num_pathways}) == 1 assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] num_groups = cfg.RESNET.NUM_GROUPS width_per_group = cfg.RESNET.WIDTH_PER_GROUP dim_inner = num_groups * width_per_group temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] self.s1 = stem_helper.VideoModelStem( dim_in=cfg.DATA.INPUT_CHANNEL_NUM, dim_out=[width_per_group], kernel=[temp_kernel[0][0] + [7, 7]], stride=[[1, 2, 2]], padding=[[temp_kernel[0][0][0] // 2, 3, 3]], norm_module=self.norm_module, ) self.s2 = resnet_helper.ResStage( dim_in=[width_per_group], dim_out=[width_per_group * 4], dim_inner=[dim_inner], temp_kernel_sizes=temp_kernel[1], stride=cfg.RESNET.SPATIAL_STRIDES[0], num_blocks=[d2], num_groups=[num_groups], num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0], nonlocal_inds=cfg.NONLOCAL.LOCATION[0], nonlocal_group=cfg.NONLOCAL.GROUP[0], nonlocal_pool=cfg.NONLOCAL.POOL[0], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, stride_1x1=cfg.RESNET.STRIDE_1X1, inplace_relu=cfg.RESNET.INPLACE_RELU, dilation=cfg.RESNET.SPATIAL_DILATIONS[0], norm_module=self.norm_module, ) for pathway in range(self.num_pathways): pool = nn.MaxPool3d( kernel_size=pool_size[pathway], stride=pool_size[pathway], padding=[0, 0, 0], ) self.add_module("pathway{}_pool".format(pathway), pool) self.s3 = resnet_helper.ResStage( dim_in=[width_per_group * 4], dim_out=[width_per_group * 8], dim_inner=[dim_inner * 2], temp_kernel_sizes=temp_kernel[2], stride=cfg.RESNET.SPATIAL_STRIDES[1], num_blocks=[d3], num_groups=[num_groups], num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1], nonlocal_inds=cfg.NONLOCAL.LOCATION[1], nonlocal_group=cfg.NONLOCAL.GROUP[1], nonlocal_pool=cfg.NONLOCAL.POOL[1], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, stride_1x1=cfg.RESNET.STRIDE_1X1, inplace_relu=cfg.RESNET.INPLACE_RELU, dilation=cfg.RESNET.SPATIAL_DILATIONS[1], norm_module=self.norm_module, ) self.s4 = resnet_helper.ResStage( dim_in=[width_per_group * 8], dim_out=[width_per_group * 16], dim_inner=[dim_inner * 4], temp_kernel_sizes=temp_kernel[3], stride=cfg.RESNET.SPATIAL_STRIDES[2], num_blocks=[d4], num_groups=[num_groups], num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2], nonlocal_inds=cfg.NONLOCAL.LOCATION[2], nonlocal_group=cfg.NONLOCAL.GROUP[2], nonlocal_pool=cfg.NONLOCAL.POOL[2], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, stride_1x1=cfg.RESNET.STRIDE_1X1, inplace_relu=cfg.RESNET.INPLACE_RELU, dilation=cfg.RESNET.SPATIAL_DILATIONS[2], norm_module=self.norm_module, ) self.s5 = resnet_helper.ResStage( dim_in=[width_per_group * 16], dim_out=[width_per_group * 32], dim_inner=[dim_inner * 8], temp_kernel_sizes=temp_kernel[4], stride=cfg.RESNET.SPATIAL_STRIDES[3], num_blocks=[d5], num_groups=[num_groups], num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3], nonlocal_inds=cfg.NONLOCAL.LOCATION[3], nonlocal_group=cfg.NONLOCAL.GROUP[3], nonlocal_pool=cfg.NONLOCAL.POOL[3], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, stride_1x1=cfg.RESNET.STRIDE_1X1, inplace_relu=cfg.RESNET.INPLACE_RELU, dilation=cfg.RESNET.SPATIAL_DILATIONS[3], norm_module=self.norm_module, ) if self.enable_detection: self.head = head_helper.ResNetRoIHead( dim_in=[width_per_group * 32], num_classes=cfg.MODEL.NUM_CLASSES, pool_size=[[cfg.DATA.NUM_FRAMES // pool_size[0][0], 1, 1]], resolution=[[cfg.DETECTION.ROI_XFORM_RESOLUTION] * 2], scale_factor=[cfg.DETECTION.SPATIAL_SCALE_FACTOR], dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, aligned=cfg.DETECTION.ALIGNED, ) else: head = head_helper.ResNetBasicHead( dim_in=[width_per_group * 32], num_classes=cfg.MODEL.NUM_CLASSES, pool_size=[None, None] if cfg.MULTIGRID.SHORT_CYCLE else [ [ cfg.DATA.NUM_FRAMES // pool_size[0][0], cfg.DATA.TRAIN_CROP_SIZE // 32 // pool_size[0][1], cfg.DATA.TRAIN_CROP_SIZE // 32 // pool_size[0][2], ] ], # None for AdaptiveAvgPool3d((1, 1, 1)) dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, ) self.head_name = "head{}".format(cfg.TASK) self.add_module(self.head_name, head) def forward(self, x, bboxes=None): x = self.s1(x) x = self.s2(x) for pathway in range(self.num_pathways): pool = getattr(self, "pathway{}_pool".format(pathway)) x[pathway] = pool(x[pathway]) x = self.s3(x) x = self.s4(x) x = self.s5(x) head = getattr(self, self.head_name) if self.enable_detection: x = head(x, bboxes) else: x = head(x) return x @MODEL_REGISTRY.register() class X3D(nn.Module): """ X3D model builder. It builds a X3D network backbone, which is a ResNet. Christoph Feichtenhofer. "X3D: Expanding Architectures for Efficient Video Recognition." https://arxiv.org/abs/2004.04730 """ def __init__(self, cfg): """ The `__init__` method of any subclass should also contain these arguments. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ super(X3D, self).__init__() self.norm_module = get_norm(cfg) self.enable_detection = cfg.DETECTION.ENABLE self.num_pathways = 1 exp_stage = 2.0 self.dim_c1 = cfg.X3D.DIM_C1 self.dim_res2 = ( self._round_width(self.dim_c1, exp_stage, divisor=8) if cfg.X3D.SCALE_RES2 else self.dim_c1 ) self.dim_res3 = self._round_width(self.dim_res2, exp_stage, divisor=8) self.dim_res4 = self._round_width(self.dim_res3, exp_stage, divisor=8) self.dim_res5 = self._round_width(self.dim_res4, exp_stage, divisor=8) self.block_basis = [ # blocks, c, stride [1, self.dim_res2, 2], [2, self.dim_res3, 2], [5, self.dim_res4, 2], [3, self.dim_res5, 2], ] self._construct_network(cfg) init_helper.init_weights( self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN ) def _round_width(self, width, multiplier, min_depth=8, divisor=8): """Round width of filters based on width multiplier.""" if not multiplier: return width width *= multiplier min_depth = min_depth or divisor new_filters = max( min_depth, int(width + divisor / 2) // divisor * divisor ) if new_filters < 0.9 * width: new_filters += divisor return int(new_filters) def _round_repeats(self, repeats, multiplier): """Round number of layers based on depth multiplier.""" multiplier = multiplier if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) def _construct_network(self, cfg): """ Builds a single pathway X3D model. Args: cfg (CfgNode): model building configs, details are in the comments of the config file. """ assert cfg.MODEL.ARCH in _POOL1.keys() assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() (d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] num_groups = cfg.RESNET.NUM_GROUPS width_per_group = cfg.RESNET.WIDTH_PER_GROUP dim_inner = num_groups * width_per_group w_mul = cfg.X3D.WIDTH_FACTOR d_mul = cfg.X3D.DEPTH_FACTOR dim_res1 = self._round_width(self.dim_c1, w_mul) temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] self.s1 = stem_helper.VideoModelStem( dim_in=cfg.DATA.INPUT_CHANNEL_NUM, dim_out=[dim_res1], kernel=[temp_kernel[0][0] + [3, 3]], stride=[[1, 2, 2]], padding=[[temp_kernel[0][0][0] // 2, 1, 1]], norm_module=self.norm_module, stem_func_name="x3d_stem", ) # blob_in = s1 dim_in = dim_res1 for stage, block in enumerate(self.block_basis): dim_out = self._round_width(block[1], w_mul) dim_inner = int(cfg.X3D.BOTTLENECK_FACTOR * dim_out) n_rep = self._round_repeats(block[0], d_mul) prefix = "s{}".format( stage + 2 ) # start w res2 to follow convention s = resnet_helper.ResStage( dim_in=[dim_in], dim_out=[dim_out], dim_inner=[dim_inner], temp_kernel_sizes=temp_kernel[1], stride=[block[2]], num_blocks=[n_rep], num_groups=[dim_inner] if cfg.X3D.CHANNELWISE_3x3x3 else [num_groups], num_block_temp_kernel=[n_rep], nonlocal_inds=cfg.NONLOCAL.LOCATION[0], nonlocal_group=cfg.NONLOCAL.GROUP[0], nonlocal_pool=cfg.NONLOCAL.POOL[0], instantiation=cfg.NONLOCAL.INSTANTIATION, trans_func_name=cfg.RESNET.TRANS_FUNC, stride_1x1=cfg.RESNET.STRIDE_1X1, norm_module=self.norm_module, dilation=cfg.RESNET.SPATIAL_DILATIONS[stage], drop_connect_rate=cfg.MODEL.DROPCONNECT_RATE * (stage + 2) / (len(self.block_basis) + 1), ) dim_in = dim_out self.add_module(prefix, s) if self.enable_detection: NotImplementedError else: spat_sz = int(math.ceil(cfg.DATA.TRAIN_CROP_SIZE / 32.0)) self.head = head_helper.X3DHead( dim_in=dim_out, dim_inner=dim_inner, dim_out=cfg.X3D.DIM_C5, num_classes=cfg.MODEL.NUM_CLASSES, pool_size=[cfg.DATA.NUM_FRAMES, spat_sz, spat_sz], dropout_rate=cfg.MODEL.DROPOUT_RATE, act_func=cfg.MODEL.HEAD_ACT, bn_lin5_on=cfg.X3D.BN_LIN5, ) def forward(self, x, bboxes=None): for module in self.children(): x = module(x) return x