File size: 3,780 Bytes
3997eb3
a26597f
095fbac
a26597f
037d730
a26597f
 
 
 
884899c
a26597f
 
 
 
 
41506bf
a26597f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41506bf
 
 
 
9eb308e
41506bf
30a37ef
 
9eb308e
 
41506bf
3cd9827
 
 
 
 
 
 
 
 
a26597f
 
 
 
 
 
f0575e8
 
16660bd
d7b9b40
3178329
41506bf
 
 
 
a26597f
 
 
 
 
 
89d3a8e
a26597f
 
e7c6334
a26597f
 
 
 
 
65689c6
d5624a0
a26597f
c1d104f
30a37ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import cv2
import torch
import numpy as np
import gradio as gr
from PIL import Image
from super_image import ImageLoader, EdsrModel, MsrnModel, MdsrModel, AwsrnModel, A2nModel, CarnModel, PanModel, \
    HanModel, DrlnModel, RcanModel

title = "super-image"
description = "State of the Art Image Super-Resolution Models."
article = "<p style='text-align: center'><a href='https://github.com/eugenesiow/super-image'>Github Repo</a>" \
          "| <a href='https://eugenesiow.github.io/super-image/'>Documentation</a> " \
          "| <a href='https://github.com/eugenesiow/super-image#scale-x2'>Models</a></p>"


def get_model(model_name, scale):
    if model_name == 'EDSR':
        model = EdsrModel.from_pretrained('eugenesiow/edsr', scale=scale)
    elif model_name == 'MSRN':
        model = MsrnModel.from_pretrained('eugenesiow/msrn', scale=scale)
    elif model_name == 'MDSR':
        model = MdsrModel.from_pretrained('eugenesiow/mdsr', scale=scale)
    elif model_name == 'AWSRN-BAM':
        model = AwsrnModel.from_pretrained('eugenesiow/awsrn-bam', scale=scale)
    elif model_name == 'A2N':
        model = A2nModel.from_pretrained('eugenesiow/a2n', scale=scale)
    elif model_name == 'CARN':
        model = CarnModel.from_pretrained('eugenesiow/carn', scale=scale)
    elif model_name == 'PAN':
        model = PanModel.from_pretrained('eugenesiow/pan', scale=scale)
    elif model_name == 'HAN':
        model = HanModel.from_pretrained('eugenesiow/han', scale=scale)
    elif model_name == 'DRLN':
        model = DrlnModel.from_pretrained('eugenesiow/drln', scale=scale)
    elif model_name == 'RCAN':
        model = RcanModel.from_pretrained('eugenesiow/rcan', scale=scale)
    else:
        model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=scale)
    return model


def inference(img, scale_str, model_name):
    max_res = 1024
    scale = int(scale_str.replace('x', ''))
    width, height = img.size
    print(width, height)
    if width > max_res or height > max_res:
        img = img.thumbnail((max_res, max_res), Image.ANTIALIAS)
    model = get_model(model_name, scale)
    try:
        inputs = ImageLoader.load_image(img)
        preds = model(inputs)
        preds = preds.data.cpu().numpy()
        pred = preds[0].transpose((1, 2, 0)) * 255.0
        return Image.fromarray(pred.astype('uint8'), 'RGB')
    except Exception as e:
        print(e)
        return None


torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/baby_mini_d3_gaussian.bmp',
                               'baby.bmp')
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/woman_mini_d3_gaussian.bmp',
                               'woman.bmp')
torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/bird_mini_d4_gaussian.bmp',
                               'bird.bmp')

# models = ['EDSR-base', 'DRLN', 'EDSR', 'MDSR', 'A2N', 'PAN', 'AWSRN-BAM', 'MSRN']
models = ['EDSR-base', 'A2N', 'PAN', 'AWSRN-BAM', 'MSRN']
scales = [2, 3, 4]
for model_name in models:
    for scale in scales:
        get_model(model_name, scale)

gr.Interface(
    inference,
    [
        gr.inputs.Image(type="pil", label="Input"),
        gr.inputs.Radio(["x2", "x3", "x4"], label='scale'),
        gr.inputs.Dropdown(choices=models,
                           label='Model')
    ],
    gr.outputs.Image(type="pil", label="Output"),
    title=title,
    description=description,
    article=article,
    examples=[
        ['baby.bmp', 'x2', 'EDSR-base'],
        ['woman.bmp', 'x3', 'MSRN'],
        ['bird.bmp', 'x4', 'PAN']
    ],
    allow_flagging='never',
    ).launch(debug=False)