File size: 1,482 Bytes
985cc7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import mmcv
import pytest
import torch

from mmseg.models.utils.se_layer import SELayer


def test_se_layer():
    with pytest.raises(AssertionError):
        # test act_cfg assertion.
        SELayer(32, act_cfg=(dict(type='ReLU'), ))

    # test config with channels = 16.
    se_layer = SELayer(16)
    assert se_layer.conv1.conv.kernel_size == (1, 1)
    assert se_layer.conv1.conv.stride == (1, 1)
    assert se_layer.conv1.conv.padding == (0, 0)
    assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
    assert se_layer.conv2.conv.kernel_size == (1, 1)
    assert se_layer.conv2.conv.stride == (1, 1)
    assert se_layer.conv2.conv.padding == (0, 0)
    assert isinstance(se_layer.conv2.activate, mmcv.cnn.HSigmoid)

    x = torch.rand(1, 16, 64, 64)
    output = se_layer(x)
    assert output.shape == (1, 16, 64, 64)

    # test config with channels = 16, act_cfg = dict(type='ReLU').
    se_layer = SELayer(16, act_cfg=dict(type='ReLU'))
    assert se_layer.conv1.conv.kernel_size == (1, 1)
    assert se_layer.conv1.conv.stride == (1, 1)
    assert se_layer.conv1.conv.padding == (0, 0)
    assert isinstance(se_layer.conv1.activate, torch.nn.ReLU)
    assert se_layer.conv2.conv.kernel_size == (1, 1)
    assert se_layer.conv2.conv.stride == (1, 1)
    assert se_layer.conv2.conv.padding == (0, 0)
    assert isinstance(se_layer.conv2.activate, torch.nn.ReLU)

    x = torch.rand(1, 16, 64, 64)
    output = se_layer(x)
    assert output.shape == (1, 16, 64, 64)