mvsoom's picture
Upload folder using huggingface_hub
3133fdb
raw
history blame contribute delete
4.54 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import unittest
import torch
from pytorchvideo.layers.swish import Swish
from pytorchvideo.models.x3d import create_x3d, create_x3d_bottleneck_block
from torch import nn
class TestX3d(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
def test_create_x3d(self):
"""
To test different versions of X3D, set the input to:
X3D-XS: (4, 160, 2.0, 2.2, 2.25)
X3D-S: (13, 160, 2.0, 2.2, 2.25)
X3D-M: (16, 224, 2.0, 2.2, 2.25)
X3D-L: (16, 312, 2.0, 5.0, 2.25)
Each of the parameters corresponds to input_clip_length, input_crop_size,
width_factor, depth_factor and bottleneck_factor.
"""
for (
input_clip_length,
input_crop_size,
width_factor,
depth_factor,
bottleneck_factor,
) in [
(4, 160, 2.0, 2.2, 2.25),
]:
model = create_x3d(
input_clip_length=input_clip_length,
input_crop_size=input_crop_size,
model_num_class=400,
dropout_rate=0.5,
width_factor=width_factor,
depth_factor=depth_factor,
norm=nn.BatchNorm3d,
activation=nn.ReLU,
stem_dim_in=12,
stem_conv_kernel_size=(5, 3, 3),
stem_conv_stride=(1, 2, 2),
stage_conv_kernel_size=((3, 3, 3),) * 4,
stage_spatial_stride=(2, 2, 2, 2),
stage_temporal_stride=(1, 1, 1, 1),
bottleneck=create_x3d_bottleneck_block,
bottleneck_factor=bottleneck_factor,
se_ratio=0.0625,
inner_act=Swish,
head_dim_out=2048,
head_pool_act=nn.ReLU,
head_bn_lin5_on=False,
head_activation=nn.Softmax,
)
# Test forwarding.
for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
if tensor.shape[1] != 3:
with self.assertRaises(RuntimeError):
out = model(tensor)
continue
out = model(tensor)
output_shape = out.shape
output_shape_gt = (tensor.shape[0], 400)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
def test_load_hubconf(self):
path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"..",
)
for (input_clip_length, input_crop_size, model_name) in [
(4, 160, "x3d_xs"),
(13, 160, "x3d_s"),
(16, 224, "x3d_m"),
]:
model = torch.hub.load(
repo_or_dir=path,
source="local",
model=model_name,
pretrained=False,
head_output_with_global_average=True,
)
self.assertIsNotNone(model)
# Test forwarding.
for tensor in TestX3d._get_inputs(input_clip_length, input_crop_size):
if tensor.shape[1] != 3:
with self.assertRaises(RuntimeError):
out = model(tensor)
continue
out = model(tensor)
output_shape = out.shape
output_shape_gt = (tensor.shape[0], 400)
self.assertEqual(
output_shape,
output_shape_gt,
"Output shape {} is different from expected shape {}".format(
output_shape, output_shape_gt
),
)
@staticmethod
def _get_inputs(clip_length: int = 4, crop_size: int = 160) -> torch.tensor:
"""
Provide different tensors as test cases.
Yield:
(torch.tensor): tensor as test case input.
"""
# Prepare random inputs as test cases.
shapes = (
(1, 3, clip_length, crop_size, crop_size),
(2, 3, clip_length, crop_size, crop_size),
)
for shape in shapes:
yield torch.rand(shape)