|
|
|
|
|
import itertools |
|
import math |
|
import unittest |
|
|
|
import torch |
|
from pytorchvideo.layers import PositionalEncoding, SpatioTemporalClsPositionalEncoding |
|
|
|
|
|
class TestPositionalEncoding(unittest.TestCase): |
|
def setUp(self): |
|
super().setUp() |
|
torch.set_rng_state(torch.manual_seed(42).get_state()) |
|
|
|
self.batch_size = 4 |
|
self.seq_len = 16 |
|
self.feature_dim = 8 |
|
self.fake_input = torch.randn( |
|
(self.batch_size, self.seq_len, self.feature_dim) |
|
).float() |
|
lengths = torch.Tensor([16, 0, 14, 15, 16, 16, 16, 16]) |
|
self.mask = torch.lt( |
|
torch.arange(self.seq_len)[None, :], lengths[:, None].long() |
|
) |
|
|
|
def test_positional_encoding(self): |
|
model = PositionalEncoding(self.feature_dim, self.seq_len) |
|
output = model(self.fake_input) |
|
delta = output - self.fake_input |
|
|
|
pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float) |
|
position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, self.feature_dim, 2).float() |
|
* (-math.log(10000.0) / self.feature_dim) |
|
) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
for n in range(0, self.batch_size): |
|
self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6)) |
|
|
|
def test_positional_encoding_with_different_pe_and_data_dimensions(self): |
|
"""Test that model executes even if input data dimensions |
|
differs from the dimension of initialized postional encoding model""" |
|
|
|
|
|
positional_encoding_seq_len = self.seq_len * 3 |
|
model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len) |
|
output = model(self.fake_input) |
|
|
|
delta = output - self.fake_input |
|
pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float) |
|
position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, self.feature_dim, 2).float() |
|
* (-math.log(10000.0) / self.feature_dim) |
|
) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
|
for n in range(0, self.batch_size): |
|
self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6)) |
|
|
|
|
|
positional_encoding_seq_len = self.seq_len // 2 |
|
model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len) |
|
with self.assertRaises(AssertionError): |
|
output = model(self.fake_input) |
|
|
|
def test_SpatioTemporalClsPositionalEncoding(self): |
|
|
|
batch_dim = 4 |
|
dim = 16 |
|
video_shape = (1, 2, 4) |
|
video_sum = video_shape[0] * video_shape[1] * video_shape[2] |
|
has_cls = True |
|
model = SpatioTemporalClsPositionalEncoding( |
|
embed_dim=dim, |
|
patch_embed_shape=video_shape, |
|
has_cls=has_cls, |
|
) |
|
fake_input = torch.rand(batch_dim, video_sum, dim) |
|
output = model(fake_input) |
|
output_gt_shape = (batch_dim, video_sum + 1, dim) |
|
self.assertEqual(tuple(output.shape), output_gt_shape) |
|
|
|
def test_SpatioTemporalClsPositionalEncoding_nocls(self): |
|
|
|
batch_dim = 4 |
|
dim = 16 |
|
video_shape = (1, 2, 4) |
|
video_sum = video_shape[0] * video_shape[1] * video_shape[2] |
|
has_cls = False |
|
model = SpatioTemporalClsPositionalEncoding( |
|
embed_dim=dim, |
|
patch_embed_shape=video_shape, |
|
has_cls=has_cls, |
|
) |
|
fake_input = torch.rand(batch_dim, video_sum, dim) |
|
output = model(fake_input) |
|
output_gt_shape = (batch_dim, video_sum, dim) |
|
self.assertEqual(tuple(output.shape), output_gt_shape) |
|
|
|
def test_SpatioTemporalClsPositionalEncoding_mismatch(self): |
|
|
|
with self.assertRaises(AssertionError): |
|
SpatioTemporalClsPositionalEncoding( |
|
embed_dim=16, |
|
patch_embed_shape=(1, 2), |
|
) |
|
|
|
def test_SpatioTemporalClsPositionalEncoding_scriptable(self): |
|
iter_embed_dim = [1, 2, 4, 32] |
|
iter_patch_embed_shape = [(1, 1, 1), (1, 2, 4), (32, 16, 1)] |
|
iter_sep_pos_embed = [True, False] |
|
iter_has_cls = [True, False] |
|
|
|
for ( |
|
embed_dim, |
|
patch_embed_shape, |
|
sep_pos_embed, |
|
has_cls, |
|
) in itertools.product( |
|
iter_embed_dim, |
|
iter_patch_embed_shape, |
|
iter_sep_pos_embed, |
|
iter_has_cls, |
|
): |
|
stcpe = SpatioTemporalClsPositionalEncoding( |
|
embed_dim=embed_dim, |
|
patch_embed_shape=patch_embed_shape, |
|
sep_pos_embed=sep_pos_embed, |
|
has_cls=has_cls, |
|
) |
|
stcpe_scripted = torch.jit.script(stcpe) |
|
batch_dim = 4 |
|
video_dim = math.prod(patch_embed_shape) |
|
fake_input = torch.rand(batch_dim, video_dim, embed_dim) |
|
expected = stcpe(fake_input) |
|
actual = stcpe_scripted(fake_input) |
|
torch.testing.assert_allclose(expected, actual) |
|
|