Spaces:
Running
Running
import numpy as np | |
import cv2 | |
import onnxruntime | |
import gradio as gr | |
article_text = """ | |
<div style="text-align: center;"> | |
<p>Enjoying the tool? Buy me a coffee and get exclusive prompt guides!</p> | |
<p><i>Instantly unlock helpful tips for creating better prompts!</i></p> | |
<div style="display: flex; justify-content: center;"> | |
<a href="https://piczify.lemonsqueezy.com/buy/0f5206fa-68e8-42f6-9ca8-4f80c587c83e"> | |
<img src="https://www.buymeacoffee.com/assets/img/custom_images/yellow_img.png" | |
alt="Buy Me a Coffee" | |
style="height: 40px; width: auto; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); border-radius: 10px;"> | |
</a> | |
</div> | |
</div> | |
""" | |
def pre_process(img: np.array) -> np.array: | |
# H, W, C -> C, H, W | |
img = np.transpose(img[:, :, 0:3], (2, 0, 1)) | |
# C, H, W -> 1, C, H, W | |
img = np.expand_dims(img, axis=0).astype(np.float32) | |
return img | |
def post_process(img: np.array) -> np.array: | |
# 1, C, H, W -> C, H, W | |
img = np.squeeze(img) | |
# C, H, W -> H, W, C | |
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8) | |
return img | |
def inference(model_path: str, img_array: np.array) -> np.array: | |
options = onnxruntime.SessionOptions() | |
options.intra_op_num_threads = 1 | |
options.inter_op_num_threads = 1 | |
ort_session = onnxruntime.InferenceSession(model_path, options) | |
ort_inputs = {ort_session.get_inputs()[0].name: img_array} | |
ort_outs = ort_session.run(None, ort_inputs) | |
return ort_outs[0] | |
def convert_pil_to_cv2(image): | |
# pil_image = image.convert("RGB") | |
open_cv_image = np.array(image) | |
# RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
return open_cv_image | |
def upscale(image, model): | |
model_path = f"models/{model}.ort" | |
img = convert_pil_to_cv2(image) | |
if img.ndim == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
if img.shape[2] == 4: | |
alpha = img[:, :, 3] # GRAY | |
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR | |
alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR | |
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY | |
img = img[:, :, 0:3] # BGR | |
image_output = post_process(inference(model_path, pre_process(img))) # BGR | |
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA | |
image_output[:, :, 3] = alpha_output | |
elif img.shape[2] == 3: | |
image_output = post_process(inference(model_path, pre_process(img))) # BGR | |
return image_output | |
css = ".output-image, .input-image, .image-preview {height: 480px !important} " | |
model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"] | |
gr.Interface( | |
fn=upscale, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Radio( | |
model_choices, | |
type="value", | |
value="modelx4", | |
label="Choose Upscaler", | |
) | |
], | |
# additional_inputs=[ | |
# gr.Radio( | |
# model_choices, | |
# type="value", | |
# value="modelx4", | |
# label="Choose Upscaler", | |
# ) | |
# ], | |
outputs="image", | |
# title="Image Upscaler PRO ⚡", | |
# description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)", | |
description = """ | |
<div style="text-align: center;"> | |
<h1>Image Upscaler PRO ⚡</h1> | |
<a href="https://arxiv.org/abs/2105.09750"> | |
<img src="https://img.shields.io/badge/arXiv-2105.09750-b31b1b.svg" alt="Arxiv" style="display:inline-block;"> | |
</a> | |
<p>Anchor-based Plain Net for Mobile Image Super-Resolution</p> | |
</div> | |
""", | |
article =article_text, | |
allow_flagging="never", | |
css=css, | |
).launch() | |