File size: 3,829 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn.modules.batchnorm import _BatchNorm

from mmpose.models.backbones import HRNet
from mmpose.models.backbones.hrnet import HRModule
from mmpose.models.backbones.resnet import BasicBlock, Bottleneck


def is_block(modules):
    """Check if is HRModule building block."""
    if isinstance(modules, (HRModule, )):
        return True
    return False


def is_norm(modules):
    """Check if is one of the norms."""
    if isinstance(modules, (_BatchNorm, )):
        return True
    return False


def all_zeros(modules):
    """Check if the weight(and bias) is all zero."""
    weight_zero = torch.equal(modules.weight.data,
                              torch.zeros_like(modules.weight.data))
    if hasattr(modules, 'bias'):
        bias_zero = torch.equal(modules.bias.data,
                                torch.zeros_like(modules.bias.data))
    else:
        bias_zero = True

    return weight_zero and bias_zero


def test_hrmodule():
    # Test HRModule forward
    block = HRModule(
        num_branches=1,
        blocks=BasicBlock,
        num_blocks=(4, ),
        in_channels=[
            64,
        ],
        num_channels=(64, ))

    x = torch.randn(2, 64, 56, 56)
    x_out = block([x])
    assert x_out[0].shape == torch.Size([2, 64, 56, 56])


def test_hrnet_backbone():
    extra = dict(
        stage1=dict(
            num_modules=1,
            num_branches=1,
            block='BOTTLENECK',
            num_blocks=(4, ),
            num_channels=(64, )),
        stage2=dict(
            num_modules=1,
            num_branches=2,
            block='BASIC',
            num_blocks=(4, 4),
            num_channels=(32, 64)),
        stage3=dict(
            num_modules=4,
            num_branches=3,
            block='BASIC',
            num_blocks=(4, 4, 4),
            num_channels=(32, 64, 128)),
        stage4=dict(
            num_modules=3,
            num_branches=4,
            block='BASIC',
            num_blocks=(4, 4, 4, 4),
            num_channels=(32, 64, 128, 256)))

    model = HRNet(extra, in_channels=3)

    imgs = torch.randn(2, 3, 224, 224)
    feat = model(imgs)
    assert len(feat) == 1
    assert feat[0].shape == torch.Size([2, 32, 56, 56])

    # Test HRNet zero initialization of residual
    model = HRNet(extra, in_channels=3, zero_init_residual=True)
    model.init_weights()
    for m in model.modules():
        if isinstance(m, Bottleneck):
            assert all_zeros(m.norm3)
    model.train()

    imgs = torch.randn(2, 3, 224, 224)
    feat = model(imgs)
    assert len(feat) == 1
    assert feat[0].shape == torch.Size([2, 32, 56, 56])

    # Test HRNet with the first three stages frozen
    frozen_stages = 3
    model = HRNet(extra, in_channels=3, frozen_stages=frozen_stages)
    model.init_weights()
    model.train()
    if frozen_stages >= 0:
        assert model.norm1.training is False
        assert model.norm2.training is False
        for layer in [model.conv1, model.norm1, model.conv2, model.norm2]:
            for param in layer.parameters():
                assert param.requires_grad is False

    for i in range(1, frozen_stages + 1):
        if i == 1:
            layer = getattr(model, 'layer1')
        else:
            layer = getattr(model, f'stage{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False

        if i < 4:
            layer = getattr(model, f'transition{i}')
        for mod in layer.modules():
            if isinstance(mod, _BatchNorm):
                assert mod.training is False
        for param in layer.parameters():
            assert param.requires_grad is False