File size: 1,914 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import torch
import torch.nn as nn


class SRMConv2d(nn.Module):
    def __init__(self, stride: int = 1, padding: int = 2, clip: float = 2):
        super().__init__()
        self.stride = stride
        self.padding = padding
        self.clip = clip
        self.conv = self._get_srm_filter()

    def _get_srm_filter(self):
        filter1 = [
            [0, 0, 0, 0, 0],
            [0, -1, 2, -1, 0],
            [0, 2, -4, 2, 0],
            [0, -1, 2, -1, 0],
            [0, 0, 0, 0, 0],
        ]
        filter2 = [
            [-1, 2, -2, 2, -1],
            [2, -6, 8, -6, 2],
            [-2, 8, -12, 8, -2],
            [2, -6, 8, -6, 2],
            [-1, 2, -2, 2, -1],
        ]
        filter3 = [
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 1, -2, 1, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
        ]
        q = [4.0, 12.0, 2.0]
        filter1 = np.asarray(filter1, dtype=float) / q[0]
        filter2 = np.asarray(filter2, dtype=float) / q[1]
        filter3 = np.asarray(filter3, dtype=float) / q[2]
        filters = [
            [filter1, filter1, filter1],
            [filter2, filter2, filter2],
            [filter3, filter3, filter3],
        ]
        filters = torch.tensor(filters).float()
        conv2d = nn.Conv2d(
            3,
            3,
            kernel_size=5,
            stride=self.stride,
            padding=self.padding,
            padding_mode="zeros",
        )
        conv2d.weight = nn.Parameter(filters, requires_grad=False)
        conv2d.bias = nn.Parameter(torch.zeros_like(conv2d.bias), requires_grad=False)
        return conv2d

    def forward(self, x):
        x = self.conv(x)
        if self.clip != 0.0:
            x = x.clamp(-self.clip, self.clip)
        return x


if __name__ == "__main__":
    srm = SRMConv2d()
    x = torch.rand((63, 3, 64, 64))
    x = srm(x)