Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from image_gen_aux import UpscaleWithModel | |
from image_gen_aux.utils import load_image | |
from fastapi.middleware.cors import CORSMiddleware | |
import math | |
MODELS = { | |
"4xNomosWebPhotoRealPLKSR": "Phips/4xNomosWebPhoto_RealPLKSR", | |
"4xRealESRGAN": "luca115/4xRealESRGAN", | |
"4xRealHATGANSharper": "luca115/Real_HAT_GAN_SHARPER", | |
"4xSwinIRLarge": "luca115/4xSwinIRLarge", | |
} | |
def get_duration( | |
image, model_selection | |
): | |
width, height = image.size | |
pixel = width * height | |
if model_selection in ["4xNomosWebPhotoRealPLKSR", "4xRealESRGAN"]: | |
return math.ceil((pixel * 10) / 1_000_000) + 3 | |
else: | |
return math.ceil((pixel * 30) / 1_000_000) + 3 | |
def upscale_image(image, model_selection): | |
original = load_image(image) | |
upscaler = UpscaleWithModel.from_pretrained(MODELS[model_selection]).to("cuda") | |
image = upscaler(original, tiling=True, tile_width=1024, tile_height=1024) | |
return original, image | |
def clear_result(): | |
return gr.update(value=None) | |
title = """<h1 align="center">Best Upscaling Models</h1> | |
<div align="center">A collection of my favorite non-diffusion-based upscaling models. For diffusion-based methods, check out these <a href="https://upsampler.com">creative image upscalers and enhancers</a>.</div> | |
""" | |
with gr.Blocks() as demo: | |
gr.HTML(title) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
model_selection = gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value="4xSwinIRLarge", | |
label="Model", | |
) | |
run_button = gr.Button("Upscale") | |
with gr.Column(): | |
result = gr.ImageSlider( | |
interactive=False, | |
label="Generated Image", | |
format="png" | |
) | |
run_button.click( | |
fn=clear_result, | |
inputs=None, | |
outputs=result, | |
).then( | |
fn=upscale_image, | |
inputs=[input_image, model_selection], | |
outputs=result, | |
) | |
app, local_url, share_url = demo.launch(share=True) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |