import numpy as np import torch import torch.nn as nn class SRMConv2d(nn.Module): def __init__(self, stride: int = 1, padding: int = 2, clip: float = 2): super().__init__() self.stride = stride self.padding = padding self.clip = clip self.conv = self._get_srm_filter() def _get_srm_filter(self): filter1 = [ [0, 0, 0, 0, 0], [0, -1, 2, -1, 0], [0, 2, -4, 2, 0], [0, -1, 2, -1, 0], [0, 0, 0, 0, 0], ] filter2 = [ [-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2], [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1], ] filter3 = [ [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, -2, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], ] q = [4.0, 12.0, 2.0] filter1 = np.asarray(filter1, dtype=float) / q[0] filter2 = np.asarray(filter2, dtype=float) / q[1] filter3 = np.asarray(filter3, dtype=float) / q[2] filters = [ [filter1, filter1, filter1], [filter2, filter2, filter2], [filter3, filter3, filter3], ] filters = torch.tensor(filters).float() conv2d = nn.Conv2d( 3, 3, kernel_size=5, stride=self.stride, padding=self.padding, padding_mode="zeros", ) conv2d.weight = nn.Parameter(filters, requires_grad=False) conv2d.bias = nn.Parameter(torch.zeros_like(conv2d.bias), requires_grad=False) return conv2d def forward(self, x): x = self.conv(x) if self.clip != 0.0: x = x.clamp(-self.clip, self.clip) return x if __name__ == "__main__": srm = SRMConv2d() x = torch.rand((63, 3, 64, 64)) x = srm(x)