dragonSwing commited on
Commit
0fe2a53
1 Parent(s): 2832c43

Add application files

Browse files
.gitignore ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,metals
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,metals
3
+
4
+ ### Metals ###
5
+ .metals/
6
+ .bloop/
7
+ project/**/metals.sbt
8
+
9
+ ### Python ###
10
+ # Byte-compiled / optimized / DLL files
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ dist/
23
+ downloads/
24
+ eggs/
25
+ .eggs/
26
+ lib/
27
+ lib64/
28
+ parts/
29
+ sdist/
30
+ var/
31
+ wheels/
32
+ share/python-wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .nox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ *.py,cover
59
+ .hypothesis/
60
+ .pytest_cache/
61
+ cover/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+ db.sqlite3-journal
72
+
73
+ # Flask stuff:
74
+ instance/
75
+ .webassets-cache
76
+
77
+ # Scrapy stuff:
78
+ .scrapy
79
+
80
+ # Sphinx documentation
81
+ docs/_build/
82
+
83
+ # PyBuilder
84
+ .pybuilder/
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # IPython
91
+ profile_default/
92
+ ipython_config.py
93
+
94
+ # pyenv
95
+ # For a library or package, you might want to ignore these files since the code is
96
+ # intended to run in multiple environments; otherwise, check them in:
97
+ # .python-version
98
+
99
+ # pipenv
100
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
101
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
102
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
103
+ # install all needed dependencies.
104
+ #Pipfile.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ #poetry.lock
112
+
113
+ # pdm
114
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115
+ #pdm.lock
116
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
117
+ # in version control.
118
+ # https://pdm.fming.dev/#use-with-ide
119
+ .pdm.toml
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
170
+
171
+ ### Python Patch ###
172
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
173
+ poetry.toml
174
+
175
+ # ruff
176
+ .ruff_cache/
177
+
178
+ # LSP config files
179
+ pyrightconfig.json
180
+
181
+ ### VisualStudioCode ###
182
+ .vscode/*
183
+ !.vscode/settings.json
184
+ !.vscode/tasks.json
185
+ !.vscode/launch.json
186
+ !.vscode/extensions.json
187
+ !.vscode/*.code-snippets
188
+
189
+ # Local History for Visual Studio Code
190
+ .history/
191
+
192
+ # Built Visual Studio Code Extensions
193
+ *.vsix
194
+
195
+ ### VisualStudioCode Patch ###
196
+ # Ignore all local history of files
197
+ .history
198
+ .ionide
199
+
200
+ # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,metals
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
2
+
3
+ # Arguments to build Docker Image using CUDA
4
+ ARG USE_CUDA=0
5
+ ARG TORCH_ARCH=
6
+
7
+ ENV AM_I_DOCKER True
8
+ ENV BUILD_WITH_CUDA "${USE_CUDA}"
9
+ ENV TORCH_CUDA_ARCH_LIST "${TORCH_ARCH}"
10
+ ENV CUDA_HOME /usr/local/cuda-11.6/
11
+
12
+ RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
13
+ libsm6=2:* libxext6=2:* git=1:* nano=2.* \
14
+ vim=2:* -y \
15
+ && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
16
+
17
+ WORKDIR /app
18
+ COPY . /app
19
+
20
+ RUN pip install -r requirements.txt
21
+
22
+ # Expose the desired port (change it if needed)
23
+ EXPOSE 7680
24
+ CMD ["python3", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Binh Le
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ from utils import get_upsampler, get_face_enhancer
5
+
6
+
7
+ def inference(img, task, model_name, scale):
8
+ if scale > 4:
9
+ scale = 4 # avoid too large scale value
10
+ try:
11
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
12
+
13
+ h, w = img.shape[0:2]
14
+ if h > 3500 or w > 3500:
15
+ raise gr.Error(f"image too large: {w} * {h}")
16
+
17
+ if (h < 300 and w < 300) and model_name != "srcnn":
18
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
19
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
20
+
21
+ if task == "face":
22
+ upsample_model_name = "realesr-general-x4v3"
23
+ else:
24
+ upsample_model_name = model_name
25
+ upsampler = get_upsampler(upsample_model_name)
26
+
27
+ if task == "face":
28
+ face_enhancer = get_face_enhancer(model_name, scale, upsampler)
29
+ else:
30
+ face_enhancer = None
31
+
32
+ try:
33
+ if face_enhancer is not None:
34
+ _, _, output = face_enhancer.enhance(
35
+ img, has_aligned=False, only_center_face=False, paste_back=True
36
+ )
37
+ else:
38
+ output, _ = upsampler.enhance(img, outscale=scale)
39
+ except RuntimeError as error:
40
+ raise gr.Error(error)
41
+
42
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
43
+ return output
44
+ except Exception as error:
45
+ raise gr.Error(f"global exception: {error}")
46
+
47
+
48
+ def on_task_change(task):
49
+ if task == "general":
50
+ return gr.Dropdown.update(
51
+ choices=[
52
+ "srcnn",
53
+ "RealESRGAN_x2plus",
54
+ "RealESRGAN_x4plus",
55
+ "RealESRNet_x4plus",
56
+ "realesr-general-x4v3",
57
+ ],
58
+ value="realesr-general-x4v3",
59
+ )
60
+ elif task == "face":
61
+ return gr.Dropdown.update(
62
+ choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"], value="GFPGANv1.4"
63
+ )
64
+ elif task == "anime":
65
+ return gr.Dropdown.update(
66
+ choices=["srcnn", "RealESRGAN_x4plus_anime_6B", "realesr-animevideov3"],
67
+ value="RealESRGAN_x4plus_anime_6B",
68
+ )
69
+
70
+
71
+ title = "ISR: General Image Super Resolution"
72
+ description = r"""Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.<br>
73
+ It can be used to restore your **old photos** or improve **AI-generated faces**.<br>
74
+ To use it, simply upload your image.<br>
75
+ If GFPGAN is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Github Repo</a> and recommend it to your friends 😊
76
+ """
77
+ article = r"""
78
+ <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_GFPGAN' alt='visitor badge'></center>
79
+ """
80
+
81
+ with gr.Blocks(css="style.css", title=title) as demo:
82
+ with gr.Row(elem_classes=["container"]):
83
+ with gr.Column(scale=2):
84
+ input_image = gr.Image(type="filepath", label="Input")
85
+ # with gr.Row():
86
+ task = gr.Dropdown(
87
+ ["general", "face", "anime"],
88
+ type="value",
89
+ value="general",
90
+ label="task",
91
+ )
92
+ model_name = gr.Dropdown(
93
+ [
94
+ "srcnn",
95
+ "RealESRGAN_x2plus",
96
+ "RealESRGAN_x4plus",
97
+ "RealESRNet_x4plus",
98
+ "realesr-general-x4v3",
99
+ ],
100
+ type="value",
101
+ value="realesr-general-x4v3",
102
+ label="model",
103
+ )
104
+ scale = gr.Slider(
105
+ minimum=1.5,
106
+ maximum=4,
107
+ value=2,
108
+ step=0.5,
109
+ label="Scale factor",
110
+ info="Scaling factor",
111
+ )
112
+ run_btn = gr.Button(value="Submit")
113
+
114
+ with gr.Column(scale=3):
115
+ output_image = gr.Image(type="numpy", label="Output image")
116
+
117
+ with gr.Row(elem_classes=["container"]):
118
+ gr.Examples(
119
+ [
120
+ ["examples/landscape.jpg", "general", 2],
121
+ ["examples/cat.jpg", "general", 2],
122
+ ["examples/cat2.jpg", "face", 2],
123
+ ["examples/AI-generate.png", "face", 2],
124
+ ["examples/Blake_Lively.png", "face", 2],
125
+ ["examples/old_image.jpg", "face", 2],
126
+ ["examples/naruto.png", "anime", 2],
127
+ ["examples/luffy2.jpg", "anime", 2],
128
+ ],
129
+ [input_image, task, scale],
130
+ )
131
+
132
+ run_btn.click(inference, [input_image, task, model_name, scale], [output_image])
133
+ task.change(on_task_change, [task], [model_name])
134
+
135
+ demo.queue(concurrency_count=4).launch(debug=True, share=True)
config.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ WEIGHT_DIR = "weights"
5
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
examples/AI-generate.png ADDED
examples/Blake_Lively.png ADDED
examples/cat.jpg ADDED
examples/cat2.jpg ADDED
examples/landscape.jpg ADDED
examples/luffy.jpg ADDED
examples/luffy2.jpg ADDED
examples/naruto.png ADDED
examples/old_image.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ gradio
3
+ matplotlib
4
+ numpy
5
+ opencv_python
6
+ Pillow
7
+ requests
8
+ torch
9
+ torchvision
10
+ transformers
11
+ imutils
12
+ argparse
13
+ tqdm
14
+ basicsr
15
+ facexlib
16
+ gfpgan
17
+ realesrgan
18
+ lmdb
19
+ pyyaml
20
+ yapf
srcnn.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from torch import nn
6
+ from torchvision import transforms as T
7
+
8
+
9
+ class SRCNN(nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_channels=3,
13
+ output_channels=3,
14
+ input_size=33,
15
+ label_size=21,
16
+ scale=2,
17
+ device=None,
18
+ ):
19
+ super().__init__()
20
+ self.input_size = input_size
21
+ self.label_size = label_size
22
+ self.pad = (self.input_size - self.label_size) // 2
23
+ self.scale = scale
24
+ self.model = nn.Sequential(
25
+ nn.Conv2d(input_channels, 64, 9),
26
+ nn.ReLU(),
27
+ nn.Conv2d(64, 32, 1),
28
+ nn.ReLU(),
29
+ nn.Conv2d(32, output_channels, 5),
30
+ nn.ReLU(),
31
+ )
32
+ self.transform = T.Compose(
33
+ [T.ToTensor()] # Scale between [0, 1]
34
+ )
35
+
36
+ if device is None:
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ self.device = device
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ return self.model(x)
42
+
43
+ @torch.no_grad()
44
+ def pre_process(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
45
+ if torch.is_tensor(x):
46
+ return x / 255.0
47
+ else:
48
+ return self.transform(x)
49
+
50
+ @torch.no_grad()
51
+ def post_process(self, x: torch.Tensor) -> torch.Tensor:
52
+ return x.clip(0, 1) * 255.0
53
+
54
+ @torch.no_grad()
55
+ def enhance(self, image: np.ndarray, outscale: float = 2) -> np.ndarray:
56
+ (h, w) = image.shape[:2]
57
+ scale_w = int((w - w % self.label_size + self.input_size) * self.scale)
58
+ scale_h = int((h - h % self.label_size + self.input_size) * self.scale)
59
+ # resize the input image using bicubic interpolation
60
+ scaled = cv2.resize(image, (scale_w, scale_h), interpolation=cv2.INTER_CUBIC)
61
+ # Preprocessing
62
+ in_tensor = self.pre_process(scaled) # (C, H, W)
63
+ out_tensor = torch.zeros_like(in_tensor) # (C, H, W)
64
+
65
+ # slide a window from left-to-right and top-to-bottom
66
+ for y in range(0, scale_h - self.input_size + 1, self.label_size):
67
+ for x in range(0, scale_w - self.input_size + 1, self.label_size):
68
+ # crop ROI from our scaled image
69
+ crop = in_tensor[:, y : y + self.input_size, x : x + self.input_size]
70
+ # make a prediction on the crop and store it in our output
71
+ crop_inp = crop.unsqueeze(0).to(self.device)
72
+ pred = self.forward(crop_inp).cpu().squeeze()
73
+ out_tensor[
74
+ :,
75
+ y + self.pad : y + self.pad + self.label_size,
76
+ x + self.pad : x + self.pad + self.label_size,
77
+ ] = pred
78
+
79
+ out_tensor = self.post_process(out_tensor)
80
+ output = out_tensor.permute(1, 2, 0).numpy() # (C, H, W) to (H, W, C)
81
+ output = output[self.pad : -self.pad * 2, self.pad : -self.pad * 2]
82
+ output = np.clip(output, 0, 255).astype("uint8")
83
+
84
+ # Use openCV to upsample image if scaling factor different than 2
85
+ if outscale != 2:
86
+ interpolation = cv2.INTER_AREA if outscale < 2 else cv2.INTER_LANCZOS4
87
+ h, w = output.shape[0:2]
88
+ output = cv2.resize(
89
+ output,
90
+ (int(w * outscale / 2), int(h * outscale / 2)),
91
+ interpolation=interpolation,
92
+ )
93
+
94
+ return output, None
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .container {
2
+ max-width: 1368px;
3
+ margin-left: auto;
4
+ margin-right: auto;
5
+ }
6
+
7
+ #row-flex {
8
+ display: flex;
9
+ align-items: center;
10
+ justify-content: center;
11
+ }
upsample.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import os
4
+
5
+ from imutils import paths
6
+ from tqdm import tqdm
7
+ from config import *
8
+ from utils import get_face_enhancer, get_upsampler
9
+
10
+
11
+ def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
12
+ if scale > 4:
13
+ scale = 4 # avoid too large scale value
14
+ try:
15
+ img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
16
+
17
+ h, w = img.shape[0:2]
18
+ if h > 3500 or w > 3500:
19
+ output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
20
+ return output
21
+
22
+ if (h < 300 and w < 300) and upsampler_name != "srcnn":
23
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
24
+ return img
25
+
26
+ upsampler = get_upsampler(upsampler_name, device=device)
27
+
28
+ if face_enhancer_name:
29
+ face_enhancer = get_face_enhancer(
30
+ face_enhancer_name, scale, upsampler, device=device
31
+ )
32
+ else:
33
+ face_enhancer = None
34
+
35
+ try:
36
+ if face_enhancer is not None:
37
+ _, _, output = face_enhancer.enhance(
38
+ img, has_aligned=False, only_center_face=False, paste_back=True
39
+ )
40
+ else:
41
+ output, _ = upsampler.enhance(img, outscale=scale)
42
+ except RuntimeError as error:
43
+ print(f"Runtime error: {error}")
44
+
45
+ return output
46
+ except Exception as error:
47
+ print(f"global exception: {error}")
48
+
49
+
50
+ def main(args: argparse.Namespace) -> None:
51
+ device = args.device
52
+ scale = args.scale
53
+
54
+ upsampler_name = args.upsampler
55
+ face_enhancer_name = args.face_enhancer
56
+
57
+ if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
58
+ print(
59
+ "Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
60
+ )
61
+ face_enhancer_name = None
62
+
63
+ os.makedirs(args.output, exist_ok=True)
64
+ if not os.path.exists(args.input):
65
+ raise ValueError("The input directory doesn't exist!")
66
+ elif not os.path.isdir(args.input):
67
+ image_paths = [args.input]
68
+ else:
69
+ image_paths = paths.list_images(args.input)
70
+
71
+ with tqdm(image_paths) as pbar:
72
+ for image_path in pbar:
73
+ filename = os.path.basename(image_path)
74
+ pbar.set_postfix_str(f"Processing {image_path}")
75
+ upsampled_image = process(
76
+ image_path=image_path,
77
+ upsampler_name=upsampler_name,
78
+ face_enhancer_name=face_enhancer_name,
79
+ scale=scale,
80
+ device=device,
81
+ )
82
+ if upsampled_image is not None:
83
+ save_path = os.path.join(args.output, filename)
84
+ cv2.imwrite(save_path, upsampled_image)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ parser = argparse.ArgumentParser(
89
+ description=(
90
+ "Runs automatic detection and mask generation on an input image or directory of images"
91
+ )
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--input",
96
+ "-i",
97
+ type=str,
98
+ required=True,
99
+ help="Path to either a single input image or folder of images.",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--output",
104
+ "-o",
105
+ type=str,
106
+ required=True,
107
+ help="Path to the output directory.",
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--upsampler",
112
+ type=str,
113
+ default="realesr-general-x4v3",
114
+ choices=[
115
+ "srcnn",
116
+ "RealESRGAN_x2plus",
117
+ "RealESRGAN_x4plus",
118
+ "RealESRNet_x4plus",
119
+ "realesr-general-x4v3",
120
+ "RealESRGAN_x4plus_anime_6B",
121
+ "realesr-animevideov3",
122
+ ],
123
+ help="The type of upsampler model to load",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--face-enhancer",
128
+ type=str,
129
+ choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
130
+ help="The type of face enhancer model to load",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--scale",
135
+ type=float,
136
+ default=2,
137
+ choices=[1.5, 2, 2.5, 3, 3.5, 4],
138
+ help="scaling factor",
139
+ )
140
+ parser.add_argument(
141
+ "--device", type=str, default="cuda", help="The device to run upsampling on."
142
+ )
143
+ args = parser.parse_args()
144
+ main(args)
utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from basicsr.utils.download_util import load_file_from_url
4
+ from basicsr.archs.rrdbnet_arch import RRDBNet
5
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
6
+ from gfpgan.utils import GFPGANer
7
+ from realesrgan.utils import RealESRGANer
8
+
9
+ from config import *
10
+ from srcnn import SRCNN
11
+
12
+
13
+ def get_upsampler(model_name, device=None):
14
+ if model_name == "RealESRGAN_x4plus": # x4 RRDBNet model
15
+ model = RRDBNet(
16
+ num_in_ch=3,
17
+ num_out_ch=3,
18
+ num_feat=64,
19
+ num_block=23,
20
+ num_grow_ch=32,
21
+ scale=4,
22
+ )
23
+ netscale = 4
24
+ file_url = [
25
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
26
+ ]
27
+ elif model_name == "RealESRNet_x4plus": # x4 RRDBNet model
28
+ model = RRDBNet(
29
+ num_in_ch=3,
30
+ num_out_ch=3,
31
+ num_feat=64,
32
+ num_block=23,
33
+ num_grow_ch=32,
34
+ scale=4,
35
+ )
36
+ netscale = 4
37
+ file_url = [
38
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
39
+ ]
40
+ elif model_name == "RealESRGAN_x4plus_anime_6B": # x4 RRDBNet model with 6 blocks
41
+ model = RRDBNet(
42
+ num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4
43
+ )
44
+ netscale = 4
45
+ file_url = [
46
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
47
+ ]
48
+ elif model_name == "RealESRGAN_x2plus": # x2 RRDBNet model
49
+ model = RRDBNet(
50
+ num_in_ch=3,
51
+ num_out_ch=3,
52
+ num_feat=64,
53
+ num_block=23,
54
+ num_grow_ch=32,
55
+ scale=2,
56
+ )
57
+ netscale = 2
58
+ file_url = [
59
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
60
+ ]
61
+ elif model_name == "realesr-animevideov3": # x4 VGG-style model (XS size)
62
+ model = SRVGGNetCompact(
63
+ num_in_ch=3,
64
+ num_out_ch=3,
65
+ num_feat=64,
66
+ num_conv=16,
67
+ upscale=4,
68
+ act_type="prelu",
69
+ )
70
+ netscale = 4
71
+ file_url = [
72
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth"
73
+ ]
74
+ elif model_name == "realesr-general-x4v3": # x4 VGG-style model (S size)
75
+ model = SRVGGNetCompact(
76
+ num_in_ch=3,
77
+ num_out_ch=3,
78
+ num_feat=64,
79
+ num_conv=32,
80
+ upscale=4,
81
+ act_type="prelu",
82
+ )
83
+ netscale = 4
84
+ file_url = [
85
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
86
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
87
+ ]
88
+ elif model_name == "srcnn":
89
+ model = SRCNN(device=device)
90
+ model_path = os.path.join(ROOT_DIR, WEIGHT_DIR, model_name + ".pth")
91
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
92
+ if device:
93
+ model.to(device)
94
+ return model
95
+ else:
96
+ raise ValueError(f"Wrong model version {model_name}.")
97
+
98
+ model_path = os.path.join(ROOT_DIR, WEIGHT_DIR, model_name + ".pth")
99
+ if not os.path.exists(model_path):
100
+ print(f"Downloading weights for model {model_name}")
101
+
102
+ for url in file_url:
103
+ # model_path will be updated
104
+ model_path = load_file_from_url(
105
+ url=url,
106
+ model_dir=os.path.join(ROOT_DIR, WEIGHT_DIR),
107
+ progress=True,
108
+ file_name=None,
109
+ )
110
+
111
+ if model_name != "realesr-general-x4v3":
112
+ dni_weight = None
113
+ else:
114
+ dni_weight = [0.5, 0.5]
115
+ wdn_model_path = model_path.replace(
116
+ "realesr-general-x4v3", "realesr-general-wdn-x4v3"
117
+ )
118
+ model_path = [model_path, wdn_model_path]
119
+
120
+ half = "cuda" in str(device)
121
+
122
+ return RealESRGANer(
123
+ scale=netscale,
124
+ model_path=model_path,
125
+ dni_weight=dni_weight,
126
+ model=model,
127
+ half=half,
128
+ device=device,
129
+ )
130
+
131
+
132
+ def get_face_enhancer(model_name, upscale=2, bg_upsampler=None, device=None):
133
+ if model_name == "GFPGANv1.3":
134
+ arch = "clean"
135
+ channel_multiplier = 2
136
+ file_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
137
+ elif model_name == "GFPGANv1.4":
138
+ arch = "clean"
139
+ channel_multiplier = 2
140
+ file_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
141
+ elif model_name == "RestoreFormer":
142
+ arch = "RestoreFormer"
143
+ channel_multiplier = 2
144
+ file_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
145
+ else:
146
+ raise ValueError(f"Wrong model version {model_name}.")
147
+
148
+ model_path = os.path.join(ROOT_DIR, WEIGHT_DIR, model_name + ".pth")
149
+ if not os.path.exists(model_path):
150
+ print(f"Downloading weights for model {model_name}")
151
+ model_path = load_file_from_url(
152
+ url=file_url,
153
+ model_dir=os.path.join(ROOT_DIR, WEIGHT_DIR),
154
+ progress=True,
155
+ file_name=None,
156
+ )
157
+
158
+ return GFPGANer(
159
+ model_path=model_path,
160
+ upscale=upscale,
161
+ arch=arch,
162
+ channel_multiplier=channel_multiplier,
163
+ bg_upsampler=bg_upsampler,
164
+ device=device,
165
+ )
weights/srcnn.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df34daf51c338db25f197d8abcd003c0b3f109c1e8b4ca33d111862b30437bf3
3
+ size 82119