shubham5027 commited on
Commit
8ccd726
·
verified ·
1 Parent(s): 1d21ce7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +114 -0
  2. requirements.txt +97 -0
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import sys
4
+ from torchvision.transforms import functional
5
+ sys.modules["torchvision.transforms.functional_tensor"] = functional
6
+
7
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
8
+ from gfpgan.utils import GFPGANer
9
+ from realesrgan.utils import RealESRGANer
10
+
11
+ import torch
12
+ import cv2
13
+ import gradio as gr
14
+
15
+
16
+ #Download Required Models
17
+ if not os.path.exists('realesr-general-x4v3.pth'):
18
+ os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
19
+ if not os.path.exists('GFPGANv1.2.pth'):
20
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
21
+ if not os.path.exists('GFPGANv1.3.pth'):
22
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
23
+ if not os.path.exists('GFPGANv1.4.pth'):
24
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
25
+ if not os.path.exists('RestoreFormer.pth'):
26
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
27
+
28
+
29
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
30
+ model_path = 'realesr-general-x4v3.pth'
31
+ half = True if torch.cuda.is_available() else False
32
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
33
+
34
+
35
+ # Save Image to the Directory
36
+ # os.makedirs('output', exist_ok=True)
37
+
38
+ def upscaler(img, version, scale):
39
+
40
+ try:
41
+
42
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
43
+ if len(img.shape) == 3 and img.shape[2] == 4:
44
+ img_mode = 'RGBA'
45
+ elif len(img.shape) == 2:
46
+ img_mode = None
47
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
48
+ else:
49
+ img_mode = None
50
+
51
+
52
+ h, w = img.shape[0:2]
53
+ if h < 300:
54
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
55
+
56
+
57
+ face_enhancer = GFPGANer(
58
+ model_path=f'{version}.pth',
59
+ upscale=2,
60
+ arch='RestoreFormer' if version=='RestoreFormer' else 'clean',
61
+ channel_multiplier=2,
62
+ bg_upsampler=upsampler
63
+ )
64
+
65
+
66
+ try:
67
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
68
+ except RuntimeError as error:
69
+ print('Error', error)
70
+
71
+
72
+ try:
73
+ if scale != 2:
74
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
75
+ h, w = img.shape[0:2]
76
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
77
+ except Exception as error:
78
+ print('wrong scale input.', error)
79
+
80
+ # Save Image to the Directory
81
+ # ext = os.path.splitext(os.path.basename(str(img)))[1]
82
+ # if img_mode == 'RGBA':
83
+ # ext = 'png'
84
+ # else:
85
+ # ext = 'jpg'
86
+ #
87
+ # save_path = f'output/out.{ext}'
88
+ # cv2.imwrite(save_path, output)
89
+ # return output, save_path
90
+
91
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
92
+ return output
93
+ except Exception as error:
94
+ print('global exception', error)
95
+ return None, None
96
+
97
+ if __name__ == "__main__":
98
+
99
+ title = "Image Upscaler & Restoring [GFPGAN Algorithm]"
100
+
101
+ demo = gr.Interface(
102
+ upscaler, [
103
+ gr.Image(type="filepath", label="Input"),
104
+ gr.Radio(['GFPGANv1.2', 'GFPGANv1.3', 'GFPGANv1.4', 'RestoreFormer'], type="value", label='version'),
105
+ gr.Number(label="Rescaling factor"),
106
+ ], [
107
+ gr.Image(type="numpy", label="Output"),
108
+ ],
109
+ title=title,
110
+ allow_flagging="never"
111
+ )
112
+
113
+ demo.queue()
114
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ addict==2.4.0
3
+ aiofiles==23.2.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ attrs==23.2.0
8
+ basicsr==1.4.2
9
+ certifi==2024.2.2
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ contourpy==1.2.1
13
+ cycler==0.12.1
14
+ facexlib==0.3.0
15
+ fastapi==0.110.2
16
+ ffmpy==0.3.2
17
+ filelock==3.13.4
18
+ filterpy==1.4.5
19
+ fonttools==4.51.0
20
+ fsspec==2024.3.1
21
+ future==1.0.0
22
+ gfpgan==1.3.8
23
+ gradio==4.28.3
24
+ gradio_client==0.16.0
25
+ grpcio==1.62.2
26
+ h11==0.14.0
27
+ httpcore==1.0.5
28
+ httpx==0.27.0
29
+ huggingface-hub==0.22.2
30
+ idna==3.7
31
+ imageio==2.34.1
32
+ importlib_metadata==7.1.0
33
+ importlib_resources==6.4.0
34
+ Jinja2==3.1.3
35
+ jsonschema==4.21.1
36
+ jsonschema-specifications==2023.12.1
37
+ kiwisolver==1.4.5
38
+ lazy_loader==0.4
39
+ llvmlite==0.42.0
40
+ lmdb==1.4.1
41
+ Markdown==3.6
42
+ markdown-it-py==3.0.0
43
+ MarkupSafe==2.1.5
44
+ matplotlib==3.8.4
45
+ mdurl==0.1.2
46
+ mpmath==1.3.0
47
+ networkx==3.3
48
+ numba==0.59.1
49
+ numpy==1.26.4
50
+ opencv-python==4.9.0.80
51
+ orjson==3.10.1
52
+ packaging==24.0
53
+ pandas==2.2.2
54
+ pillow==10.3.0
55
+ platformdirs==4.2.1
56
+ protobuf==4.25.3
57
+ pydantic==2.7.1
58
+ pydantic_core==2.18.2
59
+ pydub==0.25.1
60
+ Pygments==2.17.2
61
+ pyparsing==3.1.2
62
+ python-dateutil==2.9.0.post0
63
+ python-multipart==0.0.9
64
+ pytz==2024.1
65
+ PyYAML==6.0.1
66
+ realesrgan==0.3.0
67
+ referencing==0.35.0
68
+ requests==2.31.0
69
+ rich==13.7.1
70
+ rpds-py==0.18.0
71
+ ruff==0.4.2
72
+ scikit-image==0.23.2
73
+ scipy==1.13.0
74
+ semantic-version==2.10.0
75
+ shellingham==1.5.4
76
+ six==1.16.0
77
+ sniffio==1.3.1
78
+ starlette==0.37.2
79
+ sympy==1.12
80
+ tb-nightly==2.17.0a20240428
81
+ tensorboard-data-server==0.7.2
82
+ tifffile==2024.4.24
83
+ tomli==2.0.1
84
+ tomlkit==0.12.0
85
+ toolz==0.12.1
86
+ torch==2.3.0
87
+ torchvision==0.18.0
88
+ tqdm==4.66.2
89
+ typer==0.12.3
90
+ typing_extensions==4.11.0
91
+ tzdata==2024.1
92
+ urllib3==2.2.1
93
+ uvicorn==0.29.0
94
+ websockets==11.0.3
95
+ Werkzeug==3.0.2
96
+ yapf==0.40.2
97
+ zipp==3.18.1