|
|
|
|
|
import itertools |
|
import unittest |
|
from typing import Iterable |
|
|
|
import numpy as np |
|
import torch |
|
from pytorchvideo.layers.nonlocal_net import create_nonlocal, NonLocal |
|
from torch import nn |
|
|
|
|
|
class TestNonlocal(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
def test_build_nonlocal(self): |
|
""" |
|
Test Nonlocal model builder. |
|
""" |
|
for dim_in, dim_inner, pool, norm, instantiation in itertools.product( |
|
(4, 8), |
|
(2, 4), |
|
(None, nn.MaxPool3d(2)), |
|
(None, nn.BatchNorm3d), |
|
("dot_product", "softmax"), |
|
): |
|
model = NonLocal( |
|
conv_theta=nn.Conv3d( |
|
dim_in, dim_inner, kernel_size=1, stride=1, padding=0 |
|
), |
|
conv_phi=nn.Conv3d( |
|
dim_in, dim_inner, kernel_size=1, stride=1, padding=0 |
|
), |
|
conv_g=nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0), |
|
conv_out=nn.Conv3d( |
|
dim_inner, dim_in, kernel_size=1, stride=1, padding=0 |
|
), |
|
pool=pool, |
|
norm=norm(dim_in) if norm is not None else None, |
|
instantiation=instantiation, |
|
) |
|
|
|
|
|
for input_tensor in TestNonlocal._get_inputs(input_dim=dim_in): |
|
if input_tensor.shape[1] != dim_in: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(input_tensor) |
|
continue |
|
else: |
|
output_tensor = model(input_tensor) |
|
|
|
input_shape = input_tensor.shape |
|
output_shape = output_tensor.shape |
|
|
|
self.assertEqual( |
|
input_shape, |
|
output_shape, |
|
"Input shape {} is different from output shape {}".format( |
|
input_shape, output_shape |
|
), |
|
) |
|
|
|
def test_nonlocal_builder(self): |
|
""" |
|
Test builder `create_nonlocal`. |
|
""" |
|
for dim_in, dim_inner, pool_size, norm, instantiation in itertools.product( |
|
(4, 8), |
|
(2, 4), |
|
((1, 1, 1), (2, 2, 2)), |
|
(None, nn.BatchNorm3d), |
|
("dot_product", "softmax"), |
|
): |
|
conv_theta = nn.Conv3d( |
|
dim_in, dim_inner, kernel_size=1, stride=1, padding=0 |
|
) |
|
conv_phi = nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0) |
|
conv_g = nn.Conv3d(dim_in, dim_inner, kernel_size=1, stride=1, padding=0) |
|
conv_out = nn.Conv3d(dim_inner, dim_in, kernel_size=1, stride=1, padding=0) |
|
if norm is None: |
|
norm_model = None |
|
else: |
|
norm_model = norm(num_features=dim_in) |
|
if isinstance(pool_size, Iterable) and any(size > 1 for size in pool_size): |
|
pool_model = nn.MaxPool3d( |
|
kernel_size=pool_size, stride=pool_size, padding=[0, 0, 0] |
|
) |
|
else: |
|
pool_model = None |
|
|
|
model = create_nonlocal( |
|
dim_in=dim_in, |
|
dim_inner=dim_inner, |
|
pool_size=pool_size, |
|
instantiation=instantiation, |
|
norm=norm, |
|
) |
|
|
|
model_gt = NonLocal( |
|
conv_theta=conv_theta, |
|
conv_phi=conv_phi, |
|
conv_g=conv_g, |
|
conv_out=conv_out, |
|
pool=pool_model, |
|
norm=norm_model, |
|
instantiation=instantiation, |
|
) |
|
model.load_state_dict( |
|
model_gt.state_dict(), strict=True |
|
) |
|
|
|
|
|
for input_tensor in TestNonlocal._get_inputs(input_dim=dim_in): |
|
with torch.no_grad(): |
|
if input_tensor.shape[1] != dim_in: |
|
with self.assertRaises(RuntimeError): |
|
output_tensor = model(input_tensor) |
|
continue |
|
else: |
|
output_tensor = model(input_tensor) |
|
output_tensor_gt = model_gt(input_tensor) |
|
self.assertEqual( |
|
output_tensor.shape, |
|
output_tensor_gt.shape, |
|
"Output shape {} is different from expected shape {}".format( |
|
output_tensor.shape, output_tensor_gt.shape |
|
), |
|
) |
|
self.assertTrue( |
|
np.allclose(output_tensor.numpy(), output_tensor_gt.numpy()) |
|
) |
|
|
|
@staticmethod |
|
def _get_inputs(input_dim: int = 8) -> torch.tensor: |
|
""" |
|
Provide different tensors as test cases. |
|
|
|
Yield: |
|
(torch.tensor): tensor as test case input. |
|
""" |
|
|
|
shapes = ( |
|
|
|
(1, input_dim, 5, 7, 7), |
|
(2, input_dim, 5, 7, 7), |
|
(4, input_dim, 5, 7, 7), |
|
(4, input_dim, 5, 7, 7), |
|
(4, input_dim, 7, 7, 7), |
|
(4, input_dim, 7, 7, 14), |
|
(4, input_dim, 7, 14, 7), |
|
(4, input_dim, 7, 14, 14), |
|
|
|
(8, input_dim * 2, 3, 7, 7), |
|
(8, input_dim * 4, 5, 7, 7), |
|
) |
|
for shape in shapes: |
|
yield torch.rand(shape) |
|
|