hamg22 commited on
Commit
28d826a
1 Parent(s): b898db5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.sys.path
3
+
4
+ !pip install opencv-python
5
+
6
+ !pip install basicsr
7
+ !pip install facexlib
8
+ !pip install gfpgan
9
+ !pip install tqdm
10
+ !pip install -U gradio
11
+
12
+ %pip install realesrgan
13
+
14
+ import gradio as gr
15
+ import cv2
16
+ import numpy
17
+ import os
18
+ import random
19
+ from basicsr.archs.rrdbnet_arch import RRDBNet
20
+ from basicsr.utils.download_util import load_file_from_url
21
+
22
+ from realesrgan import RealESRGANer
23
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
24
+
25
+ last_file = None
26
+ img_mode = "RGBA"
27
+
28
+
29
+ def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
30
+ """Real-ESRGAN function to restore (and upscale) images.
31
+ """
32
+ if not img:
33
+ return
34
+
35
+ # Define model parameters
36
+ if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
37
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
38
+ netscale = 4
39
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
40
+ elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
41
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
42
+ netscale = 4
43
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
44
+ elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
45
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
46
+ netscale = 4
47
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
48
+ elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
49
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
50
+ netscale = 2
51
+ file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
52
+ elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
53
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
54
+ netscale = 4
55
+ file_url = [
56
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
57
+ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
58
+ ]
59
+
60
+ # Determine model paths
61
+ model_path = os.path.join('weights', model_name + '.pth')
62
+ if not os.path.isfile(model_path):
63
+ ROOT_DIR = os.path.dirname(os.path.abspath("."))
64
+ for url in file_url:
65
+ # model_path will be updated
66
+ model_path = load_file_from_url(
67
+ url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
68
+
69
+ # Use dni to control the denoise strength
70
+ dni_weight = None
71
+ if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
72
+ wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
73
+ model_path = [model_path, wdn_model_path]
74
+ dni_weight = [denoise_strength, 1 - denoise_strength]
75
+
76
+ # Restorer Class
77
+ upsampler = RealESRGANer(
78
+ scale=netscale,
79
+ model_path=model_path,
80
+ dni_weight=dni_weight,
81
+ model=model,
82
+ tile=0,
83
+ tile_pad=10,
84
+ pre_pad=10,
85
+ half=False,
86
+ gpu_id=None
87
+ )
88
+
89
+ # Use GFPGAN for face enhancement
90
+ if face_enhance:
91
+ from gfpgan import GFPGANer
92
+ face_enhancer = GFPGANer(
93
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
94
+ upscale=outscale,
95
+ arch='clean',
96
+ channel_multiplier=2,
97
+ bg_upsampler=upsampler)
98
+
99
+ # Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
100
+ cv_img = numpy.array(img)
101
+ img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
102
+
103
+ # Apply restoration
104
+ try:
105
+ if face_enhance:
106
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
107
+ else:
108
+ output, _ = upsampler.enhance(img, outscale=outscale)
109
+ except RuntimeError as error:
110
+ print('Error', error)
111
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
112
+ else:
113
+ # Save restored image and return it to the output Image component
114
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
115
+ extension = 'png'
116
+ else:
117
+ extension = 'jpg'
118
+
119
+ out_filename = f"output_{rnd_string(8)}.{extension}"
120
+ cv2.imwrite(out_filename, output)
121
+ global last_file
122
+ last_file = out_filename
123
+ return out_filename
124
+
125
+
126
+ def rnd_string(x):
127
+ """Returns a string of 'x' random characters
128
+ """
129
+ characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
130
+ result = "".join((random.choice(characters)) for i in range(x))
131
+ return result
132
+
133
+
134
+ def reset():
135
+ """Resets the Image components of the Gradio interface and deletes
136
+ the last processed image
137
+ """
138
+ global last_file
139
+ if last_file:
140
+ print(f"Deleting {last_file} ...")
141
+ os.remove(last_file)
142
+ last_file = None
143
+ return gr.update(value=None), gr.update(value=None)
144
+
145
+
146
+ def has_transparency(img):
147
+ """This function works by first checking to see if a "transparency" property is defined
148
+ in the image's info -- if so, we return "True". Then, if the image is using indexed colors
149
+ (such as in GIFs), it gets the index of the transparent color in the palette
150
+ (img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
151
+ (img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
152
+ it, but it double-checks by getting the minimum and maximum values of every color channel
153
+ (img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
154
+ https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
155
+ """
156
+ if img.info.get("transparency", None) is not None:
157
+ return True
158
+ if img.mode == "P":
159
+ transparent = img.info.get("transparency", -1)
160
+ for _, index in img.getcolors():
161
+ if index == transparent:
162
+ return True
163
+ elif img.mode == "RGBA":
164
+ extrema = img.getextrema()
165
+ if extrema[3][0] < 255:
166
+ return True
167
+ return False
168
+
169
+
170
+ def image_properties(img):
171
+ """Returns the dimensions (width and height) and color mode of the input image and
172
+ also sets the global img_mode variable to be used by the realesrgan function
173
+ """
174
+ global img_mode
175
+ if img:
176
+ try:
177
+ img_mode = img.mode # Get the color mode directly
178
+ properties = f"Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
179
+ return properties
180
+ except Exception as e:
181
+ print(f"Error processing image: {e}")
182
+
183
+ return "Invalid image"
184
+
185
+ def main():
186
+ # Gradio Interface
187
+ with gr.Blocks(title="PixelUP Demo", theme="dark") as demo:
188
+
189
+ gr.Markdown(
190
+ """# <div align="center"> </div>
191
+ <div align="center"><img width="200" height="74" src="https://i.imgur.com/UzFFDdL.png"></div>
192
+ """
193
+ )
194
+
195
+ with gr.Accordion("Options/Parameters"):
196
+ with gr.Row():
197
+ model_name = gr.Dropdown(label="Choose Model",
198
+ choices=["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B",
199
+ "RealESRGAN_x2plus", "realesr-general-x4v3"],
200
+ value="realesr-general-x4v3", show_label=True)
201
+ denoise_strength = gr.Slider(label="Denoise Strength",
202
+ minimum=0, maximum=1, step=0.1, value=0.5)
203
+ outscale = gr.Slider(label="Image Upscaling Factor",
204
+ minimum=1, maximum=10, step=1, value=2, show_label=True)
205
+ face_enhance = gr.Checkbox(label="Face Enhancement using GFPGAN ",
206
+ value=False, show_label=True)
207
+
208
+ with gr.Row():
209
+ with gr.Group():
210
+ input_image = gr.Image(label="Source Image", type="pil", image_mode="RGBA")
211
+ input_image_properties = gr.Textbox(label="Image Properties", max_lines=1)
212
+ output_image = gr.Image(label="Restored Image", image_mode="RGBA")
213
+ with gr.Row():
214
+ restore_btn = gr.Button("Restore Image")
215
+ reset_btn = gr.Button("Reset")
216
+
217
+ # Event listeners:
218
+ input_image.change(fn=image_properties, inputs=input_image, outputs=input_image_properties)
219
+ restore_btn.click(fn=realesrgan,
220
+ inputs=[input_image, model_name, denoise_strength, face_enhance, outscale],
221
+ outputs=output_image)
222
+ reset_btn.click(fn=reset, inputs=[], outputs=[output_image, input_image])
223
+ # reset_btn.click(None, inputs=[], outputs=[input_image], _js="() => (null)\n")
224
+ # Undocumented method to clear a component's value using Javascript
225
+
226
+ gr.Markdown(
227
+ """*Please note that support for animated GIFs is not yet implemented. Should an animated GIF is chosen for restoration,
228
+ the demo will output only the first frame saved in PNG format (to preserve probable transparency).*
229
+ """
230
+ )
231
+
232
+ demo.launch(share=True)
233
+
234
+
235
+ if __name__ == "__main__":
236
+ main()