realesrgan / inference /real_esrgan.py
GeorgiosIoannouCoder's picture
Create inference directory
9ab5f71 verified
raw
history blame
11.8 kB
###########################################################################################
# Filename: realsrgan.py
# Description: Upscale images using the trained REALESRGAN model.
###########################################################################################
#
# Import libraries.
#
# Import OpenCV library for image processing.
import cv2
# Import the math module for mathematical operations.
import math
# Import NumPy for numerical operations on arrays.
import numpy as np
# Import the os module for operating system functionalities.
import os
# Import the queue module for implementing queues.
import queue
# Import the threading module for multi-threading support.
import threading
# Import PyTorch for deep learning.
import torch
# Import a utility function for downloading files.
from basicsr.utils.download_util import load_file_from_url
# Import functional module from PyTorch's neural network library.
from torch.nn import functional as F
###########################################################################################
# Define the root directory.
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
###########################################################################################
class RealEsrGan:
def __init__(
self,
scale, # Upsampling scale factor used in the networks.
model_path, # The path to the pretrained model.
dni_weight=None, # Performing the interpolation between two networks.
model=None, # The pretained model weights.
pre_pad=10, # Pad the input images to avoid border artifacts.
half=False, # Whether to use half precision during inference or not.
device=None, # What device to run inference on. cpu or cuda.
gpu_id=None, # ID of GPU to be used if there are more than one GPUs.
):
self.scale = scale
self.model_path = model_path
self.dni_weight = dni_weight
self.model = model
self.pre_pad = pre_pad
self.half = half
self.device = device
self.gpu_id = gpu_id
self.mod_scale = None
# Initialize device based on GPU availability and user preference.
if self.gpu_id:
self.device = (
torch.device(
f"cuda:{self.gpu_id}" if torch.cuda.is_available() else "cpu"
)
if self.device is None
else self.device
)
else:
self.device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device is None
else self.device
)
# Load the RealESRGAN model from the specified path or URL.
if isinstance(self.model_path, list):
assert len(self.model_path) == len(self.dni_weight)
loadnet = self.dni(self.model_path[0], self.model_path[1], self.dni_weight)
else:
# Download model if model path is a URL.
if self.model_path.startswith("https://"):
self.model_path = load_file_from_url(
url=model_path,
model_dir=os.path.join(ROOT_DIR, "weights"),
progress=True,
file_name=None,
)
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# Use params_ema if available, otherwise use params.
if "params_ema" in loadnet:
keyname = "params_ema"
else:
keyname = "params"
# Load model weights.
model.load_state_dict(loadnet[keyname], strict=True)
# Put the model in evaluation mode.
model.eval()
# Move the model to the specified device.
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def dni(self, net_a, net_b, dni_weight, key="params", loc="cpu"):
# Define a method for Domain-Adversarial Neural Interface (DNI).
# Load the parameters of neural network A from a file, considering the specified device location.
net_a = torch.load(net_a, map_location=torch.device(loc))
# Load the parameters of neural network B from a file, considering the specified device location.
net_b = torch.load(net_b, map_location=torch.device(loc))
# Iterate over each key-value pair in the parameters of neural network A.
for k, v_a in net_a[key].items():
# Update the parameters of neural network A using a weighted combination
# of its own parameters and those of neural network B.
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
# Return the updated model.
return net_a
def pre_process(self, img):
# Convert image to PyTorch tensor and adjust dimensions.
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
# Add a batch dimension and move the tensor to the specified device.
self.img = img.unsqueeze(0).to(self.device)
# If half precision is enabled, convert the tensor to half precision.
if self.half:
self.img = self.img.half()
# Apply reflective padding to the image if pre_pad is not zero.
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
# Set mod_scale based on the scale factor.
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
# Check if mod_scale is specified and perform padding accordingly.
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
# Calculate padding required to make dimensions divisible by mod_scale.
if h % self.mod_scale != 0:
self.mod_pad_h = self.mod_scale - h % self.mod_scale
if w % self.mod_scale != 0:
self.mod_pad_w = self.mod_scale - w % self.mod_scale
# Apply reflective padding to the image based on mod_pad_h and mod_pad_w.
self.img = F.pad(
self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect"
)
def process(self):
# Process/inference on the image.
self.output = self.model(self.img)
def post_process(self):
# Check if a modification scale is specified.
if self.mod_scale is not None:
# Get the height and width of the output tensor.
_, _, h, w = self.output.size()
# Crop the output tensor based on the specified modification scale and padding
self.output = self.output[
:,
:,
0 : h - self.mod_pad_h * self.scale,
0 : w - self.mod_pad_w * self.scale,
]
# Check if there is pre-padding applied.
if self.pre_pad != 0:
# Get the height and width of the output tensor.
_, _, h, w = self.output.size()
# Crop the output tensor based on the specified pre-padding.
self.output = self.output[
:,
:,
0 : h - self.pre_pad * self.scale,
0 : w - self.pre_pad * self.scale,
]
# Return the processed output tensor after modification and cropping.
return self.output
def enhance(self, img, upscale=None, alpha_upsampler="realesrgan"):
# Get the height and width of the input image.
h_input, w_input = img.shape[0:2]
img = img.astype(np.float32)
# Determine if the input image is 16-bit.
if np.max(img) > 256:
max_range = 65535
print("\tInput is a 16-bit image")
else:
max_range = 255
# Normalize the image to the range [0, 1].
img = img / max_range
# Identify the image mode based on its number of channels.
if len(img.shape) == 2:
img_mode = "L" # Gray image.
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = "RGBA" # RGBA image with alpha channel.
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Convert alpha channel to RGB if using realesrgan alpha upsampling.
if alpha_upsampler == "realesrgan":
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = "RGB" # RGB image.
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Pre-process the image using a method not provided in the code.
self.pre_process(img)
# Process the image.
self.process()
# Post-process the image and retrieve the enhanced output.
output_img = self.post_process()
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
# Convert output image back to grayscale if the original image was grayscale.
if img_mode == "L":
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
# Process alpha channel if the original image had RGBA mode.
if img_mode == "RGBA":
# Check if RealESRGAN should be used for alpha channel upsampling.
if alpha_upsampler == "realesrgan":
# Pre-process the alpha channel using a method not provided in this code.
self.pre_process(alpha)
# Process the image.
self.process()
# Post-process the alpha channel and retrieve the enhanced output.
output_alpha = self.post_process()
# Convert the alpha channel output to a NumPy array in the range [0, 1].
output_alpha = (
output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
)
# Transpose the alpha channel array for proper channel ordering.
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
# Convert the alpha channel to grayscale.
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else:
# Resize the alpha channel using linear interpolation if not using realesrgan.
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(
alpha,
(w * self.scale, h * self.scale),
interpolation=cv2.INTER_LINEAR,
)
# Convert output image to BGRA format and assign the processed alpha channel.
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# Scale the output image back to the original size if specified.
if max_range == 65535:
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
# Resize the output image if a different scale is specified.
if upscale is not None and upscale != float(self.scale):
output = cv2.resize(
output,
(
int(w_input * upscale),
int(h_input * upscale),
),
interpolation=cv2.INTER_LANCZOS4,
)
return output, img_mode
###########################################################################################