File size: 787 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn

from .bayar_conv import BayarConv2d
from .srm_conv import SRMConv2d


class EarlyFusionPreFilter(nn.Module):
    def __init__(self, bayar_magnitude: float, srm_clip: float):
        super().__init__()
        self.bayar_filter = BayarConv2d(
            3, 3, 5, stride=1, padding=2, magnitude=bayar_magnitude
        )
        self.srm_filter = SRMConv2d(stride=1, padding=2, clip=srm_clip)
        self.rgb_filter = nn.Identity()
        self.map = nn.Conv2d(9, 3, 1, stride=1, padding=0)

    def forward(self, x):
        x_bayar = self.bayar_filter(x)
        x_srm = self.srm_filter(x)
        x_rgb = self.rgb_filter(x)

        x_concat = torch.cat([x_bayar, x_srm, x_rgb], dim=1)
        x_concat = self.map(x_concat)
        return x_concat