super-image / app.py
Eugene Siow
Add updates to gr.Interface.
c1d104f
raw
history blame contribute delete
No virus
3.78 kB
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from super_image import ImageLoader, EdsrModel, MsrnModel, MdsrModel, AwsrnModel, A2nModel, CarnModel, PanModel, \
HanModel, DrlnModel, RcanModel
title = "super-image"
description = "State of the Art Image Super-Resolution Models."
article = "<p style='text-align: center'><a href='https://github.com/eugenesiow/super-image'>Github Repo</a>" \
"| <a href='https://eugenesiow.github.io/super-image/'>Documentation</a> " \
"| <a href='https://github.com/eugenesiow/super-image#scale-x2'>Models</a></p>"
def get_model(model_name, scale):
if model_name == 'EDSR':
model = EdsrModel.from_pretrained('eugenesiow/edsr', scale=scale)
elif model_name == 'MSRN':
model = MsrnModel.from_pretrained('eugenesiow/msrn', scale=scale)
elif model_name == 'MDSR':
model = MdsrModel.from_pretrained('eugenesiow/mdsr', scale=scale)
elif model_name == 'AWSRN-BAM':
model = AwsrnModel.from_pretrained('eugenesiow/awsrn-bam', scale=scale)
elif model_name == 'A2N':
model = A2nModel.from_pretrained('eugenesiow/a2n', scale=scale)
elif model_name == 'CARN':
model = CarnModel.from_pretrained('eugenesiow/carn', scale=scale)
elif model_name == 'PAN':
model = PanModel.from_pretrained('eugenesiow/pan', scale=scale)
elif model_name == 'HAN':
model = HanModel.from_pretrained('eugenesiow/han', scale=scale)
elif model_name == 'DRLN':
model = DrlnModel.from_pretrained('eugenesiow/drln', scale=scale)
elif model_name == 'RCAN':
model = RcanModel.from_pretrained('eugenesiow/rcan', scale=scale)
else:
model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=scale)
return model
def inference(img, scale_str, model_name):
max_res = 1024
scale = int(scale_str.replace('x', ''))
width, height = img.size
print(width, height)
if width > max_res or height > max_res:
img = img.thumbnail((max_res, max_res), Image.ANTIALIAS)
model = get_model(model_name, scale)
try:
inputs = ImageLoader.load_image(img)
preds = model(inputs)
preds = preds.data.cpu().numpy()
pred = preds[0].transpose((1, 2, 0)) * 255.0
return Image.fromarray(pred.astype('uint8'), 'RGB')
except Exception as e:
print(e)
return None
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/baby_mini_d3_gaussian.bmp',
'baby.bmp')
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/woman_mini_d3_gaussian.bmp',
'woman.bmp')
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/bird_mini_d4_gaussian.bmp',
'bird.bmp')
# models = ['EDSR-base', 'DRLN', 'EDSR', 'MDSR', 'A2N', 'PAN', 'AWSRN-BAM', 'MSRN']
models = ['EDSR-base', 'A2N', 'PAN', 'AWSRN-BAM', 'MSRN']
scales = [2, 3, 4]
for model_name in models:
for scale in scales:
get_model(model_name, scale)
gr.Interface(
inference,
[
gr.inputs.Image(type="pil", label="Input"),
gr.inputs.Radio(["x2", "x3", "x4"], label='scale'),
gr.inputs.Dropdown(choices=models,
label='Model')
],
gr.outputs.Image(type="pil", label="Output"),
title=title,
description=description,
article=article,
examples=[
['baby.bmp', 'x2', 'EDSR-base'],
['woman.bmp', 'x3', 'MSRN'],
['bird.bmp', 'x4', 'PAN']
],
allow_flagging='never',
).launch(debug=False)