isr / upsample.py
dragonSwing's picture
Add application files
0fe2a53
raw
history blame contribute delete
No virus
4.41 kB
import argparse
import cv2
import os
from imutils import paths
from tqdm import tqdm
from config import *
from utils import get_face_enhancer, get_upsampler
def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
if scale > 4:
scale = 4 # avoid too large scale value
try:
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
h, w = img.shape[0:2]
if h > 3500 or w > 3500:
output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return output
if (h < 300 and w < 300) and upsampler_name != "srcnn":
img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
return img
upsampler = get_upsampler(upsampler_name, device=device)
if face_enhancer_name:
face_enhancer = get_face_enhancer(
face_enhancer_name, scale, upsampler, device=device
)
else:
face_enhancer = None
try:
if face_enhancer is not None:
_, _, output = face_enhancer.enhance(
img, has_aligned=False, only_center_face=False, paste_back=True
)
else:
output, _ = upsampler.enhance(img, outscale=scale)
except RuntimeError as error:
print(f"Runtime error: {error}")
return output
except Exception as error:
print(f"global exception: {error}")
def main(args: argparse.Namespace) -> None:
device = args.device
scale = args.scale
upsampler_name = args.upsampler
face_enhancer_name = args.face_enhancer
if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
print(
"Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
)
face_enhancer_name = None
os.makedirs(args.output, exist_ok=True)
if not os.path.exists(args.input):
raise ValueError("The input directory doesn't exist!")
elif not os.path.isdir(args.input):
image_paths = [args.input]
else:
image_paths = paths.list_images(args.input)
with tqdm(image_paths) as pbar:
for image_path in pbar:
filename = os.path.basename(image_path)
pbar.set_postfix_str(f"Processing {image_path}")
upsampled_image = process(
image_path=image_path,
upsampler_name=upsampler_name,
face_enhancer_name=face_enhancer_name,
scale=scale,
device=device,
)
if upsampled_image is not None:
save_path = os.path.join(args.output, filename)
cv2.imwrite(save_path, upsampled_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Runs automatic detection and mask generation on an input image or directory of images"
)
)
parser.add_argument(
"--input",
"-i",
type=str,
required=True,
help="Path to either a single input image or folder of images.",
)
parser.add_argument(
"--output",
"-o",
type=str,
required=True,
help="Path to the output directory.",
)
parser.add_argument(
"--upsampler",
type=str,
default="realesr-general-x4v3",
choices=[
"srcnn",
"RealESRGAN_x2plus",
"RealESRGAN_x4plus",
"RealESRNet_x4plus",
"realesr-general-x4v3",
"RealESRGAN_x4plus_anime_6B",
"realesr-animevideov3",
],
help="The type of upsampler model to load",
)
parser.add_argument(
"--face-enhancer",
type=str,
choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
help="The type of face enhancer model to load",
)
parser.add_argument(
"--scale",
type=float,
default=2,
choices=[1.5, 2, 2.5, 3, 3.5, 4],
help="scaling factor",
)
parser.add_argument(
"--device", type=str, default="cuda", help="The device to run upsampling on."
)
args = parser.parse_args()
main(args)