File size: 2,452 Bytes
3133fdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import unittest

import torch
from pytorchvideo.layers import make_fusion_layer


class TestFusion(unittest.TestCase):
    def setUp(self):
        super().setUp()
        torch.set_rng_state(torch.manual_seed(42).get_state())

        self.fake_input_1 = torch.Tensor(
            [[[4, -2], [3, 0]], [[0, 2], [4, 3]], [[3, 1], [5, 2]]]
        ).float()
        self.fake_input_2 = torch.Tensor(
            [[[1, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [2, 1]]]
        ).float()

    def test_reduce_fusion_layers(self):
        expected_output_for_method = {
            "max": torch.Tensor(
                [[[4, 2], [3, 4]], [[5, 6], [6, 5]], [[4, 3], [5, 2]]]
            ).float(),
            "sum": torch.Tensor(
                [[[5, 0], [6, 4]], [[5, 8], [10, 8]], [[7, 4], [7, 3]]]
            ).float(),
            "prod": torch.Tensor(
                [[[4, -4], [9, 0]], [[0, 12], [24, 15]], [[12, 3], [10, 2]]]
            ).float(),
        }

        for method, expected_output in expected_output_for_method.items():
            model = make_fusion_layer(
                method, [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]]
            )
            output = model([self.fake_input_1, self.fake_input_2])
            self.assertTrue(torch.equal(output, expected_output))
            self.assertEqual(model.output_dim, self.fake_input_1.shape[-1])

    def test_concat_fusion(self):
        model = make_fusion_layer(
            "concat", [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]]
        )
        input_list = [self.fake_input_1, self.fake_input_2]
        output = model(input_list)
        expected_output = torch.cat(input_list, dim=-1)
        self.assertTrue(torch.equal(output, expected_output))

        expected_shape = self.fake_input_1.shape[-1] + self.fake_input_2.shape[-1]
        self.assertEqual(model.output_dim, expected_shape)

    def test_temporal_concat_fusion(self):
        model = make_fusion_layer(
            "temporal_concat",
            [self.fake_input_1.shape[-1], self.fake_input_2.shape[-1]],
        )
        input_list = [self.fake_input_1, self.fake_input_2]
        output = model(input_list)

        expected_output = torch.cat(input_list, dim=-2)
        self.assertTrue(torch.equal(output, expected_output))
        self.assertEqual(model.output_dim, self.fake_input_2.shape[-1])