test / modules /postprocess /realesrgan_model_arch.py
bilegentile's picture
Upload folder using huggingface_hub
c19ca42 verified
import os
import math
import queue
import threading
import cv2
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn, TimeElapsedColumn
from modules import devices
from modules.shared import log, console
from modules.upscaler import compile_upscaler
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer():
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self,
name,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.name = name
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
if gpu_id:
self.device = torch.device(
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
else:
self.device = devices.device_esrgan if device is None else device
if isinstance(model_path, list):
# dni
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
else:
# if the model_path starts with https, it will first download models to the folder: weights
if model_path.startswith('https://'):
from modules.modelloader import load_file_from_url
model_path = load_file_from_url(url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
log.info(f"Upscaler loaded: type={self.name} model={model_path}")
# prefer to use params_ema
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
if self.half:
model = model.half()
self.model = model.to(self.device)
self.model = compile_upscaler(self.model)
def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
"""Deep network interpolation.
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
"""
net_a = torch.load(net_a, map_location=torch.device(loc))
net_b = torch.load(net_b, map_location=torch.device(loc))
for k, v_a in net_a[key].items():
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
return net_a
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
task = progress.add_task(description="Upscaling", total=tiles_y * tiles_x)
with torch.no_grad():
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1 # noqa
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
output_tile = self.model(input_tile)
except Exception as e:
log.error(f'Upscale error: type=R-ESRGAN {e}')
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
progress.update(task, advance=1, description="Upscaling")
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == 'quit':
break
output = msg['output']
save_path = msg['save_path']
cv2.imwrite(save_path, output)
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
It is a compact network structure, which performs upsampling in the last layer and no convolution is
conducted on the HR feature space.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 16.
upscale (int): Upsampling factor. Default: 4.
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out