isr / upsample.py
dragonSwing's picture
Add application files
0fe2a53
import argparse
import cv2
import os
from imutils import paths
from tqdm import tqdm
from config import *
from utils import get_face_enhancer, get_upsampler
def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
if scale > 4:
scale = 4 # avoid too large scale value
try:
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[0:2]
if h > 3500 or w > 3500:
output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return output
if (h < 300 and w < 300) and upsampler_name != "srcnn":
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
return img
upsampler = get_upsampler(upsampler_name, device=device)
if face_enhancer_name:
face_enhancer = get_face_enhancer(
face_enhancer_name, scale, upsampler, device=device
)
else:
face_enhancer = None
try:
if face_enhancer is not None:
_, _, output = face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _ = upsampler.enhance(img, outscale=scale)
except RuntimeError as error:
print(f"Runtime error: {error}")
return output
except Exception as error:
print(f"global exception: {error}")
def main(args: argparse.Namespace) -> None:
device = args.device
scale = args.scale
upsampler_name = args.upsampler
face_enhancer_name = args.face_enhancer
if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
print(
"Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
)
face_enhancer_name = None
os.makedirs(args.output, exist_ok=True)
if not os.path.exists(args.input):
raise ValueError("The input directory doesn't exist!")
elif not os.path.isdir(args.input):
image_paths = [args.input]
else:
image_paths = paths.list_images(args.input)
with tqdm(image_paths) as pbar:
for image_path in pbar:
filename = os.path.basename(image_path)
pbar.set_postfix_str(f"Processing {image_path}")
upsampled_image = process(
image_path=image_path,
upsampler_name=upsampler_name,
face_enhancer_name=face_enhancer_name,
scale=scale,
device=device,
)
if upsampled_image is not None:
save_path = os.path.join(args.output, filename)
cv2.imwrite(save_path, upsampled_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Runs automatic detection and mask generation on an input image or directory of images"
)
)
parser.add_argument(
"--input",
"-i",
type=str,
required=True,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
"-o",
type=str,
required=True,
help="Path to the output directory.",
)
parser.add_argument(
"--upsampler",
type=str,
default="realesr-general-x4v3",
choices=[
"srcnn",
"RealESRGAN_x2plus",
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"realesr-general-x4v3",
"RealESRGAN_x4plus_anime_6B",
"realesr-animevideov3",
],
help="The type of upsampler model to load",
)
parser.add_argument(
"--face-enhancer",
type=str,
choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
help="The type of face enhancer model to load",
)
parser.add_argument(
"--scale",
type=float,
default=2,
choices=[1.5, 2, 2.5, 3, 3.5, 4],
help="scaling factor",
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to run upsampling on."
)
args = parser.parse_args()
main(args)