File size: 1,638 Bytes
561c629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch.nn as nn


class ResBlock(nn.Module):
    """Residual block without BN.

    It has a style of:

    ::

        ---Conv-ReLU-Conv-+-
         |________________|

    Args:
        num_feats (int): Channel number of intermediate features.
            Default: 64.
        res_scale (float): Used to scale the residual before addition.
            Default: 1.0.
    """

    def __init__(self, num_feats=64, res_scale=1.0, bias=True, shortcut=True):
        super().__init__()
        self.res_scale = res_scale
        self.shortcut = shortcut
        self.conv1 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """Forward function.

        Args:
            x (Tensor): Input tensor with shape (n, c, h, w).

        Returns:
            Tensor: Forward results.
        """

        identity = x
        out = self.conv2(self.relu(self.conv1(x)))
        if self.shortcut:
            return identity + out * self.res_scale
        else:
            return out * self.res_scale


class ResBlockWrapper(ResBlock):
    "Used for transformers"

    def __init__(self, num_feats, bias=True, shortcut=True):
        super(ResBlockWrapper, self).__init__(
            num_feats=num_feats, bias=bias, shortcut=shortcut
        )

    def forward(self, x, x_size):
        H, W = x_size
        B, L, C = x.shape
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)
        x = super(ResBlockWrapper, self).forward(x)
        x = x.flatten(2).permute(0, 2, 1)
        return x