File size: 4,401 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmpose.models.backbones.pvt import (PVTEncoderLayer,
                                         PyramidVisionTransformer,
                                         PyramidVisionTransformerV2)


def test_pvt_block():
    # test PVT structure and forward
    block = PVTEncoderLayer(
        embed_dims=64, num_heads=4, feedforward_channels=256)
    assert block.ffn.embed_dims == 64
    assert block.attn.num_heads == 4
    assert block.ffn.feedforward_channels == 256
    x = torch.randn(1, 56 * 56, 64)
    x_out = block(x, (56, 56))
    assert x_out.shape == torch.Size([1, 56 * 56, 64])


def test_pvt():
    """Test PVT backbone."""

    with pytest.raises(TypeError):
        # Pretrained arg must be str or None.
        PyramidVisionTransformer(pretrained=123)

    # test pretrained image size
    with pytest.raises(AssertionError):
        PyramidVisionTransformer(pretrain_img_size=(224, 224, 224))

    # test padding
    model = PyramidVisionTransformer(
        paddings=['corner', 'corner', 'corner', 'corner'])
    temp = torch.randn((1, 3, 32, 32))
    outs = model(temp)
    assert outs[0].shape == (1, 64, 8, 8)
    assert outs[1].shape == (1, 128, 4, 4)
    assert outs[2].shape == (1, 320, 2, 2)
    assert outs[3].shape == (1, 512, 1, 1)

    # Test absolute position embedding
    temp = torch.randn((1, 3, 224, 224))
    model = PyramidVisionTransformer(
        pretrain_img_size=224, use_abs_pos_embed=True)
    model.init_weights()
    model(temp)

    # Test normal inference
    temp = torch.randn((1, 3, 32, 32))
    model = PyramidVisionTransformer()
    outs = model(temp)
    assert outs[0].shape == (1, 64, 8, 8)
    assert outs[1].shape == (1, 128, 4, 4)
    assert outs[2].shape == (1, 320, 2, 2)
    assert outs[3].shape == (1, 512, 1, 1)

    # Test abnormal inference size
    temp = torch.randn((1, 3, 33, 33))
    model = PyramidVisionTransformer()
    outs = model(temp)
    assert outs[0].shape == (1, 64, 8, 8)
    assert outs[1].shape == (1, 128, 4, 4)
    assert outs[2].shape == (1, 320, 2, 2)
    assert outs[3].shape == (1, 512, 1, 1)

    # Test abnormal inference size
    temp = torch.randn((1, 3, 112, 137))
    model = PyramidVisionTransformer()
    outs = model(temp)
    assert outs[0].shape == (1, 64, 28, 34)
    assert outs[1].shape == (1, 128, 14, 17)
    assert outs[2].shape == (1, 320, 7, 8)
    assert outs[3].shape == (1, 512, 3, 4)


def test_pvtv2():
    """Test PVTv2 backbone."""

    with pytest.raises(TypeError):
        # Pretrained arg must be str or None.
        PyramidVisionTransformerV2(pretrained=123)

    # test pretrained image size
    with pytest.raises(AssertionError):
        PyramidVisionTransformerV2(pretrain_img_size=(224, 224, 224))

    # test load pretrained weights
    model = PyramidVisionTransformerV2(
        embed_dims=32,
        num_layers=[2, 2, 2, 2],
        pretrained='https://github.com/whai362/PVT/'
        'releases/download/v2/pvt_v2_b0.pth')
    model.init_weights()

    # test init_cfg
    model = PyramidVisionTransformerV2(
        embed_dims=32,
        num_layers=[2, 2, 2, 2],
        init_cfg=dict(checkpoint='https://github.com/whai362/PVT/'
                      'releases/download/v2/pvt_v2_b0.pth'))
    model.init_weights()

    # test init weights from scratch
    model = PyramidVisionTransformerV2(embed_dims=32, num_layers=[2, 2, 2, 2])
    model.init_weights()

    # Test normal inference
    temp = torch.randn((1, 3, 32, 32))
    model = PyramidVisionTransformerV2()
    outs = model(temp)
    assert outs[0].shape == (1, 64, 8, 8)
    assert outs[1].shape == (1, 128, 4, 4)
    assert outs[2].shape == (1, 320, 2, 2)
    assert outs[3].shape == (1, 512, 1, 1)

    # Test abnormal inference size
    temp = torch.randn((1, 3, 31, 31))
    model = PyramidVisionTransformerV2()
    outs = model(temp)
    assert outs[0].shape == (1, 64, 8, 8)
    assert outs[1].shape == (1, 128, 4, 4)
    assert outs[2].shape == (1, 320, 2, 2)
    assert outs[3].shape == (1, 512, 1, 1)

    # Test abnormal inference size
    temp = torch.randn((1, 3, 112, 137))
    model = PyramidVisionTransformerV2()
    outs = model(temp)
    assert outs[0].shape == (1, 64, 28, 35)
    assert outs[1].shape == (1, 128, 14, 18)
    assert outs[2].shape == (1, 320, 7, 9)
    assert outs[3].shape == (1, 512, 4, 5)