File size: 1,715 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) OpenMMLab. All rights reserved.

import torch
from addict import Dict

from mmpose.models.detectors import GestureRecognizer


def test_gesture_recognizer_forward():
    model_cfg = dict(
        type='GestureRecognizer',
        pretrained=None,
        modality=['rgb', 'depth'],
        backbone=dict(
            rgb=dict(
                type='I3D',
                in_channels=3,
                expansion=0.25,
            ),
            depth=dict(
                type='I3D',
                in_channels=1,
                expansion=0.25,
            ),
        ),
        cls_head=dict(
            type='MultiModalSSAHead',
            num_classes=25,
            avg_pool_kernel=(1, 2, 2),
            in_channels=256),
        train_cfg=dict(
            beta=2,
            lambda_=1e-3,
            ssa_start_epoch=10,
        ),
        test_cfg=dict(),
    )

    detector = GestureRecognizer(model_cfg['backbone'], None,
                                 model_cfg['cls_head'], model_cfg['train_cfg'],
                                 model_cfg['test_cfg'], model_cfg['modality'],
                                 model_cfg['pretrained'])
    detector.set_train_epoch(11)

    video = [torch.randn(1, 3, 16, 112, 112), torch.randn(1, 1, 16, 112, 112)]
    labels = torch.tensor([1]).long()
    img_metas = Dict()
    img_metas.data = dict(modality=['rgb', 'depth'])

    # Test forward train
    losses = detector.forward(video, labels, img_metas, return_loss=True)
    assert isinstance(losses, dict)
    assert 'ssa_loss' in losses

    # Test forward test
    with torch.no_grad():
        _ = detector.forward(
            video, labels, img_metas=img_metas, return_loss=False)