WSCL / models /bayar_conv.py
yhzhai's picture
release code
482ab8a
raw
history blame
2.1 kB
import torch
import torch.nn as nn
from einops import rearrange
class BayarConv2d(nn.Module):
def __init__(
self,
in_channles: int,
out_channels: int,
kernel_size: int = 5,
stride: int = 1,
padding: int = 0,
magnitude: float = 1.0,
):
super().__init__()
assert kernel_size > 1, "Bayar conv kernel size must be greater than 1"
self.in_channels = in_channles
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.magnitude = magnitude
self.center_weight = nn.Parameter(
torch.ones(self.in_channels, self.out_channels, 1) * -1.0 * magnitude,
requires_grad=False,
)
self.kernel_weight = nn.Parameter(
torch.rand((self.in_channels, self.out_channels, kernel_size**2 - 1)),
requires_grad=True,
)
def _constraint_weight(self):
self.kernel_weight.data = self.kernel_weight.permute(2, 0, 1)
self.kernel_weight.data = torch.div(
self.kernel_weight.data, self.kernel_weight.data.sum(0)
)
self.kernel_weight.data = self.kernel_weight.permute(1, 2, 0) * self.magnitude
center_idx = self.kernel_size**2 // 2
full_kernel = torch.cat(
[
self.kernel_weight[:, :, :center_idx],
self.center_weight,
self.kernel_weight[:, :, center_idx:],
],
dim=2,
)
full_kernel = rearrange(
full_kernel, "ci co (kw kh) -> ci co kw kh", kw=self.kernel_size
)
return full_kernel
def forward(self, x):
x = nn.functional.conv2d(
x, self._constraint_weight(), stride=self.stride, padding=self.padding
)
return x
if __name__ == "__main__":
device = "cuda"
bayer_conv2d = BayarConv2d(3, 3, 3, magnitude=1).to(device)
bayer_conv2d._constraint_weight()
i = torch.rand(16, 3, 16, 16).to(device)
o = bayer_conv2d(i)