| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class wConv2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, den, stride=1, padding=1, groups=1, bias=False): | |
| super(wConv2d, self).__init__() | |
| self.stride = stride | |
| self.padding = padding | |
| self.kernel_size = kernel_size | |
| self.groups = groups | |
| self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, kernel_size, kernel_size)) | |
| nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') | |
| self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None | |
| device = torch.device('cpu') | |
| self.register_buffer('alfa', torch.cat([torch.tensor(den, device=device),torch.tensor([1.0], device=device),torch.flip(torch.tensor(den, device=device), dims=[0])])) | |
| self.register_buffer('Phi', torch.outer(self.alfa, self.alfa)) | |
| if self.Phi.shape != (kernel_size, kernel_size): | |
| raise ValueError(f"Phi shape {self.Phi.shape} must match kernel size ({kernel_size}, {kernel_size})") | |
| def forward(self, x): | |
| Phi = self.Phi.to(x.device) | |
| weight_Phi = self.weight * Phi | |
| return F.conv2d(x, weight_Phi, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups) | |