File size: 4,405 Bytes
0fe2a53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import cv2
import os

from imutils import paths
from tqdm import tqdm
from config import *
from utils import get_face_enhancer, get_upsampler


def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
    if scale > 4:
        scale = 4  # avoid too large scale value
    try:
        img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)

        h, w = img.shape[0:2]
        if h > 3500 or w > 3500:
            output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            return output

        if (h < 300 and w < 300) and upsampler_name != "srcnn":
            img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
            return img

        upsampler = get_upsampler(upsampler_name, device=device)

        if face_enhancer_name:
            face_enhancer = get_face_enhancer(
                face_enhancer_name, scale, upsampler, device=device
            )
        else:
            face_enhancer = None

        try:
            if face_enhancer is not None:
                _, _, output = face_enhancer.enhance(
                    img, has_aligned=False, only_center_face=False, paste_back=True
                )
            else:
                output, _ = upsampler.enhance(img, outscale=scale)
        except RuntimeError as error:
            print(f"Runtime error: {error}")

        return output
    except Exception as error:
        print(f"global exception: {error}")


def main(args: argparse.Namespace) -> None:
    device = args.device
    scale = args.scale

    upsampler_name = args.upsampler
    face_enhancer_name = args.face_enhancer

    if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
        print(
            "Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
        )
        face_enhancer_name = None

    os.makedirs(args.output, exist_ok=True)
    if not os.path.exists(args.input):
        raise ValueError("The input directory doesn't exist!")
    elif not os.path.isdir(args.input):
        image_paths = [args.input]
    else:
        image_paths = paths.list_images(args.input)

    with tqdm(image_paths) as pbar:
        for image_path in pbar:
            filename = os.path.basename(image_path)
            pbar.set_postfix_str(f"Processing {image_path}")
            upsampled_image = process(
                image_path=image_path,
                upsampler_name=upsampler_name,
                face_enhancer_name=face_enhancer_name,
                scale=scale,
                device=device,
            )
            if upsampled_image is not None:
                save_path = os.path.join(args.output, filename)
                cv2.imwrite(save_path, upsampled_image)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description=(
            "Runs automatic detection and mask generation on an input image or directory of images"
        )
    )

    parser.add_argument(
        "--input",
        "-i",
        type=str,
        required=True,
        help="Path to either a single input image or folder of images.",
    )

    parser.add_argument(
        "--output",
        "-o",
        type=str,
        required=True,
        help="Path to the output directory.",
    )

    parser.add_argument(
        "--upsampler",
        type=str,
        default="realesr-general-x4v3",
        choices=[
            "srcnn",
            "RealESRGAN_x2plus",
            "RealESRGAN_x4plus",
            "RealESRNet_x4plus",
            "realesr-general-x4v3",
            "RealESRGAN_x4plus_anime_6B",
            "realesr-animevideov3",
        ],
        help="The type of upsampler model to load",
    )

    parser.add_argument(
        "--face-enhancer",
        type=str,
        choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
        help="The type of face enhancer model to load",
    )

    parser.add_argument(
        "--scale",
        type=float,
        default=2,
        choices=[1.5, 2, 2.5, 3, 3.5, 4],
        help="scaling factor",
    )
    parser.add_argument(
        "--device", type=str, default="cuda", help="The device to run upsampling on."
    )
    args = parser.parse_args()
    main(args)