File size: 1,715 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from addict import Dict
from mmpose.models.detectors import GestureRecognizer
def test_gesture_recognizer_forward():
model_cfg = dict(
type='GestureRecognizer',
pretrained=None,
modality=['rgb', 'depth'],
backbone=dict(
rgb=dict(
type='I3D',
in_channels=3,
expansion=0.25,
),
depth=dict(
type='I3D',
in_channels=1,
expansion=0.25,
),
),
cls_head=dict(
type='MultiModalSSAHead',
num_classes=25,
avg_pool_kernel=(1, 2, 2),
in_channels=256),
train_cfg=dict(
beta=2,
lambda_=1e-3,
ssa_start_epoch=10,
),
test_cfg=dict(),
)
detector = GestureRecognizer(model_cfg['backbone'], None,
model_cfg['cls_head'], model_cfg['train_cfg'],
model_cfg['test_cfg'], model_cfg['modality'],
model_cfg['pretrained'])
detector.set_train_epoch(11)
video = [torch.randn(1, 3, 16, 112, 112), torch.randn(1, 1, 16, 112, 112)]
labels = torch.tensor([1]).long()
img_metas = Dict()
img_metas.data = dict(modality=['rgb', 'depth'])
# Test forward train
losses = detector.forward(video, labels, img_metas, return_loss=True)
assert isinstance(losses, dict)
assert 'ssa_loss' in losses
# Test forward test
with torch.no_grad():
_ = detector.forward(
video, labels, img_metas=img_metas, return_loss=False)
|