mvsoom's picture
Upload folder using huggingface_hub
3133fdb
raw
history blame contribute delete
2.45 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
import torch
from pytorchvideo.layers import make_fusion_layer
class TestFusion(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
self.fake_input_1 = torch.Tensor(
[[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]]
).float()
self.fake_input_2 = torch.Tensor(
[[[1, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [2, 1]]]
).float()
def test_reduce_fusion_layers(self):
expected_output_for_method = {
"max": torch.Tensor(
[[[4, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [5, 2]]]
).float(),
"sum": torch.Tensor(
[[[5, 0], [6, 4]], [[5, 8], [10, 8]], [[7, 4], [7, 3]]]
).float(),
"prod": torch.Tensor(
[[[4, -4], [9, 0]], [[0, 12], [24, 15]], [[12, 3], [10, 2]]]
).float(),
}
for method, expected_output in expected_output_for_method.items():
model = make_fusion_layer(
method, [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]]
)
output = model([self.fake_input_1, self.fake_input_2])
self.assertTrue(torch.equal(output, expected_output))
self.assertEqual(model.output_dim, self.fake_input_1.shape[-1])
def test_concat_fusion(self):
model = make_fusion_layer(
"concat", [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]]
)
input_list = [self.fake_input_1, self.fake_input_2]
output = model(input_list)
expected_output = torch.cat(input_list, dim=-1)
self.assertTrue(torch.equal(output, expected_output))
expected_shape = self.fake_input_1.shape[-1] + self.fake_input_2.shape[-1]
self.assertEqual(model.output_dim, expected_shape)
def test_temporal_concat_fusion(self):
model = make_fusion_layer(
"temporal_concat",
[self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]],
)
input_list = [self.fake_input_1, self.fake_input_2]
output = model(input_list)
expected_output = torch.cat(input_list, dim=-2)
self.assertTrue(torch.equal(output, expected_output))
self.assertEqual(model.output_dim, self.fake_input_2.shape[-1])