import argparse import cv2 import os from imutils import paths from tqdm import tqdm from config import * from utils import get_face_enhancer, get_upsampler def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"): if scale > 4: scale = 4 # avoid too large scale value try: img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) h, w = img.shape[0:2] if h > 3500 or w > 3500: output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return output if (h < 300 and w < 300) and upsampler_name != "srcnn": img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) return img upsampler = get_upsampler(upsampler_name, device=device) if face_enhancer_name: face_enhancer = get_face_enhancer( face_enhancer_name, scale, upsampler, device=device ) else: face_enhancer = None try: if face_enhancer is not None: _, _, output = face_enhancer.enhance( img, has_aligned=False, only_center_face=False, paste_back=True ) else: output, _ = upsampler.enhance(img, outscale=scale) except RuntimeError as error: print(f"Runtime error: {error}") return output except Exception as error: print(f"global exception: {error}") def main(args: argparse.Namespace) -> None: device = args.device scale = args.scale upsampler_name = args.upsampler face_enhancer_name = args.face_enhancer if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name): print( "Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you" ) face_enhancer_name = None os.makedirs(args.output, exist_ok=True) if not os.path.exists(args.input): raise ValueError("The input directory doesn't exist!") elif not os.path.isdir(args.input): image_paths = [args.input] else: image_paths = paths.list_images(args.input) with tqdm(image_paths) as pbar: for image_path in pbar: filename = os.path.basename(image_path) pbar.set_postfix_str(f"Processing {image_path}") upsampled_image = process( image_path=image_path, upsampler_name=upsampler_name, face_enhancer_name=face_enhancer_name, scale=scale, device=device, ) if upsampled_image is not None: save_path = os.path.join(args.output, filename) cv2.imwrite(save_path, upsampled_image) if __name__ == "__main__": parser = argparse.ArgumentParser( description=( "Runs automatic detection and mask generation on an input image or directory of images" ) ) parser.add_argument( "--input", "-i", type=str, required=True, help="Path to either a single input image or folder of images.", ) parser.add_argument( "--output", "-o", type=str, required=True, help="Path to the output directory.", ) parser.add_argument( "--upsampler", type=str, default="realesr-general-x4v3", choices=[ "srcnn", "RealESRGAN_x2plus", "RealESRGAN_x4plus", "RealESRNet_x4plus", "realesr-general-x4v3", "RealESRGAN_x4plus_anime_6B", "realesr-animevideov3", ], help="The type of upsampler model to load", ) parser.add_argument( "--face-enhancer", type=str, choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"], help="The type of face enhancer model to load", ) parser.add_argument( "--scale", type=float, default=2, choices=[1.5, 2, 2.5, 3, 3.5, 4], help="scaling factor", ) parser.add_argument( "--device", type=str, default="cuda", help="The device to run upsampling on." ) args = parser.parse_args() main(args)