pandagpt-vicuna-v0-7b / code /pytorchvideo /tests /test_layers_positional_encoding.py
mvsoom's picture
Upload folder using huggingface_hub
3133fdb
raw
history blame contribute delete
5.53 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import math
import unittest
import torch
from pytorchvideo.layers import PositionalEncoding, SpatioTemporalClsPositionalEncoding
class TestPositionalEncoding(unittest.TestCase):
def setUp(self):
super().setUp()
torch.set_rng_state(torch.manual_seed(42).get_state())
self.batch_size = 4
self.seq_len = 16
self.feature_dim = 8
self.fake_input = torch.randn(
(self.batch_size, self.seq_len, self.feature_dim)
).float()
lengths = torch.Tensor([16, 0, 14, 15, 16, 16, 16, 16])
self.mask = torch.lt(
torch.arange(self.seq_len)[None, :], lengths[:, None].long()
)
def test_positional_encoding(self):
model = PositionalEncoding(self.feature_dim, self.seq_len)
output = model(self.fake_input)
delta = output - self.fake_input
pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.feature_dim, 2).float()
* (-math.log(10000.0) / self.feature_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
for n in range(0, self.batch_size):
self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))
def test_positional_encoding_with_different_pe_and_data_dimensions(self):
"""Test that model executes even if input data dimensions
differs from the dimension of initialized postional encoding model"""
# When self.seq_len < positional_encoding_seq_len, pe is added to input
positional_encoding_seq_len = self.seq_len * 3
model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
output = model(self.fake_input)
delta = output - self.fake_input
pe = torch.zeros(self.seq_len, self.feature_dim, dtype=torch.float)
position = torch.arange(0, self.seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.feature_dim, 2).float()
* (-math.log(10000.0) / self.feature_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
for n in range(0, self.batch_size):
self.assertTrue(torch.allclose(delta[n], pe, atol=1e-6))
# When self.seq_len > positional_encoding_seq_len, assertion error
positional_encoding_seq_len = self.seq_len // 2
model = PositionalEncoding(self.feature_dim, positional_encoding_seq_len)
with self.assertRaises(AssertionError):
output = model(self.fake_input)
def test_SpatioTemporalClsPositionalEncoding(self):
# Test with cls token.
batch_dim = 4
dim = 16
video_shape = (1, 2, 4)
video_sum = video_shape[0] * video_shape[1] * video_shape[2]
has_cls = True
model = SpatioTemporalClsPositionalEncoding(
embed_dim=dim,
patch_embed_shape=video_shape,
has_cls=has_cls,
)
fake_input = torch.rand(batch_dim, video_sum, dim)
output = model(fake_input)
output_gt_shape = (batch_dim, video_sum + 1, dim)
self.assertEqual(tuple(output.shape), output_gt_shape)
def test_SpatioTemporalClsPositionalEncoding_nocls(self):
# Test without cls token.
batch_dim = 4
dim = 16
video_shape = (1, 2, 4)
video_sum = video_shape[0] * video_shape[1] * video_shape[2]
has_cls = False
model = SpatioTemporalClsPositionalEncoding(
embed_dim=dim,
patch_embed_shape=video_shape,
has_cls=has_cls,
)
fake_input = torch.rand(batch_dim, video_sum, dim)
output = model(fake_input)
output_gt_shape = (batch_dim, video_sum, dim)
self.assertEqual(tuple(output.shape), output_gt_shape)
def test_SpatioTemporalClsPositionalEncoding_mismatch(self):
# Mismatch in dimension for patch_embed_shape.
with self.assertRaises(AssertionError):
SpatioTemporalClsPositionalEncoding(
embed_dim=16,
patch_embed_shape=(1, 2),
)
def test_SpatioTemporalClsPositionalEncoding_scriptable(self):
iter_embed_dim = [1, 2, 4, 32]
iter_patch_embed_shape = [(1, 1, 1), (1, 2, 4), (32, 16, 1)]
iter_sep_pos_embed = [True, False]
iter_has_cls = [True, False]
for (
embed_dim,
patch_embed_shape,
sep_pos_embed,
has_cls,
) in itertools.product(
iter_embed_dim,
iter_patch_embed_shape,
iter_sep_pos_embed,
iter_has_cls,
):
stcpe = SpatioTemporalClsPositionalEncoding(
embed_dim=embed_dim,
patch_embed_shape=patch_embed_shape,
sep_pos_embed=sep_pos_embed,
has_cls=has_cls,
)
stcpe_scripted = torch.jit.script(stcpe)
batch_dim = 4
video_dim = math.prod(patch_embed_shape)
fake_input = torch.rand(batch_dim, video_dim, embed_dim)
expected = stcpe(fake_input)
actual = stcpe_scripted(fake_input)
torch.testing.assert_allclose(expected, actual)