AlekseyKorshuk commited on
Commit
b21faa4
1 Parent(s): c53354c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from super_image import ImageLoader, EdsrModel, MsrnModel, MdsrModel, AwsrnModel, A2nModel, CarnModel, PanModel, \
7
+ HanModel, DrlnModel, RcanModel
8
+
9
+ title = "super-image"
10
+ description = "State of the Art Image Super-Resolution Models."
11
+ article = "<p style='text-align: center'><a href='https://github.com/eugenesiow/super-image'>Github Repo</a>" \
12
+ "| <a href='https://eugenesiow.github.io/super-image/'>Documentation</a> " \
13
+ "| <a href='https://github.com/eugenesiow/super-image#scale-x2'>Models</a></p>"
14
+
15
+
16
+ def get_model(model_name, scale):
17
+ if model_name == 'EDSR':
18
+ model = EdsrModel.from_pretrained('eugenesiow/edsr', scale=scale)
19
+ elif model_name == 'MSRN':
20
+ model = MsrnModel.from_pretrained('eugenesiow/msrn', scale=scale)
21
+ elif model_name == 'MDSR':
22
+ model = MdsrModel.from_pretrained('eugenesiow/mdsr', scale=scale)
23
+ elif model_name == 'AWSRN-BAM':
24
+ model = AwsrnModel.from_pretrained('eugenesiow/awsrn-bam', scale=scale)
25
+ elif model_name == 'A2N':
26
+ model = A2nModel.from_pretrained('eugenesiow/a2n', scale=scale)
27
+ elif model_name == 'CARN':
28
+ model = CarnModel.from_pretrained('eugenesiow/carn', scale=scale)
29
+ elif model_name == 'PAN':
30
+ model = PanModel.from_pretrained('eugenesiow/pan', scale=scale)
31
+ elif model_name == 'HAN':
32
+ model = HanModel.from_pretrained('eugenesiow/han', scale=scale)
33
+ elif model_name == 'DRLN':
34
+ model = DrlnModel.from_pretrained('eugenesiow/drln', scale=scale)
35
+ elif model_name == 'RCAN':
36
+ model = RcanModel.from_pretrained('eugenesiow/rcan', scale=scale)
37
+ else:
38
+ model = EdsrModel.from_pretrained('eugenesiow/edsr-base', scale=scale)
39
+ return model
40
+
41
+
42
+ def inference(img, scale_str, model_name):
43
+ max_res = 1024
44
+ scale = int(scale_str.replace('x', ''))
45
+ width, height = img.size
46
+ print(width, height)
47
+ if width > max_res or height > max_res:
48
+ img = img.thumbnail((max_res, max_res), Image.ANTIALIAS)
49
+ model = get_model(model_name, scale)
50
+ try:
51
+ inputs = ImageLoader.load_image(img)
52
+ preds = model(inputs)
53
+ preds = preds.data.cpu().numpy()
54
+ pred = preds[0].transpose((1, 2, 0)) * 255.0
55
+ return Image.fromarray(pred.astype('uint8'), 'RGB')
56
+ except Exception as e:
57
+ print(e)
58
+ return None
59
+
60
+
61
+ torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/baby_mini_d3_gaussian.bmp',
62
+ 'baby.bmp')
63
+ torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/woman_mini_d3_gaussian.bmp',
64
+ 'woman.bmp')
65
+ torch.hub.download_url_to_file('http://people.rennes.inria.fr/Aline.Roumy/results/images_SR_BMVC12/input_groundtruth/bird_mini_d4_gaussian.bmp',
66
+ 'bird.bmp')
67
+
68
+ # models = ['EDSR-base', 'DRLN', 'EDSR', 'MDSR', 'A2N', 'PAN', 'AWSRN-BAM', 'MSRN']
69
+ models = ['EDSR-base', 'A2N', 'PAN', 'AWSRN-BAM', 'MSRN']
70
+ scales = [2, 3, 4]
71
+ for model_name in models:
72
+ for scale in scales:
73
+ get_model(model_name, scale)
74
+
75
+ gr.Interface(
76
+ inference,
77
+ [
78
+ gr.inputs.Image(type="pil", label="Input"),
79
+ gr.inputs.Radio(["x2", "x3", "x4"], label='scale'),
80
+ gr.inputs.Dropdown(choices=models,
81
+ label='Model')
82
+ ],
83
+ gr.outputs.Image(type="pil", label="Output"),
84
+ title=title,
85
+ description=description,
86
+ article=article,
87
+ examples=[
88
+ ['baby.bmp', 'x2', 'EDSR-base'],
89
+ ['woman.bmp', 'x3', 'MSRN'],
90
+ ['bird.bmp', 'x4', 'PAN']
91
+ ],
92
+ enable_queue=True,
93
+ allow_flagging=False,
94
+ ).launch(debug=False)