| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Tests for model.py.""" |
|
|
| import functools |
| from absl.testing import absltest |
| from absl.testing import parameterized |
| from jax import random |
| import jax.numpy as jnp |
| from scenic.projects.token_learner import model |
|
|
|
|
| class TokenLearnerTest(parameterized.TestCase): |
| """Tests for modules in token-learner model.py.""" |
|
|
| @parameterized.named_parameters( |
| ('32_tokens', 32), |
| ('111_tokens', 111), |
| ) |
| def test_dynamic_tokenizer(self, num_tokens): |
| """Tests TokenLearner module.""" |
| rng = random.PRNGKey(0) |
| x = jnp.ones((4, 224, 224, 64)) |
| tokenizer = functools.partial(model.TokenLearnerModule, |
| num_tokens=num_tokens) |
| tokenizer_vars = tokenizer().init(rng, x) |
| y = tokenizer().apply(tokenizer_vars, x) |
| |
| self.assertEqual(y.shape, (x.shape[0], num_tokens, x.shape[-1])) |
|
|
| @parameterized.named_parameters( |
| ('encoder_image', (2, 16, 192), 'dynamic', 1, 8, model.EncoderMod), |
| ('encoder_video_temporal_dims_1', |
| (2, 16, 192), 'video', 1, 8, model.EncoderMod), |
| ('encoder_video_temporal_dims_2', |
| (2, 32, 192), 'video', 2, 8, model.EncoderMod), |
| ('encoder_video_temporal_dims_4', |
| (2, 64, 192), 'video', 4, 8, model.EncoderMod), |
| ('encoder_fusion_image', |
| (2, 16, 192), 'dynamic', 1, 8, model.EncoderModFuser), |
| ('encoder_fusion_video_temporal_dims_1', |
| (2, 16, 192), 'video', 1, 8, model.EncoderModFuser), |
| ('encoder_fusion_video_temporal_dims_2', |
| (2, 32, 192), 'video', 2, 8, model.EncoderModFuser), |
| ('encoder_fusion_video_temporal_dims_4', |
| (2, 64, 192), 'video', 4, 8, model.EncoderModFuser), |
| ) |
| def test_encoder(self, input_shape, tokenizer_type, |
| temporal_dimensions, num_tokens, encoder_function): |
| """Tests shapes of TokenLearner Encoder (with and without TokenFuser).""" |
| rng = random.PRNGKey(0) |
| dummy_input = jnp.ones(input_shape) |
| encoder = functools.partial( |
| encoder_function, |
| num_layers=3, |
| mlp_dim=192, |
| num_heads=3, |
| tokenizer_type=tokenizer_type, |
| temporal_dimensions=temporal_dimensions, |
| num_tokens=num_tokens, |
| tokenlearner_loc=2) |
| encoder_vars = encoder().init(rng, dummy_input) |
| y = encoder().apply(encoder_vars, dummy_input) |
|
|
| if encoder_function == model.EncoderMod: |
| if tokenizer_type == 'dynamic': |
| expected_shape = (input_shape[0], num_tokens, input_shape[2]) |
| elif tokenizer_type == 'video': |
| expected_shape = ( |
| input_shape[0], num_tokens * temporal_dimensions, input_shape[2]) |
| else: |
| raise ValueError('Unknown tokenizer type.') |
| elif encoder_function == model.EncoderModFuser: |
| expected_shape = input_shape |
| else: |
| raise ValueError('Unknown encoder function.') |
|
|
| self.assertEqual(y.shape, expected_shape) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|