import os from enum import IntEnum from pathlib import Path from tempfile import mktemp from typing import IO, Dict, Type import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from gradio import Interface, inputs, outputs DEVICE = "cpu" WEIGHTS_PATH = Path(__file__).parent / "weights" AVALIABLE_WEIGHTS = { basename: path for basename, ext in ( os.path.splitext(filename) for filename in os.listdir(WEIGHTS_PATH) ) if (path := WEIGHTS_PATH / (basename + ext)).is_file() and ext.endswith("pth") } class ScaleMode(IntEnum): up2x = 2 up3x = 3 up4x = 4 class TileMode(IntEnum): full = 0 half = 1 quarter = 2 ninth = 3 sixteenth = 4 class SEBlock(nn.Module): def __init__(self, in_channels, reduction=8, bias=False): super(SEBlock, self).__init__() self.conv1 = nn.Conv2d( in_channels, in_channels // reduction, 1, 1, 0, bias=bias ) self.conv2 = nn.Conv2d( in_channels // reduction, in_channels, 1, 1, 0, bias=bias ) def forward(self, x): if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half() else: x0 = torch.mean(x, dim=(2, 3), keepdim=True) x0 = self.conv1(x0) x0 = F.relu(x0, inplace=True) x0 = self.conv2(x0) x0 = torch.sigmoid(x0) x = torch.mul(x, x0) return x def forward_mean(self, x, x0): x0 = self.conv1(x0) x0 = F.relu(x0, inplace=True) x0 = self.conv2(x0) x0 = torch.sigmoid(x0) x = torch.mul(x, x0) return x class UNetConv(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, se): super(UNetConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, 3, 1, 0), nn.LeakyReLU(0.1, inplace=True), nn.Conv2d(mid_channels, out_channels, 3, 1, 0), nn.LeakyReLU(0.1, inplace=True), ) if se: self.seblock = SEBlock(out_channels, reduction=8, bias=True) else: self.seblock = None def forward(self, x): z = self.conv(x) if self.seblock is not None: z = self.seblock(z) return z class UNet1(nn.Module): def __init__(self, in_channels, out_channels, deconv): super(UNet1, self).__init__() self.conv1 = UNetConv(in_channels, 32, 64, se=False) self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) self.conv2 = UNetConv(64, 128, 64, se=True) self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) if deconv: self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) else: self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2(x2) x2 = self.conv2_up(x2) x2 = F.leaky_relu(x2, 0.1, inplace=True) x1 = F.pad(x1, (-4, -4, -4, -4)) x3 = self.conv3(x1 + x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) z = self.conv_bottom(x3) return z def forward_a(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2.conv(x2) return x1, x2 def forward_b(self, x1, x2): x2 = self.conv2_up(x2) x2 = F.leaky_relu(x2, 0.1, inplace=True) x1 = F.pad(x1, (-4, -4, -4, -4)) x3 = self.conv3(x1 + x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) z = self.conv_bottom(x3) return z class UNet1x3(nn.Module): def __init__(self, in_channels, out_channels, deconv): super(UNet1x3, self).__init__() self.conv1 = UNetConv(in_channels, 32, 64, se=False) self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) self.conv2 = UNetConv(64, 128, 64, se=True) self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) if deconv: self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 5, 3, 2) else: self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2(x2) x2 = self.conv2_up(x2) x2 = F.leaky_relu(x2, 0.1, inplace=True) x1 = F.pad(x1, (-4, -4, -4, -4)) x3 = self.conv3(x1 + x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) z = self.conv_bottom(x3) return z def forward_a(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2.conv(x2) return x1, x2 def forward_b(self, x1, x2): x2 = self.conv2_up(x2) x2 = F.leaky_relu(x2, 0.1, inplace=True) x1 = F.pad(x1, (-4, -4, -4, -4)) x3 = self.conv3(x1 + x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) z = self.conv_bottom(x3) return z class UNet2(nn.Module): def __init__(self, in_channels, out_channels, deconv): super(UNet2, self).__init__() self.conv1 = UNetConv(in_channels, 32, 64, se=False) self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0) self.conv2 = UNetConv(64, 64, 128, se=True) self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0) self.conv3 = UNetConv(128, 256, 128, se=True) self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0) self.conv4 = UNetConv(128, 64, 64, se=True) self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0) self.conv5 = nn.Conv2d(64, 64, 3, 1, 0) if deconv: self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3) else: self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0) for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2(x2) x3 = self.conv2_down(x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) x3 = self.conv3(x3) x3 = self.conv3_up(x3) x3 = F.leaky_relu(x3, 0.1, inplace=True) x2 = F.pad(x2, (-4, -4, -4, -4)) x4 = self.conv4(x2 + x3) x4 = self.conv4_up(x4) x4 = F.leaky_relu(x4, 0.1, inplace=True) x1 = F.pad(x1, (-16, -16, -16, -16)) x5 = self.conv5(x1 + x4) x5 = F.leaky_relu(x5, 0.1, inplace=True) z = self.conv_bottom(x5) return z def forward_a(self, x): # conv234结尾有se x1 = self.conv1(x) x2 = self.conv1_down(x1) x2 = F.leaky_relu(x2, 0.1, inplace=True) x2 = self.conv2.conv(x2) return x1, x2 def forward_b(self, x2): # conv234结尾有se x3 = self.conv2_down(x2) x3 = F.leaky_relu(x3, 0.1, inplace=True) x3 = self.conv3.conv(x3) return x3 def forward_c(self, x2, x3): # conv234结尾有se x3 = self.conv3_up(x3) x3 = F.leaky_relu(x3, 0.1, inplace=True) x2 = F.pad(x2, (-4, -4, -4, -4)) x4 = self.conv4.conv(x2 + x3) return x4 def forward_d(self, x1, x4): # conv234结尾有se x4 = self.conv4_up(x4) x4 = F.leaky_relu(x4, 0.1, inplace=True) x1 = F.pad(x1, (-16, -16, -16, -16)) x5 = self.conv5(x1 + x4) x5 = F.leaky_relu(x5, 0.1, inplace=True) z = self.conv_bottom(x5) return z class UpCunet2x(nn.Module): # 完美tile,全程无损 def __init__(self, in_channels=3, out_channels=3): super(UpCunet2x, self).__init__() self.unet1 = UNet1(in_channels, out_channels, deconv=True) self.unet2 = UNet2(in_channels, out_channels, deconv=False) def forward(self, x, tile_mode): # 1.7G n, c, h0, w0 = x.shape if tile_mode == 0: # 不tile ph = ((h0 - 1) // 2 + 1) * 2 pw = ((w0 - 1) // 2 + 1) * 2 x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect") # 需要保证被2整除 x = self.unet1.forward(x) x0 = self.unet2.forward(x) x1 = F.pad(x, (-20, -20, -20, -20)) x = torch.add(x0, x1) if w0 != pw or h0 != ph: x = x[:, :, : h0 * 2, : w0 * 2] return x elif tile_mode == 1: # 对长边减半 if w0 >= h0: crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除 crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除 else: crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除 crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除 crop_size = (crop_size_h, crop_size_w) # 6.6G elif tile_mode == 2: # hw都减半 crop_size = ( ((h0 - 1) // 4 * 4 + 4) // 2, ((w0 - 1) // 4 * 4 + 4) // 2, ) # 5.6G elif tile_mode == 3: # hw都三分之一 crop_size = ( ((h0 - 1) // 6 * 6 + 6) // 3, ((w0 - 1) // 6 * 6 + 6) // 3, ) # 4.2G elif tile_mode == 4: # hw都四分之一 crop_size = ( ((h0 - 1) // 8 * 8 + 8) // 4, ((w0 - 1) // 8 * 8 + 8) // 4, ) # 3.7G ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect") n, c, h, w = x.shape se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device) if "Half" in x.type(): se_mean0 = se_mean0.half() n_patch = 0 tmp_dict = {} opt_res_dict = {} for i in range(0, h - 36, crop_size[0]): tmp_dict[i] = {} for j in range(0, w - 36, crop_size[1]): x_crop = x[:, :, i : i + crop_size[0] + 36, j : j + crop_size[1] + 36] n, c1, h1, w1 = x_crop.shape tmp0, x_crop = self.unet1.forward_a(x_crop) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( x_crop.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True) se_mean0 += tmp_se_mean n_patch += 1 tmp_dict[i][j] = (tmp0, x_crop) se_mean0 /= n_patch se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean1 = se_mean1.half() for i in range(0, h - 36, crop_size[0]): for j in range(0, w - 36, crop_size[1]): tmp0, x_crop = tmp_dict[i][j] x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0) opt_unet1 = self.unet1.forward_b(tmp0, x_crop) tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x2.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True) se_mean1 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2) se_mean1 /= n_patch se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean0 = se_mean0.half() for i in range(0, h - 36, crop_size[0]): for j in range(0, w - 36, crop_size[1]): opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j] tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1) tmp_x3 = self.unet2.forward_b(tmp_x2) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x3.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True) se_mean0 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3) se_mean0 /= n_patch se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean1 = se_mean1.half() for i in range(0, h - 36, crop_size[0]): for j in range(0, w - 36, crop_size[1]): opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j] tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0) tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x4.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True) se_mean1 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4) se_mean1 /= n_patch for i in range(0, h - 36, crop_size[0]): opt_res_dict[i] = {} for j in range(0, w - 36, crop_size[1]): opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j] tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1) x0 = self.unet2.forward_d(tmp_x1, tmp_x4) x1 = F.pad(opt_unet1, (-20, -20, -20, -20)) x_crop = torch.add(x0, x1) # x0是unet2的最终输出 opt_res_dict[i][j] = x_crop del tmp_dict torch.cuda.empty_cache() res = torch.zeros((n, c, h * 2 - 72, w * 2 - 72)).to(x.device) if "Half" in x.type(): res = res.half() for i in range(0, h - 36, crop_size[0]): for j in range(0, w - 36, crop_size[1]): res[ :, :, i * 2 : i * 2 + h1 * 2 - 72, j * 2 : j * 2 + w1 * 2 - 72 ] = opt_res_dict[i][j] del opt_res_dict torch.cuda.empty_cache() if w0 != pw or h0 != ph: res = res[:, :, : h0 * 2, : w0 * 2] return res # class UpCunet3x(nn.Module): # 完美tile,全程无损 def __init__(self, in_channels=3, out_channels=3): super(UpCunet3x, self).__init__() self.unet1 = UNet1x3(in_channels, out_channels, deconv=True) self.unet2 = UNet2(in_channels, out_channels, deconv=False) def forward(self, x, tile_mode): # 1.7G n, c, h0, w0 = x.shape if tile_mode == 0: # 不tile ph = ((h0 - 1) // 4 + 1) * 4 pw = ((w0 - 1) // 4 + 1) * 4 x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect") # 需要保证被2整除 x = self.unet1.forward(x) x0 = self.unet2.forward(x) x1 = F.pad(x, (-20, -20, -20, -20)) x = torch.add(x0, x1) if w0 != pw or h0 != ph: x = x[:, :, : h0 * 3, : w0 * 3] return x elif tile_mode == 1: # 对长边减半 if w0 >= h0: crop_size_w = ((w0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除 crop_size_h = (h0 - 1) // 4 * 4 + 4 # 能被4整除 else: crop_size_h = ((h0 - 1) // 8 * 8 + 8) // 2 # 减半后能被4整除,所以要先被8整除 crop_size_w = (w0 - 1) // 4 * 4 + 4 # 能被4整除 crop_size = (crop_size_h, crop_size_w) # 6.6G elif tile_mode == 2: # hw都减半 crop_size = ( ((h0 - 1) // 8 * 8 + 8) // 2, ((w0 - 1) // 8 * 8 + 8) // 2, ) # 5.6G elif tile_mode == 3: # hw都三分之一 crop_size = ( ((h0 - 1) // 12 * 12 + 12) // 3, ((w0 - 1) // 12 * 12 + 12) // 3, ) # 4.2G elif tile_mode == 4: # hw都四分之一 crop_size = ( ((h0 - 1) // 16 * 16 + 16) // 4, ((w0 - 1) // 16 * 16 + 16) // 4, ) # 3.7G ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect") n, c, h, w = x.shape se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device) if "Half" in x.type(): se_mean0 = se_mean0.half() n_patch = 0 tmp_dict = {} opt_res_dict = {} for i in range(0, h - 28, crop_size[0]): tmp_dict[i] = {} for j in range(0, w - 28, crop_size[1]): x_crop = x[:, :, i : i + crop_size[0] + 28, j : j + crop_size[1] + 28] n, c1, h1, w1 = x_crop.shape tmp0, x_crop = self.unet1.forward_a(x_crop) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( x_crop.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True) se_mean0 += tmp_se_mean n_patch += 1 tmp_dict[i][j] = (tmp0, x_crop) se_mean0 /= n_patch se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean1 = se_mean1.half() for i in range(0, h - 28, crop_size[0]): for j in range(0, w - 28, crop_size[1]): tmp0, x_crop = tmp_dict[i][j] x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0) opt_unet1 = self.unet1.forward_b(tmp0, x_crop) tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x2.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True) se_mean1 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2) se_mean1 /= n_patch se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean0 = se_mean0.half() for i in range(0, h - 28, crop_size[0]): for j in range(0, w - 28, crop_size[1]): opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j] tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1) tmp_x3 = self.unet2.forward_b(tmp_x2) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x3.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True) se_mean0 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3) se_mean0 /= n_patch se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean1 = se_mean1.half() for i in range(0, h - 28, crop_size[0]): for j in range(0, w - 28, crop_size[1]): opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j] tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0) tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x4.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True) se_mean1 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4) se_mean1 /= n_patch for i in range(0, h - 28, crop_size[0]): opt_res_dict[i] = {} for j in range(0, w - 28, crop_size[1]): opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j] tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1) x0 = self.unet2.forward_d(tmp_x1, tmp_x4) x1 = F.pad(opt_unet1, (-20, -20, -20, -20)) x_crop = torch.add(x0, x1) # x0是unet2的最终输出 opt_res_dict[i][j] = x_crop # del tmp_dict torch.cuda.empty_cache() res = torch.zeros((n, c, h * 3 - 84, w * 3 - 84)).to(x.device) if "Half" in x.type(): res = res.half() for i in range(0, h - 28, crop_size[0]): for j in range(0, w - 28, crop_size[1]): res[ :, :, i * 3 : i * 3 + h1 * 3 - 84, j * 3 : j * 3 + w1 * 3 - 84 ] = opt_res_dict[i][j] del opt_res_dict torch.cuda.empty_cache() if w0 != pw or h0 != ph: res = res[:, :, : h0 * 3, : w0 * 3] return res class UpCunet4x(nn.Module): # 完美tile,全程无损 def __init__(self, in_channels=3, out_channels=3): super(UpCunet4x, self).__init__() self.unet1 = UNet1(in_channels, 64, deconv=True) self.unet2 = UNet2(64, 64, deconv=False) self.ps = nn.PixelShuffle(2) self.conv_final = nn.Conv2d(64, 12, 3, 1, padding=0, bias=True) def forward(self, x, tile_mode): n, c, h0, w0 = x.shape x00 = x if tile_mode == 0: # 不tile ph = ((h0 - 1) // 2 + 1) * 2 pw = ((w0 - 1) // 2 + 1) * 2 x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect") # 需要保证被2整除 x = self.unet1.forward(x) x0 = self.unet2.forward(x) x1 = F.pad(x, (-20, -20, -20, -20)) x = torch.add(x0, x1) x = self.conv_final(x) x = F.pad(x, (-1, -1, -1, -1)) x = self.ps(x) if w0 != pw or h0 != ph: x = x[:, :, : h0 * 4, : w0 * 4] x += F.interpolate(x00, scale_factor=4, mode="nearest") return x elif tile_mode == 1: # 对长边减半 if w0 >= h0: crop_size_w = ((w0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除 crop_size_h = (h0 - 1) // 2 * 2 + 2 # 能被2整除 else: crop_size_h = ((h0 - 1) // 4 * 4 + 4) // 2 # 减半后能被2整除,所以要先被4整除 crop_size_w = (w0 - 1) // 2 * 2 + 2 # 能被2整除 crop_size = (crop_size_h, crop_size_w) # 6.6G elif tile_mode == 2: # hw都减半 crop_size = ( ((h0 - 1) // 4 * 4 + 4) // 2, ((w0 - 1) // 4 * 4 + 4) // 2, ) # 5.6G elif tile_mode == 3: # hw都三分之一 crop_size = ( ((h0 - 1) // 6 * 6 + 6) // 3, ((w0 - 1) // 6 * 6 + 6) // 3, ) # 4.1G elif tile_mode == 4: # hw都四分之一 crop_size = ( ((h0 - 1) // 8 * 8 + 8) // 4, ((w0 - 1) // 8 * 8 + 8) // 4, ) # 3.7G ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect") n, c, h, w = x.shape se_mean0 = torch.zeros((n, 64, 1, 1)).to(x.device) if "Half" in x.type(): se_mean0 = se_mean0.half() n_patch = 0 tmp_dict = {} opt_res_dict = {} for i in range(0, h - 38, crop_size[0]): tmp_dict[i] = {} for j in range(0, w - 38, crop_size[1]): x_crop = x[:, :, i : i + crop_size[0] + 38, j : j + crop_size[1] + 38] n, c1, h1, w1 = x_crop.shape tmp0, x_crop = self.unet1.forward_a(x_crop) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( x_crop.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(x_crop, dim=(2, 3), keepdim=True) se_mean0 += tmp_se_mean n_patch += 1 tmp_dict[i][j] = (tmp0, x_crop) se_mean0 /= n_patch se_mean1 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean1 = se_mean1.half() for i in range(0, h - 38, crop_size[0]): for j in range(0, w - 38, crop_size[1]): tmp0, x_crop = tmp_dict[i][j] x_crop = self.unet1.conv2.seblock.forward_mean(x_crop, se_mean0) opt_unet1 = self.unet1.forward_b(tmp0, x_crop) tmp_x1, tmp_x2 = self.unet2.forward_a(opt_unet1) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x2.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x2, dim=(2, 3), keepdim=True) se_mean1 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2) se_mean1 /= n_patch se_mean0 = torch.zeros((n, 128, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean0 = se_mean0.half() for i in range(0, h - 38, crop_size[0]): for j in range(0, w - 38, crop_size[1]): opt_unet1, tmp_x1, tmp_x2 = tmp_dict[i][j] tmp_x2 = self.unet2.conv2.seblock.forward_mean(tmp_x2, se_mean1) tmp_x3 = self.unet2.forward_b(tmp_x2) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x3.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x3, dim=(2, 3), keepdim=True) se_mean0 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x2, tmp_x3) se_mean0 /= n_patch se_mean1 = torch.zeros((n, 64, 1, 1)).to(x.device) # 64#128#128#64 if "Half" in x.type(): se_mean1 = se_mean1.half() for i in range(0, h - 38, crop_size[0]): for j in range(0, w - 38, crop_size[1]): opt_unet1, tmp_x1, tmp_x2, tmp_x3 = tmp_dict[i][j] tmp_x3 = self.unet2.conv3.seblock.forward_mean(tmp_x3, se_mean0) tmp_x4 = self.unet2.forward_c(tmp_x2, tmp_x3) if "Half" in x.type(): # torch.HalfTensor/torch.cuda.HalfTensor tmp_se_mean = torch.mean( tmp_x4.float(), dim=(2, 3), keepdim=True ).half() else: tmp_se_mean = torch.mean(tmp_x4, dim=(2, 3), keepdim=True) se_mean1 += tmp_se_mean tmp_dict[i][j] = (opt_unet1, tmp_x1, tmp_x4) se_mean1 /= n_patch for i in range(0, h - 38, crop_size[0]): opt_res_dict[i] = {} for j in range(0, w - 38, crop_size[1]): opt_unet1, tmp_x1, tmp_x4 = tmp_dict[i][j] tmp_x4 = self.unet2.conv4.seblock.forward_mean(tmp_x4, se_mean1) x0 = self.unet2.forward_d(tmp_x1, tmp_x4) x1 = F.pad(opt_unet1, (-20, -20, -20, -20)) x_crop = torch.add(x0, x1) # x0是unet2的最终输出 x_crop = self.conv_final(x_crop) x_crop = F.pad(x_crop, (-1, -1, -1, -1)) x_crop = self.ps(x_crop) opt_res_dict[i][j] = x_crop del tmp_dict torch.cuda.empty_cache() res = torch.zeros((n, c, h * 4 - 152, w * 4 - 152)).to(x.device) if "Half" in x.type(): res = res.half() for i in range(0, h - 38, crop_size[0]): for j in range(0, w - 38, crop_size[1]): # print(opt_res_dict[i][j].shape,res[:, :, i * 4:i * 4 + h1 * 4 - 144, j * 4:j * 4 + w1 * 4 - 144].shape) res[ :, :, i * 4 : i * 4 + h1 * 4 - 152, j * 4 : j * 4 + w1 * 4 - 152 ] = opt_res_dict[i][j] del opt_res_dict torch.cuda.empty_cache() if w0 != pw or h0 != ph: res = res[:, :, : h0 * 4, : w0 * 4] res += F.interpolate(x00, scale_factor=4, mode="nearest") return res # models: Dict[str, Type[nn.Module]] = { obj.__name__: obj for obj in globals().values() if isinstance(obj, type) and issubclass(obj, nn.Module) } class RealWaifuUpScaler: def __init__(self, scale: int, weight_path: str, half: bool, device: str): weight = torch.load(weight_path, map_location=device) self.model = models[f"UpCunet{scale}x"]() if half == True: self.model = self.model.half().to(device) else: self.model = self.model.to(device) self.model.load_state_dict(weight, strict=True) self.model.eval() self.half = half self.device = device def np2tensor(self, np_frame): if self.half == False: return ( torch.from_numpy(np.transpose(np_frame, (2, 0, 1))) .unsqueeze(0) .to(self.device) .float() / 255 ) else: return ( torch.from_numpy(np.transpose(np_frame, (2, 0, 1))) .unsqueeze(0) .to(self.device) .half() / 255 ) def tensor2np(self, tensor): if self.half == False: return np.transpose( (tensor.data.squeeze() * 255.0) .round() .clamp_(0, 255) .byte() .cpu() .numpy(), (1, 2, 0), ) else: return np.transpose( (tensor.data.squeeze().float() * 255.0) .round() .clamp_(0, 255) .byte() .cpu() .numpy(), (1, 2, 0), ) def __call__(self, frame, tile_mode): with torch.no_grad(): tensor = self.np2tensor(frame) result = self.tensor2np(self.model(tensor, tile_mode)) return result input_image = inputs.File(label="Input image") half_precision = inputs.Checkbox( label="Half precision (NOT work for CPU)", default=False ) model_weight = inputs.Dropdown(sorted(AVALIABLE_WEIGHTS), label="Choice model weight") tile_mode = inputs.Radio([mode.name for mode in TileMode], label="Output tile mode") output_image = outputs.Image(label="Output image preview") output_file = outputs.File(label="Output image file") def main(file: IO[bytes], half: bool, weight: str, tile: str): scale = next(mode.value for mode in ScaleMode if weight.startswith(mode.name)) upscaler = RealWaifuUpScaler( scale, weight_path=str(AVALIABLE_WEIGHTS[weight]), half=half, device=DEVICE ) frame = cv2.imread(file.name) result = upscaler(frame[:, :, [2, 1, 0]], TileMode[tile]) _, ext = os.path.splitext(file.name) tempfile = mktemp(suffix=ext) cv2.imwrite(tempfile, result) return result, tempfile interface = Interface( main, inputs=[input_image, half_precision, model_weight, tile_mode], outputs=[output_image, output_file], ) interface.launch()