import torch from torch import nn from modeling.base import BaseNetwork from modeling.ifrnet import Flatten from modules.blocks import DestyleResBlock, Destyler, ResBlock class UNet(BaseNetwork): def __init__(self, base_n_channels): super(UNet, self).__init__() self.ds_res1 = ResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2) self.ds_res2 = ResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1) self.ds_res3 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) self.ds_res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1) self.ds_res5 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) self.ds_res6 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1) self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0) self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1) self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1) self.init_weights(init_type="normal", gain=0.02) def forward(self, x): out = self.ds_res1(x) out = self.ds_res2(out) out = self.ds_res3(out) out = self.ds_res4(out) out = self.ds_res5(out) aux = self.ds_res6(out) out = self.upsample(aux) out = self.res1(out) out = self.res2(out) out = self.upsample(out) out = self.res3(out) out = self.res4(out) out = self.upsample(out) out = self.res5(out) out = self.conv1(out) return out, aux if __name__ == '__main__': import torchvision x = torch.rand((2, 3, 256, 256)).cuda() unet = UNet(32, 32).cuda() vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda() with torch.no_grad(): vgg_feat = vgg16(x) out = unet(x, vgg_feat) print(out.size())