|
|
| import pytest
|
| import torch
|
| from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
|
|
| from mmaction.models import X3D
|
| from mmaction.testing import check_norm_state, generate_backbone_demo_inputs
|
|
|
|
|
| def test_x3d_backbone():
|
| """Test x3d backbone."""
|
| with pytest.raises(AssertionError):
|
|
|
| X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, num_stages=0)
|
|
|
| with pytest.raises(AssertionError):
|
|
|
| X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, num_stages=5)
|
|
|
| with pytest.raises(AssertionError):
|
|
|
| X3D(gamma_w=1.0,
|
| gamma_b=2.25,
|
| gamma_d=2.2,
|
| spatial_strides=(1, 2),
|
| num_stages=4)
|
|
|
| with pytest.raises(AssertionError):
|
|
|
| X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, se_style=None)
|
|
|
| with pytest.raises(AssertionError):
|
|
|
| X3D(gamma_w=1.0,
|
| gamma_b=2.25,
|
| gamma_d=2.2,
|
| se_style='half',
|
| se_ratio=0)
|
|
|
|
|
| x3d_s = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, norm_eval=True)
|
| x3d_s.init_weights()
|
| x3d_s.train()
|
| assert check_norm_state(x3d_s.modules(), False)
|
|
|
|
|
| x3d_l = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=5.0, norm_eval=True)
|
| x3d_l.init_weights()
|
| x3d_l.train()
|
| assert check_norm_state(x3d_l.modules(), False)
|
|
|
|
|
| x3d_s = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=2.2, norm_eval=False)
|
| x3d_s.init_weights()
|
| x3d_s.train()
|
| assert check_norm_state(x3d_s.modules(), True)
|
|
|
|
|
| x3d_l = X3D(gamma_w=1.0, gamma_b=2.25, gamma_d=5.0, norm_eval=False)
|
| x3d_l.init_weights()
|
| x3d_l.train()
|
| assert check_norm_state(x3d_l.modules(), True)
|
|
|
|
|
| frozen_stages = 1
|
| x3d_s_frozen = X3D(
|
| gamma_w=1.0,
|
| gamma_b=2.25,
|
| gamma_d=2.2,
|
| norm_eval=False,
|
| frozen_stages=frozen_stages)
|
|
|
| x3d_s_frozen.init_weights()
|
| x3d_s_frozen.train()
|
| assert x3d_s_frozen.conv1_t.bn.training is False
|
| for param in x3d_s_frozen.conv1_s.parameters():
|
| assert param.requires_grad is False
|
| for param in x3d_s_frozen.conv1_t.parameters():
|
| assert param.requires_grad is False
|
|
|
| for i in range(1, frozen_stages + 1):
|
| layer = getattr(x3d_s_frozen, f'layer{i}')
|
| for mod in layer.modules():
|
| if isinstance(mod, _BatchNorm):
|
| assert mod.training is False
|
| for param in layer.parameters():
|
| assert param.requires_grad is False
|
|
|
|
|
| for m in x3d_s_frozen.modules():
|
| if hasattr(m, 'conv3'):
|
| assert torch.equal(m.conv3.bn.weight,
|
| torch.zeros_like(m.conv3.bn.weight))
|
| assert torch.equal(m.conv3.bn.bias,
|
| torch.zeros_like(m.conv3.bn.bias))
|
|
|
|
|
| input_shape = (1, 3, 13, 64, 64)
|
| imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
| if torch.__version__ == 'parrots':
|
| if torch.cuda.is_available():
|
| x3d_s_frozen = x3d_s_frozen.cuda()
|
| imgs_gpu = imgs.cuda()
|
| feat = x3d_s_frozen(imgs_gpu)
|
| assert feat.shape == torch.Size([1, 432, 13, 2, 2])
|
| else:
|
| feat = x3d_s_frozen(imgs)
|
| assert feat.shape == torch.Size([1, 432, 13, 2, 2])
|
|
|
|
|
| input_shape = (1, 3, 16, 96, 96)
|
| imgs = generate_backbone_demo_inputs(input_shape)
|
|
|
| if torch.__version__ == 'parrots':
|
| if torch.cuda.is_available():
|
| x3d_s_frozen = x3d_s_frozen.cuda()
|
| imgs_gpu = imgs.cuda()
|
| feat = x3d_s_frozen(imgs_gpu)
|
| assert feat.shape == torch.Size([1, 432, 16, 3, 3])
|
| else:
|
| feat = x3d_s_frozen(imgs)
|
| assert feat.shape == torch.Size([1, 432, 16, 3, 3])
|
|
|