USR-DA / model /encoder.py
DS
dump shiet
e5b70eb
raw
history blame
2.76 kB
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.conv_featmap_1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1, bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
)
self.conv_featmap_2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
)
self.conv_featmap_3 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=True),
nn.ReLU(),
)
def forward(self, img):
featmap_1 = self.conv_featmap_1(img)
featmap_1_down = self.maxpool(featmap_1)
featmap_2 = self.conv_featmap_2(featmap_1_down)
featmap_2_down = self.maxpool(featmap_2)
featmap_3 = self.conv_featmap_3(featmap_2_down)
return featmap_3
class Encoder_RRDB(nn.Module):
def __init__(self, num_feat=16):
super(Encoder_RRDB, self).__init__()
self.conv_featmap = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
)
def forward(self, img):
featmap = self.conv_featmap(img)
return featmap