File size: 2,854 Bytes
6e92463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a625565
 
 
 
 
6e92463
 
 
 
 
 
 
a625565
6e92463
 
 
 
 
 
9eab909
 
6e92463
 
 
 
 
 
 
a625565
6e92463
 
 
a625565
6e92463
 
 
 
a625565
6e92463
 
 
 
9eab909
36a1be4
219dbcc
9eab909
36a1be4
9eab909
36a1be4
 
3f6e15f
219dbcc
3f6e15f
6e92463
3f6e15f
9eab909
 
 
 
 
 
 
 
 
 
3f6e15f
 
9eab909
6e92463
 
3f6e15f
219dbcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import cv2
import onnxruntime
import gradio as gr


def pre_process(img: np.array) -> np.array:
    # H, W, C -> C, H, W
    img = np.transpose(img[:, :, 0:3], (2, 0, 1))
    # C, H, W -> 1, C, H, W
    img = np.expand_dims(img, axis=0).astype(np.float32)
    return img


def post_process(img: np.array) -> np.array:
    # 1, C, H, W -> C, H, W
    img = np.squeeze(img)
    # C, H, W -> H, W, C
    img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
    return img


def inference(model_path: str, img_array: np.array) -> np.array:
    options = onnxruntime.SessionOptions()
    options.intra_op_num_threads = 1
    options.inter_op_num_threads = 1
    ort_session = onnxruntime.InferenceSession(model_path, options)
    ort_inputs = {ort_session.get_inputs()[0].name: img_array}
    ort_outs = ort_session.run(None, ort_inputs)

    return ort_outs[0]


def convert_pil_to_cv2(image):
    # pil_image = image.convert("RGB")
    open_cv_image = np.array(image)
    # RGB to BGR
    open_cv_image = open_cv_image[:, :, ::-1].copy()
    return open_cv_image


def upscale(image, model):
    model_path = f"models/{model}.ort"
    img = convert_pil_to_cv2(image)
    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    if img.shape[2] == 4:
        alpha = img[:, :, 3]  # GRAY
        alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)  # BGR
        alpha_output = post_process(inference(model_path, pre_process(alpha)))  # BGR
        alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY)  # GRAY

        img = img[:, :, 0:3]  # BGR
        image_output = post_process(inference(model_path, pre_process(img)))  # BGR
        image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA)  # BGRA
        image_output[:, :, 3] = alpha_output

    elif img.shape[2] == 3:
        image_output = post_process(inference(model_path, pre_process(img)))  # BGR

    return image_output


examples = [[f"examples/example_{i+1}.png", "modelx4"] for i in range(5)]
examples += [[f"examples_x2/example_{i+1}.png", "modelx2"] for i in range(5)]
examples += [[f"examples_x2_25/example-{i+1}.png", ""] for i in range(5)]
examples += [
    [f"minecraft_examples/minecraft-{i+1}.png", "minecraft_modelx4"] for i in range(5)
]


css = ".output-image, .input-image, .image-preview {height: 480px !important} "
model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]

gr.Interface(
    fn=upscale,
    inputs=[
        gr.inputs.Image(type="pil"),
        gr.inputs.Radio(
            model_choices,
            type="value",
            default=None,
            label="Choose a Model",
            optional=False,
        ),
    ],
    outputs="image",
    examples=examples,
    examples_per_page=5,
    title="Image Upscaling 🦆",
    allow_flagging="never",
    css=css,
).launch()