WSCL / models /srm_conv.py
yhzhai's picture
release code
482ab8a
import numpy as np
import torch
import torch.nn as nn
class SRMConv2d(nn.Module):
def __init__(self, stride: int = 1, padding: int = 2, clip: float = 2):
super().__init__()
self.stride = stride
self.padding = padding
self.clip = clip
self.conv = self._get_srm_filter()
def _get_srm_filter(self):
filter1 = [
[0, 0, 0, 0, 0],
[0, -1, 2, -1, 0],
[0, 2, -4, 2, 0],
[0, -1, 2, -1, 0],
[0, 0, 0, 0, 0],
]
filter2 = [
[-1, 2, -2, 2, -1],
[2, -6, 8, -6, 2],
[-2, 8, -12, 8, -2],
[2, -6, 8, -6, 2],
[-1, 2, -2, 2, -1],
]
filter3 = [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, -2, 1, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]
q = [4.0, 12.0, 2.0]
filter1 = np.asarray(filter1, dtype=float) / q[0]
filter2 = np.asarray(filter2, dtype=float) / q[1]
filter3 = np.asarray(filter3, dtype=float) / q[2]
filters = [
[filter1, filter1, filter1],
[filter2, filter2, filter2],
[filter3, filter3, filter3],
]
filters = torch.tensor(filters).float()
conv2d = nn.Conv2d(
3,
3,
kernel_size=5,
stride=self.stride,
padding=self.padding,
padding_mode="zeros",
)
conv2d.weight = nn.Parameter(filters, requires_grad=False)
conv2d.bias = nn.Parameter(torch.zeros_like(conv2d.bias), requires_grad=False)
return conv2d
def forward(self, x):
x = self.conv(x)
if self.clip != 0.0:
x = x.clamp(-self.clip, self.clip)
return x
if __name__ == "__main__":
srm = SRMConv2d()
x = torch.rand((63, 3, 64, 64))
x = srm(x)