danielsapit commited on
Commit
547bce6
1 Parent(s): a8f51b3

Create old_app.py

Browse files
Files changed (1) hide show
  1. old_app.py +158 -0
old_app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os.path
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+ import torch
6
+ import cv2
7
+ from PIL import Image, ImageOps
8
+ import utils_image as util
9
+ from network_fbcnn import FBCNN as net
10
+ import requests
11
+
12
+ for model_path in ['fbcnn_gray.pth','fbcnn_color.pth']:
13
+ if os.path.exists(model_path):
14
+ print(f'{model_path} exists.')
15
+ else:
16
+ url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
17
+ r = requests.get(url, allow_redirects=True)
18
+ open(model_path, 'wb').write(r.content)
19
+
20
+ def inference(input_img, is_gray, input_quality, enable_zoom, zoom, x_shift, y_shift, state):
21
+
22
+ if is_gray:
23
+ n_channels = 1 # set 1 for grayscale image, set 3 for color image
24
+ model_name = 'fbcnn_gray.pth'
25
+ else:
26
+ n_channels = 3 # set 1 for grayscale image, set 3 for color image
27
+ model_name = 'fbcnn_color.pth'
28
+ nc = [64,128,256,512]
29
+ nb = 4
30
+
31
+
32
+ input_quality = 100 - input_quality
33
+
34
+ model_path = model_name
35
+
36
+ if os.path.exists(model_path):
37
+ print(f'loading model from {model_path}')
38
+ else:
39
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
40
+ url = 'https://github.com/jiaxi-jiang/FBCNN/releases/download/v1.0/{}'.format(os.path.basename(model_path))
41
+ r = requests.get(url, allow_redirects=True)
42
+ open(model_path, 'wb').write(r.content)
43
+
44
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # ----------------------------------------
47
+ # load model
48
+ # ----------------------------------------
49
+ if (not enable_zoom) or (state[1] is None):
50
+ model = net(in_nc=n_channels, out_nc=n_channels, nc=nc, nb=nb, act_mode='R')
51
+ model.load_state_dict(torch.load(model_path), strict=True)
52
+ model.eval()
53
+ for k, v in model.named_parameters():
54
+ v.requires_grad = False
55
+ model = model.to(device)
56
+
57
+ test_results = OrderedDict()
58
+ test_results['psnr'] = []
59
+ test_results['ssim'] = []
60
+ test_results['psnrb'] = []
61
+
62
+ # ------------------------------------
63
+ # (1) img_L
64
+ # ------------------------------------
65
+
66
+ if n_channels == 1:
67
+ open_cv_image = Image.fromarray(input_img)
68
+ open_cv_image = ImageOps.grayscale(open_cv_image)
69
+ open_cv_image = np.array(open_cv_image) # PIL to open cv image
70
+ img = np.expand_dims(open_cv_image, axis=2) # HxWx1
71
+ elif n_channels == 3:
72
+ open_cv_image = np.array(input_img) # PIL to open cv image
73
+ if open_cv_image.ndim == 2:
74
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_GRAY2RGB) # GGG
75
+ else:
76
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) # RGB
77
+
78
+ img_L = util.uint2tensor4(open_cv_image)
79
+ img_L = img_L.to(device)
80
+
81
+ # ------------------------------------
82
+ # (2) img_E
83
+ # ------------------------------------
84
+
85
+ img_E,QF = model(img_L)
86
+ QF = 1- QF
87
+ img_E = util.tensor2single(img_E)
88
+ img_E = util.single2uint(img_E)
89
+
90
+ qf_input = torch.tensor([[1-input_quality/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-input_quality/100]])
91
+ img_E,QF = model(img_L, qf_input)
92
+ QF = 1- QF
93
+ img_E = util.tensor2single(img_E)
94
+ img_E = util.single2uint(img_E)
95
+
96
+ if img_E.ndim == 3:
97
+ img_E = img_E[:, :, [2, 1, 0]]
98
+
99
+ print("--inference finished")
100
+ if (state[1] is not None) and enable_zoom:
101
+ img_E = state[1]
102
+ out_img = Image.fromarray(img_E)
103
+ out_img_w, out_img_h = out_img.size # output image size
104
+ zoom = zoom/100
105
+ x_shift = x_shift/100
106
+ y_shift = y_shift/100
107
+ zoom_w, zoom_h = out_img_w*zoom, out_img_h*zoom
108
+ zoom_left, zoom_right = int((out_img_w - zoom_w)*x_shift), int(zoom_w + (out_img_w - zoom_w)*x_shift)
109
+ zoom_top, zoom_bottom = int((out_img_h - zoom_h)*y_shift), int(zoom_h + (out_img_h - zoom_h)*y_shift)
110
+ if (state[0] is None) or not enable_zoom:
111
+ in_img = Image.fromarray(input_img)
112
+ state[0] = input_img
113
+ else:
114
+ in_img = Image.fromarray(state[0])
115
+ in_img = in_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
116
+ in_img = in_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
117
+ out_img = out_img.crop((zoom_left, zoom_top, zoom_right, zoom_bottom))
118
+ out_img = out_img.resize((int(zoom_w/zoom), int(zoom_h/zoom)), Image.NEAREST)
119
+
120
+ return img_E, in_img, out_img, [state[0],img_E]
121
+
122
+ gr.Interface(
123
+ fn = inference,
124
+ inputs = [gr.inputs.Image(label="Input Image"),
125
+ gr.inputs.Checkbox(label="Grayscale (Check this if your image is grayscale)"),
126
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, label="Intensity (Higher = stronger JPEG artifact removal)"),
127
+ gr.inputs.Checkbox(default=False, label="Edit Zoom preview (This is optional. "
128
+ "After the image result is loaded, check this to edit zoom parameters "
129
+ "so that the input image will not be processed when the submit button is pressed.)"),
130
+ gr.inputs.Slider(minimum=10, maximum=100, step=1, default=50, label="Zoom Image "
131
+ "(Use this to see the image quality up close. "
132
+ "100 = original size)"),
133
+ gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview horizontal shift "
134
+ "(Increase to shift to the right)"),
135
+ gr.inputs.Slider(minimum=0, maximum=100, step=1, label="Zoom preview vertical shift "
136
+ "(Increase to shift downwards)"),
137
+ gr.inputs.State(default=[None,None], label="\t")
138
+ ],
139
+ outputs = [gr.outputs.Image(label="Result"),
140
+ gr.outputs.Image(label="Before:"),
141
+ gr.outputs.Image(label="After:"),
142
+ "state"],
143
+ examples = [["doraemon.jpg",False,60,False,42,50,50],
144
+ ["tomandjerry.jpg",False,60,False,40,57,44],
145
+ ["somepanda.jpg",True,100,False,30,8,24],
146
+ ["cemetry.jpg",False,70,False,20,76,62],
147
+ ["michelangelo_david.jpg",True,30,False,12,53,27],
148
+ ["elon_musk.jpg",False,45,False,15,33,30],
149
+ ["text.jpg",True,70,False,50,11,29]],
150
+ title = "JPEG Artifacts Removal [FBCNN]",
151
+ description = "Gradio Demo for JPEG Artifacts Removal. To use it, simply upload your image, "
152
+ "or click one of the examples to load them. Check out the paper and the original GitHub repo at the link below. "
153
+ "JPEG artifacts are noticeable distortion of images caused by JPEG lossy compression. "
154
+ "This is not a super resolution AI but a JPEG compression artifact remover.",
155
+ article = "<p style='text-align: center;'><a href='https://github.com/jiaxi-jiang/FBCNN'>FBCNN GitHub Repo</a><br>"
156
+ "<a href='https://arxiv.org/abs/2109.14573'>Towards Flexible Blind JPEG Artifacts Removal (FBCNN, ICCV 2021)</a></p>",
157
+ allow_flagging="never"
158
+ ).launch(enable_queue=True)