|
|
|
|
|
import itertools |
|
import unittest |
|
|
|
import numpy as np |
|
import torch |
|
from pytorchvideo.layers.convolutions import ( |
|
Conv2plus1d, |
|
ConvReduce3D, |
|
create_conv_2plus1d, |
|
) |
|
from torch import nn |
|
|
|
|
|
class TestConvReduce3D(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
def test_create_stack_conv(self): |
|
""" |
|
Test ConvReduce3D. |
|
""" |
|
for input_dim, output_dim in itertools.product((2, 4), (4, 8, 16)): |
|
model = ConvReduce3D( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=((1, 1, 1), (3, 3, 3), (1, 3, 3)), |
|
stride=((1, 1, 1), (1, 1, 1), None), |
|
padding=((0, 0, 0), (1, 1, 1), (0, 1, 1)), |
|
dilation=((2, 2, 2), (1, 1, 1), None), |
|
groups=(1, 2, None), |
|
bias=(True, False, None), |
|
) |
|
model_gt_list = [ |
|
nn.Conv3d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=(1, 1, 1), |
|
stride=(1, 1, 1), |
|
padding=(0, 0, 0), |
|
dilation=(2, 2, 2), |
|
groups=1, |
|
bias=True, |
|
), |
|
nn.Conv3d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=(3, 3, 3), |
|
stride=(1, 1, 1), |
|
padding=(1, 1, 1), |
|
dilation=(1, 1, 1), |
|
groups=2, |
|
bias=False, |
|
), |
|
nn.Conv3d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=(1, 3, 3), |
|
padding=(0, 1, 1), |
|
), |
|
] |
|
model.convs[0].load_state_dict( |
|
model_gt_list[0].state_dict(), strict=True |
|
) |
|
model.convs[1].load_state_dict( |
|
model_gt_list[1].state_dict(), strict=True |
|
) |
|
model.convs[2].load_state_dict( |
|
model_gt_list[2].state_dict(), strict=True |
|
) |
|
|
|
|
|
for tensor in TestConvReduce3D._get_inputs(input_dim): |
|
if tensor.shape[1] != input_dim: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(tensor) |
|
continue |
|
else: |
|
output_tensor = model(tensor) |
|
output_gt = [] |
|
for ind in range(3): |
|
output_gt.append(model_gt_list[ind](tensor)) |
|
output_tensor_gt = torch.stack(output_gt, dim=0).sum( |
|
dim=0, keepdim=False |
|
) |
|
|
|
self.assertEqual( |
|
output_tensor.shape, |
|
output_tensor_gt.shape, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_tensor.shape, output_tensor_gt.shape |
|
), |
|
) |
|
|
|
@staticmethod |
|
def _get_inputs(input_dim: int = 3) -> torch.tensor: |
|
""" |
|
Provide different tensors as test cases. |
|
|
|
Yield: |
|
(torch.tensor): tensor as test case input. |
|
""" |
|
|
|
shapes = ( |
|
|
|
(1, input_dim, 3, 7, 7), |
|
(1, input_dim, 5, 7, 7), |
|
(1, input_dim, 7, 7, 7), |
|
(2, input_dim, 3, 7, 7), |
|
(4, input_dim, 3, 7, 7), |
|
(8, input_dim, 3, 7, 7), |
|
(2, input_dim, 3, 7, 14), |
|
(2, input_dim, 3, 14, 7), |
|
(2, input_dim, 3, 14, 14), |
|
|
|
(8, input_dim * 2, 3, 7, 7), |
|
(8, input_dim * 4, 5, 7, 7), |
|
) |
|
for shape in shapes: |
|
yield torch.rand(shape) |
|
|
|
|
|
class TestConv2plus1d(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
def test_create_2plus1d_conv(self): |
|
""" |
|
Test Conv2plus1d. |
|
""" |
|
for input_dim, output_dim in itertools.product((2, 4), (4, 8, 16)): |
|
model = Conv2plus1d( |
|
conv_t=nn.Conv3d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=(3, 1, 1), |
|
stride=(2, 1, 1), |
|
padding=(1, 0, 0), |
|
bias=False, |
|
), |
|
norm=nn.BatchNorm3d(output_dim), |
|
activation=nn.ReLU(), |
|
conv_xy=nn.Conv3d( |
|
in_channels=output_dim, |
|
out_channels=output_dim, |
|
kernel_size=(1, 3, 3), |
|
stride=(1, 2, 2), |
|
padding=(0, 1, 1), |
|
bias=False, |
|
), |
|
) |
|
|
|
model_gt = create_conv_2plus1d( |
|
in_channels=input_dim, |
|
out_channels=output_dim, |
|
kernel_size=(3, 3, 3), |
|
stride=(2, 2, 2), |
|
padding=(1, 1, 1), |
|
bias=False, |
|
norm=nn.BatchNorm3d, |
|
norm_eps=1e-5, |
|
norm_momentum=0.1, |
|
activation=nn.ReLU, |
|
) |
|
|
|
model.load_state_dict( |
|
model_gt.state_dict(), strict=True |
|
) |
|
|
|
|
|
for input_tensor in TestConv2plus1d._get_inputs(): |
|
with torch.no_grad(): |
|
if input_tensor.shape[1] != input_dim: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(input_tensor) |
|
continue |
|
else: |
|
output_tensor = model(input_tensor) |
|
output_tensor_gt = model_gt(input_tensor) |
|
self.assertEqual( |
|
output_tensor.shape, |
|
output_tensor_gt.shape, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_tensor.shape, output_tensor_gt.shape |
|
), |
|
) |
|
self.assertTrue( |
|
np.allclose(output_tensor.numpy(), output_tensor_gt.numpy()) |
|
) |
|
|
|
@staticmethod |
|
def _get_inputs(input_dim: int = 3) -> torch.tensor: |
|
""" |
|
Provide different tensors as test cases. |
|
|
|
Yield: |
|
(torch.tensor): tensor as test case input. |
|
""" |
|
|
|
shapes = ( |
|
|
|
(1, input_dim, 3, 7, 7), |
|
(1, input_dim, 5, 7, 7), |
|
(1, input_dim, 7, 7, 7), |
|
(2, input_dim, 3, 7, 7), |
|
(4, input_dim, 3, 7, 7), |
|
(8, input_dim, 3, 7, 7), |
|
(2, input_dim, 3, 7, 14), |
|
(2, input_dim, 3, 14, 7), |
|
(2, input_dim, 3, 14, 14), |
|
|
|
(8, input_dim * 2, 3, 7, 7), |
|
(8, input_dim * 4, 5, 7, 7), |
|
) |
|
for shape in shapes: |
|
yield torch.rand(shape) |
|
|