| |
|
| | import copy
|
| | import unittest
|
| |
|
| | import numpy as np
|
| | import pytest
|
| | import torch
|
| | from mmengine.structures import InstanceData
|
| | from mmengine.testing import assert_dict_has_keys
|
| | from numpy.testing import assert_array_equal
|
| |
|
| | from mmaction.datasets.transforms import (FormatAudioShape, FormatGCNInput,
|
| | FormatShape, PackActionInputs,
|
| | Transpose)
|
| | from mmaction.registry import TRANSFORMS
|
| | from mmaction.structures import ActionDataSample
|
| | from mmaction.utils import register_all_modules
|
| |
|
| | register_all_modules()
|
| |
|
| |
|
| | class TestPackActionInputs(unittest.TestCase):
|
| |
|
| | def test_transform(self):
|
| |
|
| | with self.assertRaises(ValueError):
|
| | results = PackActionInputs()(dict())
|
| |
|
| |
|
| | results = dict(keypoint=np.random.randn(2, 300, 17, 3), label=1)
|
| | transform = PackActionInputs()
|
| | results = transform(results)
|
| | self.assertIn('inputs', results)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertEqual(results['inputs'].shape, (2, 300, 17, 3))
|
| | self.assertEqual(results['data_samples'].gt_label,
|
| | torch.LongTensor([1]))
|
| |
|
| |
|
| | results = dict(heatmap_imgs=np.random.randn(2, 17, 56, 56), label=1)
|
| | transform = PackActionInputs()
|
| | results = transform(results)
|
| | self.assertIn('inputs', results)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertEqual(results['inputs'].shape, (2, 17, 56, 56))
|
| | self.assertEqual(results['data_samples'].gt_label,
|
| | torch.LongTensor([1]))
|
| |
|
| |
|
| | results = dict(audios=np.random.randn(3, 1, 128, 80), label=[1])
|
| | transform = PackActionInputs()
|
| | results = transform(results)
|
| | self.assertIn('inputs', results)
|
| | self.assertIn('data_samples', results)
|
| | self.assertEqual(results['inputs'].shape, (3, 1, 128, 80))
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| |
|
| |
|
| | results = dict(text=np.random.randn(77))
|
| | transform = PackActionInputs()
|
| | results = transform(results)
|
| | self.assertIn('inputs', results)
|
| | self.assertIn('data_samples', results)
|
| | self.assertEqual(results['inputs'].shape, (77, ))
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| |
|
| |
|
| | data = dict(
|
| | imgs=np.random.randn(2, 256, 256, 3),
|
| | label=[1],
|
| | filename='test.txt',
|
| | original_shape=(256, 256, 3),
|
| | img_shape=(256, 256, 3),
|
| | flip_direction='vertical')
|
| |
|
| | transform = PackActionInputs()
|
| | results = transform(copy.deepcopy(data))
|
| | self.assertIn('inputs', results)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertIsInstance(results['data_samples'], ActionDataSample)
|
| | self.assertEqual(results['data_samples'].img_shape, (256, 256, 3))
|
| | self.assertEqual(results['data_samples'].gt_label,
|
| | torch.LongTensor([1]))
|
| |
|
| |
|
| | data['imgs'] = data['imgs'].mean(-1)
|
| | results = transform(copy.deepcopy(data))
|
| | self.assertIn('inputs', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertEqual(results['inputs'].shape, (2, 256, 256))
|
| |
|
| |
|
| | data = dict(
|
| | imgs=np.random.randn(256, 256, 3),
|
| | gt_bboxes=np.array([[0, 0, 340, 224]]),
|
| | gt_labels=[1],
|
| | proposals=np.array([[0, 0, 340, 224]]),
|
| | filename='test.txt')
|
| |
|
| | transform = PackActionInputs()
|
| | results = transform(copy.deepcopy(data))
|
| | self.assertIn('inputs', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['data_samples'], ActionDataSample)
|
| | self.assertIsInstance(results['data_samples'].gt_instances,
|
| | InstanceData)
|
| | self.assertIsInstance(results['data_samples'].proposals, InstanceData)
|
| |
|
| |
|
| | data = dict(
|
| | imgs=np.random.randn(2, 256, 256, 3), text=np.random.randn(77))
|
| |
|
| | transform = PackActionInputs(collect_keys=('imgs', 'text'))
|
| | results = transform(copy.deepcopy(data))
|
| | self.assertIn('inputs', results)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['inputs'], dict)
|
| | self.assertEqual(results['inputs']['imgs'].shape, (2, 256, 256, 3))
|
| | self.assertEqual(results['inputs']['text'].shape, (77, ))
|
| |
|
| | def test_repr(self):
|
| | cfg = dict(
|
| | type='PackActionInputs', meta_keys=['flip_direction', 'img_shape'])
|
| | transform = TRANSFORMS.build(cfg)
|
| | self.assertEqual(
|
| | repr(transform), 'PackActionInputs(collect_keys=None, '
|
| | "meta_keys=['flip_direction', 'img_shape'])")
|
| |
|
| |
|
| | class TestPackLocalizationInputs(unittest.TestCase):
|
| |
|
| | def test_transform(self):
|
| |
|
| | data = dict(
|
| | raw_feature=np.random.randn(400, 5),
|
| | gt_bbox=np.array([[0.1, 0.3], [0.375, 0.625]]),
|
| | filename='test.txt')
|
| |
|
| | cfg = dict(type='PackLocalizationInputs', keys=('gt_bbox', ))
|
| | transform = TRANSFORMS.build(cfg)
|
| | results = transform(copy.deepcopy(data))
|
| | self.assertIn('inputs', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['data_samples'], ActionDataSample)
|
| | self.assertIsInstance(results['data_samples'].gt_instances,
|
| | InstanceData)
|
| |
|
| | del data['raw_feature']
|
| | with self.assertRaises(ValueError):
|
| | transform(copy.deepcopy(data))
|
| |
|
| |
|
| | data['bsp_feature'] = np.random.randn(100, 32)
|
| | results = transform(copy.deepcopy(data))
|
| | self.assertIn('inputs', results)
|
| | self.assertIsInstance(results['inputs'], torch.Tensor)
|
| | self.assertIn('data_samples', results)
|
| | self.assertIsInstance(results['data_samples'], ActionDataSample)
|
| | self.assertIsInstance(results['data_samples'].gt_instances,
|
| | InstanceData)
|
| |
|
| | def test_repr(self):
|
| | cfg = dict(
|
| | type='PackLocalizationInputs',
|
| | meta_keys=['video_name', 'feature_frame'])
|
| | transform = TRANSFORMS.build(cfg)
|
| | self.assertEqual(
|
| | repr(transform),
|
| | "PackLocalizationInputs(meta_keys=['video_name', 'feature_frame'])"
|
| | )
|
| |
|
| |
|
| | def test_transpose():
|
| | results = dict(imgs=np.random.randn(256, 256, 3))
|
| | keys = ['imgs']
|
| | order = [2, 0, 1]
|
| | transpose = Transpose(keys, order)
|
| | results = transpose(results)
|
| | assert results['imgs'].shape == (3, 256, 256)
|
| | assert repr(transpose) == transpose.__class__.__name__ + \
|
| | f'(keys={keys}, order={order})'
|
| |
|
| |
|
| | def test_format_shape():
|
| | with pytest.raises(ValueError):
|
| |
|
| | FormatShape('NHWC')
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
|
| | format_shape = FormatShape('NCHW')
|
| | assert format_shape(results)['input_shape'] == (3, 3, 224, 224)
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(3, 224, 224, 2),
|
| | num_clips=1,
|
| | clip_len=3,
|
| | modality='Flow')
|
| | format_shape = FormatShape('NCHW')
|
| | assert format_shape(results)['input_shape'] == (1, 6, 224, 224)
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(3, 224, 224, 3), num_clips=1, clip_len=3)
|
| | format_shape = FormatShape('NCTHW')
|
| | assert format_shape(results)['input_shape'] == (1, 3, 3, 224, 224)
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(18, 224, 224, 3), num_clips=2, clip_len=3)
|
| | assert format_shape(results)['input_shape'] == (6, 3, 3, 224, 224)
|
| | target_keys = ['imgs', 'input_shape']
|
| | assert assert_dict_has_keys(results, target_keys)
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(6, 224, 224, 3),
|
| | heatmap_imgs=np.random.randn(12, 17, 56, 56),
|
| | num_clips=2,
|
| | clip_len=dict(RGB=3, Pose=6))
|
| |
|
| | results = format_shape(results)
|
| | assert results['input_shape'] == (2, 3, 3, 224, 224)
|
| | assert results['heatmap_input_shape'] == (2, 17, 6, 56, 56)
|
| |
|
| | assert repr(format_shape) == "FormatShape(input_format='NCTHW')"
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(12, 17, 56, 56), num_clips=2, clip_len=6)
|
| | format_shape = FormatShape('NCTHW_Heatmap')
|
| | assert format_shape(results)['input_shape'] == (2, 17, 6, 56, 56)
|
| |
|
| |
|
| | results = dict(
|
| | imgs=np.random.randn(72, 224, 224, 3),
|
| | num_clips=9,
|
| | clip_len=1,
|
| | num_proposals=8)
|
| | format_shape = FormatShape('NPTCHW')
|
| | assert format_shape(results)['input_shape'] == (8, 9, 3, 224, 224)
|
| |
|
| |
|
| | def test_format_audio_shape():
|
| | with pytest.raises(ValueError):
|
| |
|
| | FormatAudioShape('XXXX')
|
| |
|
| |
|
| | results = dict(audios=np.random.randn(3, 128, 8))
|
| | format_shape = FormatAudioShape('NCTF')
|
| | assert format_shape(results)['input_shape'] == (3, 1, 128, 8)
|
| | assert repr(format_shape) == format_shape.__class__.__name__ + \
|
| | "(input_format='NCTF')"
|
| |
|
| |
|
| | def test_format_gcn_input():
|
| | with pytest.raises(AssertionError):
|
| | FormatGCNInput(mode='invalid')
|
| |
|
| | results = dict(
|
| | keypoint=np.random.randn(2, 10, 17, 2),
|
| | keypoint_score=np.random.randn(2, 10, 17))
|
| | format_shape = FormatGCNInput(num_person=2, mode='zero')
|
| | results = format_shape(results)
|
| | assert results['keypoint'].shape == (1, 2, 10, 17, 3)
|
| | assert repr(format_shape) == 'FormatGCNInput(num_person=2, mode=zero)'
|
| |
|
| | results = dict(keypoint=np.random.randn(2, 40, 25, 3), num_clips=4)
|
| | format_shape = FormatGCNInput(num_person=2, mode='zero')
|
| | results = format_shape(results)
|
| | assert results['keypoint'].shape == (4, 2, 10, 25, 3)
|
| |
|
| | results = dict(keypoint=np.random.randn(1, 10, 25, 3))
|
| | format_shape = FormatGCNInput(num_person=2, mode='zero')
|
| | results = format_shape(results)
|
| | assert results['keypoint'].shape == (1, 2, 10, 25, 3)
|
| | assert_array_equal(results['keypoint'][:, 1], np.zeros((1, 10, 25, 3)))
|
| |
|
| | results = dict(keypoint=np.random.randn(1, 10, 25, 3))
|
| | format_shape = FormatGCNInput(num_person=2, mode='loop')
|
| | results = format_shape(results)
|
| | assert results['keypoint'].shape == (1, 2, 10, 25, 3)
|
| | assert_array_equal(results['keypoint'][:, 1], results['keypoint'][:, 0])
|
| |
|
| | results = dict(keypoint=np.random.randn(3, 10, 25, 3))
|
| | format_shape = FormatGCNInput(num_person=2, mode='zero')
|
| | results = format_shape(results)
|
| | assert results['keypoint'].shape == (1, 2, 10, 25, 3)
|
| |
|