# Github Repository: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/README_EN.md # Code snippet (with certain modificaiton) from: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/VapourSynth/upcunet_v3_vs.py import torch from torch import nn as nn from torch.nn import functional as F import os, sys import numpy as np from time import time as ttime, sleep class UNet_Full(nn.Module): def __init__(self): super(UNet_Full, self).__init__() self.unet1 = UNet1(3, 3, deconv=True) self.unet2 = UNet2(3, 3, deconv=False) def forward(self, x): n, c, h0, w0 = x.shape ph = ((h0 - 1) // 2 + 1) * 2 pw = ((w0 - 1) // 2 + 1) * 2 x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') # In order to ensure that it can be divided by 2 x1 = self.unet1(x) x2 = self.unet2(x1) x1 = F.pad(x1, (-20, -20, -20, -20)) output = torch.add(x2, x1) if (w0 != pw or h0 != ph): output = output[:, :, :h0 * 2, :w0 * 2] return output class SEBlock(nn.Module): def __init__(self, in_channels, reduction=8, bias=False): super(SEBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias) self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias) def forward(self, x): if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half() else: x0 = torch.mean(x, dim=(2, 3), keepdim=True) x0 = self.conv1(x0) x0 = F.relu(x0, inplace=True) x0 = self.conv2(x0) x0 = torch.sigmoid(x0) x = torch.mul(x, x0) return x class UNetConv(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, se): super(UNetConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, 3, 1, 0), nn.LeakyReLU(0.1, inplace=True), nn.Conv2d(mid_channels, out_channels, 3, 1, 0), nn.LeakyReLU(0.1, inplace=True), ) if se: self.seblock = SEBlock(out_channels, reduction=8, bias=True) else: self.seblock = None def forward(self, x): z = self.conv(x) if self.seblock is not None: z = self.seblock(z) return z class UNet1(nn.Module): def __init__(self, in_channels, out_channels, deconv): super(UNet1, self).__init__() self.conv1 = UNetConv(in_channels, 32, 64, se=False) self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) self.conv2 = UNetConv(64, 128, 64, se=True) self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) if deconv: self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) else: self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2(x2) x2 = self.conv2_up(x2) x2 = F.leaky_relu(x2, 0.1, inplace=True) x1 = F.pad(x1, (-4, -4, -4, -4)) x3 = self.conv3(x1 + x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) z = self.conv_bottom(x3) return z class UNet2(nn.Module): def __init__(self, in_channels, out_channels, deconv): super(UNet2, self).__init__() self.conv1 = UNetConv(in_channels, 32, 64, se=False) self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) self.conv2 = UNetConv(64, 64, 128, se=True) self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0) self.conv3 = UNetConv(128, 256, 128, se=True) self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0) self.conv4 = UNetConv(128, 64, 64, se=True) self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) self.conv5 = nn.Conv2d(64, 64, 3, 1, 0) if deconv: self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) else: self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2(x2) x3 = self.conv2_down(x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) x3 = self.conv3(x3) x3 = self.conv3_up(x3) x3 = F.leaky_relu(x3, 0.1, inplace=True) x2 = F.pad(x2, (-4, -4, -4, -4)) x4 = self.conv4(x2 + x3) x4 = self.conv4_up(x4) x4 = F.leaky_relu(x4, 0.1, inplace=True) x1 = F.pad(x1, (-16, -16, -16, -16)) x5 = self.conv5(x1 + x4) x5 = F.leaky_relu(x5, 0.1, inplace=True) z = self.conv_bottom(x5) return z def main(): root_path = os.path.abspath('.') sys.path.append(root_path) from opt import opt # Manage GPU to choose import time model = UNet_Full().cuda() pytorch_total_params = sum(p.numel() for p in model.parameters()) print(f"CuNet has param {pytorch_total_params//1000} K params") # Count the number of FLOPs to double check x = torch.randn((1, 3, 180, 180)).cuda() start = time.time() x = model(x) print("output size is ", x.shape) total = time.time() - start print(total) if __name__ == "__main__": main()