|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
|
|
|
|
class BayarConv2d(nn.Module): |
|
def __init__( |
|
self, |
|
in_channles: int, |
|
out_channels: int, |
|
kernel_size: int = 5, |
|
stride: int = 1, |
|
padding: int = 0, |
|
magnitude: float = 1.0, |
|
): |
|
super().__init__() |
|
assert kernel_size > 1, "Bayar conv kernel size must be greater than 1" |
|
|
|
self.in_channels = in_channles |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.magnitude = magnitude |
|
|
|
self.center_weight = nn.Parameter( |
|
torch.ones(self.in_channels, self.out_channels, 1) * -1.0 * magnitude, |
|
requires_grad=False, |
|
) |
|
self.kernel_weight = nn.Parameter( |
|
torch.rand((self.in_channels, self.out_channels, kernel_size**2 - 1)), |
|
requires_grad=True, |
|
) |
|
|
|
def _constraint_weight(self): |
|
self.kernel_weight.data = self.kernel_weight.permute(2, 0, 1) |
|
self.kernel_weight.data = torch.div( |
|
self.kernel_weight.data, self.kernel_weight.data.sum(0) |
|
) |
|
self.kernel_weight.data = self.kernel_weight.permute(1, 2, 0) * self.magnitude |
|
center_idx = self.kernel_size**2 // 2 |
|
full_kernel = torch.cat( |
|
[ |
|
self.kernel_weight[:, :, :center_idx], |
|
self.center_weight, |
|
self.kernel_weight[:, :, center_idx:], |
|
], |
|
dim=2, |
|
) |
|
full_kernel = rearrange( |
|
full_kernel, "ci co (kw kh) -> ci co kw kh", kw=self.kernel_size |
|
) |
|
return full_kernel |
|
|
|
def forward(self, x): |
|
x = nn.functional.conv2d( |
|
x, self._constraint_weight(), stride=self.stride, padding=self.padding |
|
) |
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
device = "cuda" |
|
bayer_conv2d = BayarConv2d(3, 3, 3, magnitude=1).to(device) |
|
bayer_conv2d._constraint_weight() |
|
i = torch.rand(16, 3, 16, 16).to(device) |
|
o = bayer_conv2d(i) |
|
|