import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class SRMConv2d_simple(nn.Module): def __init__(self, inc=3, learnable=False): super(SRMConv2d_simple, self).__init__() self.truc = nn.Hardtanh(-3, 3) kernel = self._build_kernel(inc) # (3,3,5,5) self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) def forward(self, x): ''' x: imgs (Batch, H, W, 3) ''' out = F.conv2d(x, self.kernel, stride=1, padding=2) out = self.truc(out) return out def _build_kernel(self, inc): # filter1: KB filter1 = [[0, 0, 0, 0, 0], [0, -1, 2, -1, 0], [0, 2, -4, 2, 0], [0, -1, 2, -1, 0], [0, 0, 0, 0, 0]] # filter2:KV filter2 = [[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2], [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1]] # filter3:hor 2rd filter3 = [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, -2, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] filter1 = np.asarray(filter1, dtype=float) / 4. filter2 = np.asarray(filter2, dtype=float) / 12. filter3 = np.asarray(filter3, dtype=float) / 2. # statck the filters filters = [[filter1],#, filter1, filter1], [filter2],#, filter2, filter2], [filter3]]#, filter3, filter3]] # (3,3,5,5) filters = np.array(filters) filters = np.repeat(filters, inc, axis=1) filters = torch.FloatTensor(filters) # (3,3,5,5) return filters class SRMConv2d_Separate(nn.Module): def __init__(self, inc, outc, learnable=False): super(SRMConv2d_Separate, self).__init__() self.inc = inc self.truc = nn.Hardtanh(-3, 3) kernel = self._build_kernel(inc) # (3,3,5,5) self.kernel = nn.Parameter(data=kernel, requires_grad=learnable) # self.hor_kernel = self._build_kernel().transpose(0,1,3,2) self.out_conv = nn.Sequential( nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False), nn.BatchNorm2d(outc), nn.ReLU(inplace=True) ) for ly in self.out_conv.children(): if isinstance(ly, nn.Conv2d): nn.init.kaiming_normal_(ly.weight, a=1) def forward(self, x): ''' x: imgs (Batch, H, W, 3) ''' out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc) out = self.truc(out) out = self.out_conv(out) return out def _build_kernel(self, inc): # filter1: KB filter1 = [[0, 0, 0, 0, 0], [0, -1, 2, -1, 0], [0, 2, -4, 2, 0], [0, -1, 2, -1, 0], [0, 0, 0, 0, 0]] # filter2:KV filter2 = [[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2], [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1]] # # filter3:hor 2rd filter3 = [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, -2, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]] filter1 = np.asarray(filter1, dtype=float) / 4. filter2 = np.asarray(filter2, dtype=float) / 12. filter3 = np.asarray(filter3, dtype=float) / 2. # statck the filters filters = [[filter1],#, filter1, filter1], [filter2],#, filter2, filter2], [filter3]]#, filter3, filter3]] # (3,3,5,5) filters = np.array(filters) # filters = np.repeat(filters, inc, axis=1) filters = np.repeat(filters, inc, axis=0) filters = torch.FloatTensor(filters) # (3,3,5,5) # print(filters.size()) return filters