import torch import torch.nn as nn from .bayar_conv import BayarConv2d from .srm_conv import SRMConv2d class EarlyFusionPreFilter(nn.Module): def __init__(self, bayar_magnitude: float, srm_clip: float): super().__init__() self.bayar_filter = BayarConv2d( 3, 3, 5, stride=1, padding=2, magnitude=bayar_magnitude ) self.srm_filter = SRMConv2d(stride=1, padding=2, clip=srm_clip) self.rgb_filter = nn.Identity() self.map = nn.Conv2d(9, 3, 1, stride=1, padding=0) def forward(self, x): x_bayar = self.bayar_filter(x) x_srm = self.srm_filter(x) x_rgb = self.rgb_filter(x) x_concat = torch.cat([x_bayar, x_srm, x_rgb], dim=1) x_concat = self.map(x_concat) return x_concat