########################################################################################### | |
# Filename: realsrgan.py | |
# Description: Upscale images using the trained REALESRGAN model. | |
########################################################################################### | |
# | |
# Import libraries. | |
# | |
# Import OpenCV library for image processing. | |
import cv2 | |
# Import the math module for mathematical operations. | |
import math | |
# Import NumPy for numerical operations on arrays. | |
import numpy as np | |
# Import the os module for operating system functionalities. | |
import os | |
# Import the queue module for implementing queues. | |
import queue | |
# Import the threading module for multi-threading support. | |
import threading | |
# Import PyTorch for deep learning. | |
import torch | |
# Import a utility function for downloading files. | |
from basicsr.utils.download_util import load_file_from_url | |
# Import functional module from PyTorch's neural network library. | |
from torch.nn import functional as F | |
########################################################################################### | |
# Define the root directory. | |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
########################################################################################### | |
class RealEsrGan: | |
def __init__( | |
self, | |
scale, # Upsampling scale factor used in the networks. | |
model_path, # The path to the pretrained model. | |
dni_weight=None, # Performing the interpolation between two networks. | |
model=None, # The pretained model weights. | |
pre_pad=10, # Pad the input images to avoid border artifacts. | |
half=False, # Whether to use half precision during inference or not. | |
device=None, # What device to run inference on. cpu or cuda. | |
gpu_id=None, # ID of GPU to be used if there are more than one GPUs. | |
): | |
self.scale = scale | |
self.model_path = model_path | |
self.dni_weight = dni_weight | |
self.model = model | |
self.pre_pad = pre_pad | |
self.half = half | |
self.device = device | |
self.gpu_id = gpu_id | |
self.mod_scale = None | |
# Initialize device based on GPU availability and user preference. | |
if self.gpu_id: | |
self.device = ( | |
torch.device( | |
f"cuda:{self.gpu_id}" if torch.cuda.is_available() else "cpu" | |
) | |
if self.device is None | |
else self.device | |
) | |
else: | |
self.device = ( | |
torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if self.device is None | |
else self.device | |
) | |
# Load the RealESRGAN model from the specified path or URL. | |
if isinstance(self.model_path, list): | |
assert len(self.model_path) == len(self.dni_weight) | |
loadnet = self.dni(self.model_path[0], self.model_path[1], self.dni_weight) | |
else: | |
# Download model if model path is a URL. | |
if self.model_path.startswith("https://"): | |
self.model_path = load_file_from_url( | |
url=model_path, | |
model_dir=os.path.join(ROOT_DIR, "weights"), | |
progress=True, | |
file_name=None, | |
) | |
loadnet = torch.load(model_path, map_location=torch.device("cpu")) | |
# Use params_ema if available, otherwise use params. | |
if "params_ema" in loadnet: | |
keyname = "params_ema" | |
else: | |
keyname = "params" | |
# Load model weights. | |
model.load_state_dict(loadnet[keyname], strict=True) | |
# Put the model in evaluation mode. | |
model.eval() | |
# Move the model to the specified device. | |
self.model = model.to(self.device) | |
if self.half: | |
self.model = self.model.half() | |
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"): | |
# Define a method for Domain-Adversarial Neural Interface (DNI). | |
# Load the parameters of neural network A from a file, considering the specified device location. | |
net_a = torch.load(net_a, map_location=torch.device(loc)) | |
# Load the parameters of neural network B from a file, considering the specified device location. | |
net_b = torch.load(net_b, map_location=torch.device(loc)) | |
# Iterate over each key-value pair in the parameters of neural network A. | |
for k, v_a in net_a[key].items(): | |
# Update the parameters of neural network A using a weighted combination | |
# of its own parameters and those of neural network B. | |
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] | |
# Return the updated model. | |
return net_a | |
def pre_process(self, img): | |
# Convert image to PyTorch tensor and adjust dimensions. | |
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() | |
# Add a batch dimension and move the tensor to the specified device. | |
self.img = img.unsqueeze(0).to(self.device) | |
# If half precision is enabled, convert the tensor to half precision. | |
if self.half: | |
self.img = self.img.half() | |
# Apply reflective padding to the image if pre_pad is not zero. | |
if self.pre_pad != 0: | |
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect") | |
# Set mod_scale based on the scale factor. | |
if self.scale == 2: | |
self.mod_scale = 2 | |
elif self.scale == 1: | |
self.mod_scale = 4 | |
# Check if mod_scale is specified and perform padding accordingly. | |
if self.mod_scale is not None: | |
self.mod_pad_h, self.mod_pad_w = 0, 0 | |
_, _, h, w = self.img.size() | |
# Calculate padding required to make dimensions divisible by mod_scale. | |
if h % self.mod_scale != 0: | |
self.mod_pad_h = self.mod_scale - h % self.mod_scale | |
if w % self.mod_scale != 0: | |
self.mod_pad_w = self.mod_scale - w % self.mod_scale | |
# Apply reflective padding to the image based on mod_pad_h and mod_pad_w. | |
self.img = F.pad( | |
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect" | |
) | |
def process(self): | |
# Process/inference on the image. | |
self.output = self.model(self.img) | |
def post_process(self): | |
# Check if a modification scale is specified. | |
if self.mod_scale is not None: | |
# Get the height and width of the output tensor. | |
_, _, h, w = self.output.size() | |
# Crop the output tensor based on the specified modification scale and padding | |
self.output = self.output[ | |
:, | |
:, | |
0 : h - self.mod_pad_h * self.scale, | |
0 : w - self.mod_pad_w * self.scale, | |
] | |
# Check if there is pre-padding applied. | |
if self.pre_pad != 0: | |
# Get the height and width of the output tensor. | |
_, _, h, w = self.output.size() | |
# Crop the output tensor based on the specified pre-padding. | |
self.output = self.output[ | |
:, | |
:, | |
0 : h - self.pre_pad * self.scale, | |
0 : w - self.pre_pad * self.scale, | |
] | |
# Return the processed output tensor after modification and cropping. | |
return self.output | |
def enhance(self, img, upscale=None, alpha_upsampler="realesrgan"): | |
# Get the height and width of the input image. | |
h_input, w_input = img.shape[0:2] | |
img = img.astype(np.float32) | |
# Determine if the input image is 16-bit. | |
if np.max(img) > 256: | |
max_range = 65535 | |
print("\tInput is a 16-bit image") | |
else: | |
max_range = 255 | |
# Normalize the image to the range [0, 1]. | |
img = img / max_range | |
# Identify the image mode based on its number of channels. | |
if len(img.shape) == 2: | |
img_mode = "L" # Gray image. | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
elif img.shape[2] == 4: # RGBA image with alpha channel | |
img_mode = "RGBA" # RGBA image with alpha channel. | |
alpha = img[:, :, 3] | |
img = img[:, :, 0:3] | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# Convert alpha channel to RGB if using realesrgan alpha upsampling. | |
if alpha_upsampler == "realesrgan": | |
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) | |
else: | |
img_mode = "RGB" # RGB image. | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# Pre-process the image using a method not provided in the code. | |
self.pre_process(img) | |
# Process the image. | |
self.process() | |
# Post-process the image and retrieve the enhanced output. | |
output_img = self.post_process() | |
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() | |
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) | |
# Convert output image back to grayscale if the original image was grayscale. | |
if img_mode == "L": | |
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) | |
# Process alpha channel if the original image had RGBA mode. | |
if img_mode == "RGBA": | |
# Check if RealESRGAN should be used for alpha channel upsampling. | |
if alpha_upsampler == "realesrgan": | |
# Pre-process the alpha channel using a method not provided in this code. | |
self.pre_process(alpha) | |
# Process the image. | |
self.process() | |
# Post-process the alpha channel and retrieve the enhanced output. | |
output_alpha = self.post_process() | |
# Convert the alpha channel output to a NumPy array in the range [0, 1]. | |
output_alpha = ( | |
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() | |
) | |
# Transpose the alpha channel array for proper channel ordering. | |
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) | |
# Convert the alpha channel to grayscale. | |
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) | |
else: | |
# Resize the alpha channel using linear interpolation if not using realesrgan. | |
h, w = alpha.shape[0:2] | |
output_alpha = cv2.resize( | |
alpha, | |
(w * self.scale, h * self.scale), | |
interpolation=cv2.INTER_LINEAR, | |
) | |
# Convert output image to BGRA format and assign the processed alpha channel. | |
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) | |
output_img[:, :, 3] = output_alpha | |
# Scale the output image back to the original size if specified. | |
if max_range == 65535: | |
output = (output_img * 65535.0).round().astype(np.uint16) | |
else: | |
output = (output_img * 255.0).round().astype(np.uint8) | |
# Resize the output image if a different scale is specified. | |
if upscale is not None and upscale != float(self.scale): | |
output = cv2.resize( | |
output, | |
( | |
int(w_input * upscale), | |
int(h_input * upscale), | |
), | |
interpolation=cv2.INTER_LANCZOS4, | |
) | |
return output, img_mode | |
########################################################################################### | |