File size: 4,542 Bytes
3133fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import os
import unittest

import torch
from pytorchvideo.layers.swish import Swish
from pytorchvideo.models.x3d import create_x3d, create_x3d_bottleneck_block
from torch import nn


class TestX3d(unittest.TestCase):
    def setUp(self):
        super().setUp()
        torch.set_rng_state(torch.manual_seed(42).get_state())

    def test_create_x3d(self):
        """
        To test different versions of X3D, set the input to:
        X3D-XS: (4, 160, 2.0, 2.2, 2.25)
        X3D-S: (13, 160, 2.0, 2.2, 2.25)
        X3D-M: (16, 224, 2.0, 2.2, 2.25)
        X3D-L: (16, 312, 2.0, 5.0, 2.25)

        Each of the parameters corresponds to input_clip_length, input_crop_size,
        width_factor, depth_factor and bottleneck_factor.
        """
        for (
            input_clip_length,
            input_crop_size,
            width_factor,
            depth_factor,
            bottleneck_factor,
        ) in [
            (4, 160, 2.0, 2.2, 2.25),
        ]:
            model = create_x3d(
                input_clip_length=input_clip_length,
                input_crop_size=input_crop_size,
                model_num_class=400,
                dropout_rate=0.5,
                width_factor=width_factor,
                depth_factor=depth_factor,
                norm=nn.BatchNorm3d,
                activation=nn.ReLU,
                stem_dim_in=12,
                stem_conv_kernel_size=(5, 3, 3),
                stem_conv_stride=(1, 2, 2),
                stage_conv_kernel_size=((3, 3, 3),) * 4,
                stage_spatial_stride=(2, 2, 2, 2),
                stage_temporal_stride=(1, 1, 1, 1),
                bottleneck=create_x3d_bottleneck_block,
                bottleneck_factor=bottleneck_factor,
                se_ratio=0.0625,
                inner_act=Swish,
                head_dim_out=2048,
                head_pool_act=nn.ReLU,
                head_bn_lin5_on=False,
                head_activation=nn.Softmax,
            )

            # Test forwarding.
            for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
                if tensor.shape[1] != 3:
                    with self.assertRaises(RuntimeError):
                        out = model(tensor)
                    continue

                out = model(tensor)

                output_shape = out.shape
                output_shape_gt = (tensor.shape[0], 400)

                self.assertEqual(
                    output_shape,
                    output_shape_gt,
                    "Output shape {} is different from expected shape {}".format(
                        output_shape, output_shape_gt
                    ),
                )

    def test_load_hubconf(self):
        path = os.path.join(
            os.path.dirname(os.path.realpath(__file__)),
            "..",
        )
        for (input_clip_length, input_crop_size, model_name) in [
            (4, 160, "x3d_xs"),
            (13, 160, "x3d_s"),
            (16, 224, "x3d_m"),
        ]:
            model = torch.hub.load(
                repo_or_dir=path,
                source="local",
                model=model_name,
                pretrained=False,
                head_output_with_global_average=True,
            )
            self.assertIsNotNone(model)

            # Test forwarding.
            for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
                if tensor.shape[1] != 3:
                    with self.assertRaises(RuntimeError):
                        out = model(tensor)
                    continue

                out = model(tensor)

                output_shape = out.shape
                output_shape_gt = (tensor.shape[0], 400)

                self.assertEqual(
                    output_shape,
                    output_shape_gt,
                    "Output shape {} is different from expected shape {}".format(
                        output_shape, output_shape_gt
                    ),
                )

    @staticmethod
    def _get_inputs(clip_length: int = 4, crop_size: int = 160) -> torch.tensor:
        """
        Provide different tensors as test cases.

        Yield:
            (torch.tensor): tensor as test case input.
        """
        # Prepare random inputs as test cases.
        shapes = (
            (1, 3, clip_length, crop_size, crop_size),
            (2, 3, clip_length, crop_size, crop_size),
        )
        for shape in shapes:
            yield torch.rand(shape)