|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from pytorchvideo.layers import make_fusion_layer |
|
|
|
|
|
class TestFusion(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
self.fake_input_1 = torch.Tensor( |
|
[[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]] |
|
).float() |
|
self.fake_input_2 = torch.Tensor( |
|
[[[1, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [2, 1]]] |
|
).float() |
|
|
|
def test_reduce_fusion_layers(self): |
|
expected_output_for_method = { |
|
"max": torch.Tensor( |
|
[[[4, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [5, 2]]] |
|
).float(), |
|
"sum": torch.Tensor( |
|
[[[5, 0], [6, 4]], [[5, 8], [10, 8]], [[7, 4], [7, 3]]] |
|
).float(), |
|
"prod": torch.Tensor( |
|
[[[4, -4], [9, 0]], [[0, 12], [24, 15]], [[12, 3], [10, 2]]] |
|
).float(), |
|
} |
|
|
|
for method, expected_output in expected_output_for_method.items(): |
|
model = make_fusion_layer( |
|
method, [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]] |
|
) |
|
output = model([self.fake_input_1, self.fake_input_2]) |
|
self.assertTrue(torch.equal(output, expected_output)) |
|
self.assertEqual(model.output_dim, self.fake_input_1.shape[-1]) |
|
|
|
def test_concat_fusion(self): |
|
model = make_fusion_layer( |
|
"concat", [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]] |
|
) |
|
input_list = [self.fake_input_1, self.fake_input_2] |
|
output = model(input_list) |
|
expected_output = torch.cat(input_list, dim=-1) |
|
self.assertTrue(torch.equal(output, expected_output)) |
|
|
|
expected_shape = self.fake_input_1.shape[-1] + self.fake_input_2.shape[-1] |
|
self.assertEqual(model.output_dim, expected_shape) |
|
|
|
def test_temporal_concat_fusion(self): |
|
model = make_fusion_layer( |
|
"temporal_concat", |
|
[self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]], |
|
) |
|
input_list = [self.fake_input_1, self.fake_input_2] |
|
output = model(input_list) |
|
|
|
expected_output = torch.cat(input_list, dim=-2) |
|
self.assertTrue(torch.equal(output, expected_output)) |
|
self.assertEqual(model.output_dim, self.fake_input_2.shape[-1]) |
|
|