File size: 939 Bytes
1c3f916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn

from torch.nn import functional as F
from timm import create_model


__all__ = ['NoiseTransformer']

class NoiseTransformer(nn.Module):
    def __init__(self, resolution=(128,96)):
        super().__init__()
        self.upsample = lambda x: F.interpolate(x, [224,224])
        self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
        self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
        self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
        # self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
        self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)


    def forward(self, x, residual=False):
        if residual:
            x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
        else:
            x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))

        return x