File size: 1,515 Bytes
d3dbf03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmaction.models import AAGCN
from mmaction.utils import register_all_modules


def test_aagcn_backbone():
    """Test AAGCN backbone."""

    register_all_modules()

    mode = 'spatial'
    batch_size, num_person, num_frames = 2, 2, 150

    # openpose-18 layout
    num_joints = 18
    model = AAGCN(graph_cfg=dict(layout='openpose', mode=mode))
    model.init_weights()
    inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
    output = model(inputs)
    assert output.shape == torch.Size([2, 2, 256, 38, 18])

    # nturgb+d layout
    num_joints = 25
    model = AAGCN(graph_cfg=dict(layout='nturgb+d', mode=mode))
    model.init_weights()
    inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
    output = model(inputs)
    assert output.shape == torch.Size([2, 2, 256, 38, 25])

    # coco layout
    num_joints = 17
    model = AAGCN(graph_cfg=dict(layout='coco', mode=mode))
    model.init_weights()
    inputs = torch.randn(batch_size, num_person, num_frames, num_joints, 3)
    output = model(inputs)
    assert output.shape == torch.Size([2, 2, 256, 38, 17])

    # custom settings
    # disable the attention module to degenerate AAGCN to AGCN
    model = AAGCN(
        graph_cfg=dict(layout='coco', mode=mode), gcn_attention=False)
    model.init_weights()
    output = model(inputs)
    assert output.shape == torch.Size([2, 2, 256, 38, 17])