File size: 4,542 Bytes
3133fdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import unittest
import torch
from pytorchvideo.layers.swish import Swish
from pytorchvideo.models.x3d import create_x3d, create_x3d_bottleneck_block
from torch import nn
class TestX3d(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_create_x3d(self):
"""
To test different versions of X3D, set the input to:
X3D-XS: (4, 160, 2.0, 2.2, 2.25)
X3D-S: (13, 160, 2.0, 2.2, 2.25)
X3D-M: (16, 224, 2.0, 2.2, 2.25)
X3D-L: (16, 312, 2.0, 5.0, 2.25)
Each of the parameters corresponds to input_clip_length, input_crop_size,
width_factor, depth_factor and bottleneck_factor.
"""
for (
input_clip_length,
input_crop_size,
width_factor,
depth_factor,
bottleneck_factor,
) in [
(4, 160, 2.0, 2.2, 2.25),
]:
model = create_x3d(
input_clip_length=input_clip_length,
input_crop_size=input_crop_size,
model_num_class=400,
dropout_rate=0.5,
width_factor=width_factor,
depth_factor=depth_factor,
norm=nn.BatchNorm3d,
activation=nn.ReLU,
stem_dim_in=12,
stem_conv_kernel_size=(5, 3, 3),
stem_conv_stride=(1, 2, 2),
stage_conv_kernel_size=((3, 3, 3),) * 4,
stage_spatial_stride=(2, 2, 2, 2),
stage_temporal_stride=(1, 1, 1, 1),
bottleneck=create_x3d_bottleneck_block,
bottleneck_factor=bottleneck_factor,
se_ratio=0.0625,
inner_act=Swish,
head_dim_out=2048,
head_pool_act=nn.ReLU,
head_bn_lin5_on=False,
head_activation=nn.Softmax,
)
# Test forwarding.
for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
if tensor.shape[1] != 3:
with self.assertRaises(RuntimeError):
out = model(tensor)
continue
out = model(tensor)
output_shape = out.shape
output_shape_gt = (tensor.shape[0], 400)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
def test_load_hubconf(self):
path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
)
for (input_clip_length, input_crop_size, model_name) in [
(4, 160, "x3d_xs"),
(13, 160, "x3d_s"),
(16, 224, "x3d_m"),
]:
model = torch.hub.load(
repo_or_dir=path,
source="local",
model=model_name,
pretrained=False,
head_output_with_global_average=True,
)
self.assertIsNotNone(model)
# Test forwarding.
for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
if tensor.shape[1] != 3:
with self.assertRaises(RuntimeError):
out = model(tensor)
continue
out = model(tensor)
output_shape = out.shape
output_shape_gt = (tensor.shape[0], 400)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
@staticmethod
def _get_inputs(clip_length: int = 4, crop_size: int = 160) -> torch.tensor:
"""
Provide different tensors as test cases.
Yield:
(torch.tensor): tensor as test case input.
"""
# Prepare random inputs as test cases.
shapes = (
(1, 3, clip_length, crop_size, crop_size),
(2, 3, clip_length, crop_size, crop_size),
)
for shape in shapes:
yield torch.rand(shape)
|