image enhanced api developed
Browse files- .gitignore +143 -0
- README.md +95 -0
- Real-ESRGAN/gfpgan/weights/detection_Resnet50_Final.pth +3 -0
- Real-ESRGAN/gfpgan/weights/parsing_parsenet.pth +3 -0
- Real-ESRGAN/inference_realesrgan.py +166 -0
- Real-ESRGAN/options/finetune_realesrgan_x4plus.yml +188 -0
- Real-ESRGAN/options/finetune_realesrgan_x4plus_pairdata.yml +150 -0
- Real-ESRGAN/options/train_realesrgan_x2plus.yml +186 -0
- Real-ESRGAN/options/train_realesrgan_x4plus.yml +185 -0
- Real-ESRGAN/options/train_realesrnet_x2plus.yml +145 -0
- Real-ESRGAN/options/train_realesrnet_x4plus.yml +144 -0
- Real-ESRGAN/realesrgan/__init__.py +6 -0
- Real-ESRGAN/realesrgan/archs/__init__.py +10 -0
- Real-ESRGAN/realesrgan/archs/discriminator_arch.py +67 -0
- Real-ESRGAN/realesrgan/archs/srvgg_arch.py +69 -0
- Real-ESRGAN/realesrgan/data/__init__.py +10 -0
- Real-ESRGAN/realesrgan/data/realesrgan_dataset.py +192 -0
- Real-ESRGAN/realesrgan/data/realesrgan_paired_dataset.py +108 -0
- Real-ESRGAN/realesrgan/models/__init__.py +10 -0
- Real-ESRGAN/realesrgan/models/realesrgan_model.py +258 -0
- Real-ESRGAN/realesrgan/models/realesrnet_model.py +188 -0
- Real-ESRGAN/realesrgan/train.py +11 -0
- Real-ESRGAN/realesrgan/utils.py +313 -0
- Real-ESRGAN/realesrgan/version.py +5 -0
- Real-ESRGAN/scripts/extract_subimages.py +135 -0
- Real-ESRGAN/scripts/generate_meta_info.py +58 -0
- Real-ESRGAN/scripts/generate_meta_info_pairdata.py +49 -0
- Real-ESRGAN/scripts/generate_multiscale_DF2K.py +48 -0
- Real-ESRGAN/scripts/pytorch2onnx.py +36 -0
- Real-ESRGAN/setup.py +107 -0
- Real-ESRGAN/weights/README.md +3 -0
- Real-ESRGAN/weights/RealESRGAN_x2plus.pth +3 -0
- Real-ESRGAN/weights/RealESRGAN_x4plus.pth +3 -0
- api.py +377 -0
- app.py +273 -0
- environment.yml +23 -0
- run.py +352 -0
.gitignore
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
*.po
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
.python-version
|
87 |
+
|
88 |
+
# pipenv
|
89 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
90 |
+
# Pipfile.lock
|
91 |
+
|
92 |
+
# poetry
|
93 |
+
# Poetry explicitly recommends committing the poetry.lock file
|
94 |
+
# poetry.lock
|
95 |
+
|
96 |
+
# pdm
|
97 |
+
# According to pdm-project/pdm#368, it is recommended to include pdm.lock in version control.
|
98 |
+
# pdm.lock
|
99 |
+
|
100 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
101 |
+
__pypackages__/
|
102 |
+
|
103 |
+
# Celery stuff
|
104 |
+
celerybeat-schedule
|
105 |
+
celerybeat.pid
|
106 |
+
|
107 |
+
# SageMath parsed files
|
108 |
+
*.sage.py
|
109 |
+
|
110 |
+
# Environments
|
111 |
+
.env
|
112 |
+
.venv
|
113 |
+
env/
|
114 |
+
venv/
|
115 |
+
ENV/
|
116 |
+
env.bak/
|
117 |
+
venv.bak/
|
118 |
+
|
119 |
+
# Spyder project settings
|
120 |
+
.spyderproject
|
121 |
+
.spyproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
|
126 |
+
# mkdocs documentation
|
127 |
+
/site
|
128 |
+
|
129 |
+
# mypy
|
130 |
+
.mypy_cache/
|
131 |
+
.dmypy.json
|
132 |
+
dmypy.json
|
133 |
+
|
134 |
+
# Pyre type checker
|
135 |
+
.pyre/
|
136 |
+
|
137 |
+
# pytype static analysis results
|
138 |
+
.pytype/
|
139 |
+
|
140 |
+
# Cython debug symbols
|
141 |
+
cython_debug/
|
142 |
+
api_inputs/
|
143 |
+
api_outputs/
|
README.md
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Image Enhancer
|
2 |
+
|
3 |
+
High-resolution image enhancement powered by AI.
|
4 |
+
|
5 |
+
## Before / After Example
|
6 |
+
|
7 |
+
| Before | After |
|
8 |
+
| ---------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------- |
|
9 |
+
|  |  |
|
10 |
+
|
11 |
+
## Description
|
12 |
+
|
13 |
+
This application provides a simple web interface to enhance images using an AI upscaling model. Upload your low-resolution images and get high-resolution results.
|
14 |
+
|
15 |
+
## Features
|
16 |
+
|
17 |
+
- Web UI for easy image uploading and enhancement.
|
18 |
+
- API endpoint for programmatic access (`/enhance`).
|
19 |
+
- Selectable upscaling models and scale factors.
|
20 |
+
- Optional face enhancement.
|
21 |
+
- Input and output images are saved in `api_inputs` and `api_outputs` respectively.
|
22 |
+
|
23 |
+
## Installation
|
24 |
+
|
25 |
+
1. **Clone the repository:**
|
26 |
+
```bash
|
27 |
+
git clone https://huggingface.co/sayed99/Image-Enhancer
|
28 |
+
cd Image-Enhancer
|
29 |
+
```
|
30 |
+
2. **Create Conda Environment:**
|
31 |
+
Set up the necessary environment using the provided `environment.yml` file:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
conda env create -f environment.yml
|
35 |
+
conda activate esrgan-env
|
36 |
+
```
|
37 |
+
|
38 |
+
_(Note: The environment name is defined within the `environment.yml` file)_
|
39 |
+
|
40 |
+
3. **Download Model Weights:**
|
41 |
+
The required model weights (`.pth` files) need to be placed in the `Real-ESRGAN/weights/` directory. Common models include:
|
42 |
+
- `RealESRGAN_x4plus.pth`
|
43 |
+
- `RealESRGAN_x2plus.pth`
|
44 |
+
You can usually find these linked from the original Real-ESRGAN repository or other model sources.
|
45 |
+
|
46 |
+
## Usage
|
47 |
+
|
48 |
+
Run the application using the provided script:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
python run.py
|
52 |
+
```
|
53 |
+
|
54 |
+
This will:
|
55 |
+
|
56 |
+
1. Start the backend API server (usually on `http://localhost:8000`).
|
57 |
+
2. Start the Streamlit web interface (usually on `http://localhost:8501`).
|
58 |
+
3. Open the web interface in your default browser.
|
59 |
+
|
60 |
+
Navigate to the web interface, upload an image, select your desired options (model, scale, face enhancement), and click "Enhance Image".
|
61 |
+
|
62 |
+
## API Usage
|
63 |
+
|
64 |
+
You can also interact with the API directly.
|
65 |
+
|
66 |
+
**Enhance Endpoint:** `POST /enhance/`
|
67 |
+
|
68 |
+
**Form Data:**
|
69 |
+
|
70 |
+
- `file`: The image file to upload.
|
71 |
+
- `model_name` (optional, default: `RealESRGAN_x4plus`): Model to use (e.g., `RealESRGAN_x4plus`, `RealESRGAN_x2plus`).
|
72 |
+
- `outscale` (optional, default: `4.0`): The desired output scale factor (e.g., `2.0`, `4.0`).
|
73 |
+
- `face_enhance` (optional, default: `false`): Boolean flag to enable face enhancement.
|
74 |
+
- `fp32` (optional, default: `false`): Boolean flag to use FP32 precision.
|
75 |
+
|
76 |
+
**Example using `curl`:**
|
77 |
+
|
78 |
+
```bash
|
79 |
+
curl -X POST "http://localhost:8000/enhance/" \
|
80 |
+
-F "file=@/path/to/your/image.jpg" \
|
81 |
+
-F "model_name=RealESRGAN_x4plus" \
|
82 |
+
-F "outscale=4.0" \
|
83 |
+
-o enhanced_image.jpg
|
84 |
+
```
|
85 |
+
|
86 |
+
## Notes
|
87 |
+
|
88 |
+
- Ensure the API server is running before using the Streamlit app or sending direct API requests.
|
89 |
+
- The application uses significant resources (RAM, potentially GPU if configured).
|
90 |
+
- Input images are saved in the `api_inputs` directory.
|
91 |
+
- Output images are saved in subdirectories within the `api_outputs` directory, named by a unique request ID.
|
92 |
+
|
93 |
+
---
|
94 |
+
|
95 |
+
Powered by Real-ESRGAN
|
Real-ESRGAN/gfpgan/weights/detection_Resnet50_Final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
3 |
+
size 109497761
|
Real-ESRGAN/gfpgan/weights/parsing_parsenet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
|
3 |
+
size 85331193
|
Real-ESRGAN/inference_realesrgan.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
6 |
+
from basicsr.utils.download_util import load_file_from_url
|
7 |
+
|
8 |
+
from realesrgan import RealESRGANer
|
9 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
"""Inference demo for Real-ESRGAN.
|
14 |
+
"""
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
|
17 |
+
parser.add_argument(
|
18 |
+
'-n',
|
19 |
+
'--model_name',
|
20 |
+
type=str,
|
21 |
+
default='RealESRGAN_x4plus',
|
22 |
+
help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus | '
|
23 |
+
'realesr-animevideov3 | realesr-general-x4v3'))
|
24 |
+
parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
|
25 |
+
parser.add_argument(
|
26 |
+
'-dn',
|
27 |
+
'--denoise_strength',
|
28 |
+
type=float,
|
29 |
+
default=0.5,
|
30 |
+
help=('Denoise strength. 0 for weak denoise (keep noise), 1 for strong denoise ability. '
|
31 |
+
'Only used for the realesr-general-x4v3 model'))
|
32 |
+
parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
|
33 |
+
parser.add_argument(
|
34 |
+
'--model_path', type=str, default=None, help='[Option] Model path. Usually, you do not need to specify it')
|
35 |
+
parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
|
36 |
+
parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
|
37 |
+
parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
|
38 |
+
parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
|
39 |
+
parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
|
40 |
+
parser.add_argument(
|
41 |
+
'--fp32', action='store_true', help='Use fp32 precision during inference. Default: fp16 (half precision).')
|
42 |
+
parser.add_argument(
|
43 |
+
'--alpha_upsampler',
|
44 |
+
type=str,
|
45 |
+
default='realesrgan',
|
46 |
+
help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
|
47 |
+
parser.add_argument(
|
48 |
+
'--ext',
|
49 |
+
type=str,
|
50 |
+
default='auto',
|
51 |
+
help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
|
52 |
+
parser.add_argument(
|
53 |
+
'-g', '--gpu-id', type=int, default=None, help='gpu device to use (default=None) can be 0,1,2 for multi-gpu')
|
54 |
+
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
# determine models according to model names
|
58 |
+
args.model_name = args.model_name.split('.')[0]
|
59 |
+
if args.model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
|
60 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
61 |
+
netscale = 4
|
62 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
63 |
+
elif args.model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
|
64 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
65 |
+
netscale = 4
|
66 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
|
67 |
+
elif args.model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
|
68 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
69 |
+
netscale = 4
|
70 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
|
71 |
+
elif args.model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
|
72 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
73 |
+
netscale = 2
|
74 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
|
75 |
+
elif args.model_name == 'realesr-animevideov3': # x4 VGG-style model (XS size)
|
76 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
77 |
+
netscale = 4
|
78 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth']
|
79 |
+
elif args.model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
|
80 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
81 |
+
netscale = 4
|
82 |
+
file_url = [
|
83 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
|
84 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
85 |
+
]
|
86 |
+
|
87 |
+
# determine model paths
|
88 |
+
if args.model_path is not None:
|
89 |
+
model_path = args.model_path
|
90 |
+
else:
|
91 |
+
model_path = os.path.join('weights', args.model_name + '.pth')
|
92 |
+
if not os.path.isfile(model_path):
|
93 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
94 |
+
for url in file_url:
|
95 |
+
# model_path will be updated
|
96 |
+
model_path = load_file_from_url(
|
97 |
+
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
98 |
+
|
99 |
+
# use dni to control the denoise strength
|
100 |
+
dni_weight = None
|
101 |
+
if args.model_name == 'realesr-general-x4v3' and args.denoise_strength != 1:
|
102 |
+
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
103 |
+
model_path = [model_path, wdn_model_path]
|
104 |
+
dni_weight = [args.denoise_strength, 1 - args.denoise_strength]
|
105 |
+
|
106 |
+
# restorer
|
107 |
+
upsampler = RealESRGANer(
|
108 |
+
scale=netscale,
|
109 |
+
model_path=model_path,
|
110 |
+
dni_weight=dni_weight,
|
111 |
+
model=model,
|
112 |
+
tile=args.tile,
|
113 |
+
tile_pad=args.tile_pad,
|
114 |
+
pre_pad=args.pre_pad,
|
115 |
+
half=not args.fp32,
|
116 |
+
gpu_id=args.gpu_id)
|
117 |
+
|
118 |
+
if args.face_enhance: # Use GFPGAN for face enhancement
|
119 |
+
from gfpgan import GFPGANer
|
120 |
+
face_enhancer = GFPGANer(
|
121 |
+
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
122 |
+
upscale=args.outscale,
|
123 |
+
arch='clean',
|
124 |
+
channel_multiplier=2,
|
125 |
+
bg_upsampler=upsampler)
|
126 |
+
os.makedirs(args.output, exist_ok=True)
|
127 |
+
|
128 |
+
if os.path.isfile(args.input):
|
129 |
+
paths = [args.input]
|
130 |
+
else:
|
131 |
+
paths = sorted(glob.glob(os.path.join(args.input, '*')))
|
132 |
+
|
133 |
+
for idx, path in enumerate(paths):
|
134 |
+
imgname, extension = os.path.splitext(os.path.basename(path))
|
135 |
+
print('Testing', idx, imgname)
|
136 |
+
|
137 |
+
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
138 |
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
139 |
+
img_mode = 'RGBA'
|
140 |
+
else:
|
141 |
+
img_mode = None
|
142 |
+
|
143 |
+
try:
|
144 |
+
if args.face_enhance:
|
145 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
146 |
+
else:
|
147 |
+
output, _ = upsampler.enhance(img, outscale=args.outscale)
|
148 |
+
except RuntimeError as error:
|
149 |
+
print('Error', error)
|
150 |
+
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
151 |
+
else:
|
152 |
+
if args.ext == 'auto':
|
153 |
+
extension = extension[1:]
|
154 |
+
else:
|
155 |
+
extension = args.ext
|
156 |
+
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
157 |
+
extension = 'png'
|
158 |
+
if args.suffix == '':
|
159 |
+
save_path = os.path.join(args.output, f'{imgname}.{extension}')
|
160 |
+
else:
|
161 |
+
save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
|
162 |
+
cv2.imwrite(save_path, output)
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == '__main__':
|
166 |
+
main()
|
Real-ESRGAN/options/finetune_realesrgan_x4plus.yml
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: finetune_RealESRGANx4plus_400k
|
3 |
+
model_type: RealESRGANModel
|
4 |
+
scale: 4
|
5 |
+
num_gpu: auto
|
6 |
+
manual_seed: 0
|
7 |
+
|
8 |
+
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
9 |
+
# USM the ground-truth
|
10 |
+
l1_gt_usm: True
|
11 |
+
percep_gt_usm: True
|
12 |
+
gan_gt_usm: False
|
13 |
+
|
14 |
+
# the first degradation process
|
15 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
16 |
+
resize_range: [0.15, 1.5]
|
17 |
+
gaussian_noise_prob: 0.5
|
18 |
+
noise_range: [1, 30]
|
19 |
+
poisson_scale_range: [0.05, 3]
|
20 |
+
gray_noise_prob: 0.4
|
21 |
+
jpeg_range: [30, 95]
|
22 |
+
|
23 |
+
# the second degradation process
|
24 |
+
second_blur_prob: 0.8
|
25 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
26 |
+
resize_range2: [0.3, 1.2]
|
27 |
+
gaussian_noise_prob2: 0.5
|
28 |
+
noise_range2: [1, 25]
|
29 |
+
poisson_scale_range2: [0.05, 2.5]
|
30 |
+
gray_noise_prob2: 0.4
|
31 |
+
jpeg_range2: [30, 95]
|
32 |
+
|
33 |
+
gt_size: 256
|
34 |
+
queue_size: 180
|
35 |
+
|
36 |
+
# dataset and data loader settings
|
37 |
+
datasets:
|
38 |
+
train:
|
39 |
+
name: DF2K+OST
|
40 |
+
type: RealESRGANDataset
|
41 |
+
dataroot_gt: datasets/DF2K
|
42 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
43 |
+
io_backend:
|
44 |
+
type: disk
|
45 |
+
|
46 |
+
blur_kernel_size: 21
|
47 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
48 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
49 |
+
sinc_prob: 0.1
|
50 |
+
blur_sigma: [0.2, 3]
|
51 |
+
betag_range: [0.5, 4]
|
52 |
+
betap_range: [1, 2]
|
53 |
+
|
54 |
+
blur_kernel_size2: 21
|
55 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
56 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
57 |
+
sinc_prob2: 0.1
|
58 |
+
blur_sigma2: [0.2, 1.5]
|
59 |
+
betag_range2: [0.5, 4]
|
60 |
+
betap_range2: [1, 2]
|
61 |
+
|
62 |
+
final_sinc_prob: 0.8
|
63 |
+
|
64 |
+
gt_size: 256
|
65 |
+
use_hflip: True
|
66 |
+
use_rot: False
|
67 |
+
|
68 |
+
# data loader
|
69 |
+
use_shuffle: true
|
70 |
+
num_worker_per_gpu: 5
|
71 |
+
batch_size_per_gpu: 12
|
72 |
+
dataset_enlarge_ratio: 1
|
73 |
+
prefetch_mode: ~
|
74 |
+
|
75 |
+
# Uncomment these for validation
|
76 |
+
# val:
|
77 |
+
# name: validation
|
78 |
+
# type: PairedImageDataset
|
79 |
+
# dataroot_gt: path_to_gt
|
80 |
+
# dataroot_lq: path_to_lq
|
81 |
+
# io_backend:
|
82 |
+
# type: disk
|
83 |
+
|
84 |
+
# network structures
|
85 |
+
network_g:
|
86 |
+
type: RRDBNet
|
87 |
+
num_in_ch: 3
|
88 |
+
num_out_ch: 3
|
89 |
+
num_feat: 64
|
90 |
+
num_block: 23
|
91 |
+
num_grow_ch: 32
|
92 |
+
|
93 |
+
network_d:
|
94 |
+
type: UNetDiscriminatorSN
|
95 |
+
num_in_ch: 3
|
96 |
+
num_feat: 64
|
97 |
+
skip_connection: True
|
98 |
+
|
99 |
+
# path
|
100 |
+
path:
|
101 |
+
# use the pre-trained Real-ESRNet model
|
102 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
103 |
+
param_key_g: params_ema
|
104 |
+
strict_load_g: true
|
105 |
+
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
|
106 |
+
param_key_d: params
|
107 |
+
strict_load_d: true
|
108 |
+
resume_state: ~
|
109 |
+
|
110 |
+
# training settings
|
111 |
+
train:
|
112 |
+
ema_decay: 0.999
|
113 |
+
optim_g:
|
114 |
+
type: Adam
|
115 |
+
lr: !!float 1e-4
|
116 |
+
weight_decay: 0
|
117 |
+
betas: [0.9, 0.99]
|
118 |
+
optim_d:
|
119 |
+
type: Adam
|
120 |
+
lr: !!float 1e-4
|
121 |
+
weight_decay: 0
|
122 |
+
betas: [0.9, 0.99]
|
123 |
+
|
124 |
+
scheduler:
|
125 |
+
type: MultiStepLR
|
126 |
+
milestones: [400000]
|
127 |
+
gamma: 0.5
|
128 |
+
|
129 |
+
total_iter: 400000
|
130 |
+
warmup_iter: -1 # no warm up
|
131 |
+
|
132 |
+
# losses
|
133 |
+
pixel_opt:
|
134 |
+
type: L1Loss
|
135 |
+
loss_weight: 1.0
|
136 |
+
reduction: mean
|
137 |
+
# perceptual loss (content and style losses)
|
138 |
+
perceptual_opt:
|
139 |
+
type: PerceptualLoss
|
140 |
+
layer_weights:
|
141 |
+
# before relu
|
142 |
+
'conv1_2': 0.1
|
143 |
+
'conv2_2': 0.1
|
144 |
+
'conv3_4': 1
|
145 |
+
'conv4_4': 1
|
146 |
+
'conv5_4': 1
|
147 |
+
vgg_type: vgg19
|
148 |
+
use_input_norm: true
|
149 |
+
perceptual_weight: !!float 1.0
|
150 |
+
style_weight: 0
|
151 |
+
range_norm: false
|
152 |
+
criterion: l1
|
153 |
+
# gan loss
|
154 |
+
gan_opt:
|
155 |
+
type: GANLoss
|
156 |
+
gan_type: vanilla
|
157 |
+
real_label_val: 1.0
|
158 |
+
fake_label_val: 0.0
|
159 |
+
loss_weight: !!float 1e-1
|
160 |
+
|
161 |
+
net_d_iters: 1
|
162 |
+
net_d_init_iters: 0
|
163 |
+
|
164 |
+
# Uncomment these for validation
|
165 |
+
# validation settings
|
166 |
+
# val:
|
167 |
+
# val_freq: !!float 5e3
|
168 |
+
# save_img: True
|
169 |
+
|
170 |
+
# metrics:
|
171 |
+
# psnr: # metric name
|
172 |
+
# type: calculate_psnr
|
173 |
+
# crop_border: 4
|
174 |
+
# test_y_channel: false
|
175 |
+
|
176 |
+
# logging settings
|
177 |
+
logger:
|
178 |
+
print_freq: 100
|
179 |
+
save_checkpoint_freq: !!float 5e3
|
180 |
+
use_tb_logger: true
|
181 |
+
wandb:
|
182 |
+
project: ~
|
183 |
+
resume_id: ~
|
184 |
+
|
185 |
+
# dist training settings
|
186 |
+
dist_params:
|
187 |
+
backend: nccl
|
188 |
+
port: 29500
|
Real-ESRGAN/options/finetune_realesrgan_x4plus_pairdata.yml
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: finetune_RealESRGANx4plus_400k_pairdata
|
3 |
+
model_type: RealESRGANModel
|
4 |
+
scale: 4
|
5 |
+
num_gpu: auto
|
6 |
+
manual_seed: 0
|
7 |
+
|
8 |
+
# USM the ground-truth
|
9 |
+
l1_gt_usm: True
|
10 |
+
percep_gt_usm: True
|
11 |
+
gan_gt_usm: False
|
12 |
+
|
13 |
+
high_order_degradation: False # do not use the high-order degradation generation process
|
14 |
+
|
15 |
+
# dataset and data loader settings
|
16 |
+
datasets:
|
17 |
+
train:
|
18 |
+
name: DIV2K
|
19 |
+
type: RealESRGANPairedDataset
|
20 |
+
dataroot_gt: datasets/DF2K
|
21 |
+
dataroot_lq: datasets/DF2K
|
22 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
|
23 |
+
io_backend:
|
24 |
+
type: disk
|
25 |
+
|
26 |
+
gt_size: 256
|
27 |
+
use_hflip: True
|
28 |
+
use_rot: False
|
29 |
+
|
30 |
+
# data loader
|
31 |
+
use_shuffle: true
|
32 |
+
num_worker_per_gpu: 5
|
33 |
+
batch_size_per_gpu: 12
|
34 |
+
dataset_enlarge_ratio: 1
|
35 |
+
prefetch_mode: ~
|
36 |
+
|
37 |
+
# Uncomment these for validation
|
38 |
+
# val:
|
39 |
+
# name: validation
|
40 |
+
# type: PairedImageDataset
|
41 |
+
# dataroot_gt: path_to_gt
|
42 |
+
# dataroot_lq: path_to_lq
|
43 |
+
# io_backend:
|
44 |
+
# type: disk
|
45 |
+
|
46 |
+
# network structures
|
47 |
+
network_g:
|
48 |
+
type: RRDBNet
|
49 |
+
num_in_ch: 3
|
50 |
+
num_out_ch: 3
|
51 |
+
num_feat: 64
|
52 |
+
num_block: 23
|
53 |
+
num_grow_ch: 32
|
54 |
+
|
55 |
+
network_d:
|
56 |
+
type: UNetDiscriminatorSN
|
57 |
+
num_in_ch: 3
|
58 |
+
num_feat: 64
|
59 |
+
skip_connection: True
|
60 |
+
|
61 |
+
# path
|
62 |
+
path:
|
63 |
+
# use the pre-trained Real-ESRNet model
|
64 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
65 |
+
param_key_g: params_ema
|
66 |
+
strict_load_g: true
|
67 |
+
pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
|
68 |
+
param_key_d: params
|
69 |
+
strict_load_d: true
|
70 |
+
resume_state: ~
|
71 |
+
|
72 |
+
# training settings
|
73 |
+
train:
|
74 |
+
ema_decay: 0.999
|
75 |
+
optim_g:
|
76 |
+
type: Adam
|
77 |
+
lr: !!float 1e-4
|
78 |
+
weight_decay: 0
|
79 |
+
betas: [0.9, 0.99]
|
80 |
+
optim_d:
|
81 |
+
type: Adam
|
82 |
+
lr: !!float 1e-4
|
83 |
+
weight_decay: 0
|
84 |
+
betas: [0.9, 0.99]
|
85 |
+
|
86 |
+
scheduler:
|
87 |
+
type: MultiStepLR
|
88 |
+
milestones: [400000]
|
89 |
+
gamma: 0.5
|
90 |
+
|
91 |
+
total_iter: 400000
|
92 |
+
warmup_iter: -1 # no warm up
|
93 |
+
|
94 |
+
# losses
|
95 |
+
pixel_opt:
|
96 |
+
type: L1Loss
|
97 |
+
loss_weight: 1.0
|
98 |
+
reduction: mean
|
99 |
+
# perceptual loss (content and style losses)
|
100 |
+
perceptual_opt:
|
101 |
+
type: PerceptualLoss
|
102 |
+
layer_weights:
|
103 |
+
# before relu
|
104 |
+
'conv1_2': 0.1
|
105 |
+
'conv2_2': 0.1
|
106 |
+
'conv3_4': 1
|
107 |
+
'conv4_4': 1
|
108 |
+
'conv5_4': 1
|
109 |
+
vgg_type: vgg19
|
110 |
+
use_input_norm: true
|
111 |
+
perceptual_weight: !!float 1.0
|
112 |
+
style_weight: 0
|
113 |
+
range_norm: false
|
114 |
+
criterion: l1
|
115 |
+
# gan loss
|
116 |
+
gan_opt:
|
117 |
+
type: GANLoss
|
118 |
+
gan_type: vanilla
|
119 |
+
real_label_val: 1.0
|
120 |
+
fake_label_val: 0.0
|
121 |
+
loss_weight: !!float 1e-1
|
122 |
+
|
123 |
+
net_d_iters: 1
|
124 |
+
net_d_init_iters: 0
|
125 |
+
|
126 |
+
# Uncomment these for validation
|
127 |
+
# validation settings
|
128 |
+
# val:
|
129 |
+
# val_freq: !!float 5e3
|
130 |
+
# save_img: True
|
131 |
+
|
132 |
+
# metrics:
|
133 |
+
# psnr: # metric name
|
134 |
+
# type: calculate_psnr
|
135 |
+
# crop_border: 4
|
136 |
+
# test_y_channel: false
|
137 |
+
|
138 |
+
# logging settings
|
139 |
+
logger:
|
140 |
+
print_freq: 100
|
141 |
+
save_checkpoint_freq: !!float 5e3
|
142 |
+
use_tb_logger: true
|
143 |
+
wandb:
|
144 |
+
project: ~
|
145 |
+
resume_id: ~
|
146 |
+
|
147 |
+
# dist training settings
|
148 |
+
dist_params:
|
149 |
+
backend: nccl
|
150 |
+
port: 29500
|
Real-ESRGAN/options/train_realesrgan_x2plus.yml
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: train_RealESRGANx2plus_400k_B12G4
|
3 |
+
model_type: RealESRGANModel
|
4 |
+
scale: 2
|
5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
6 |
+
manual_seed: 0
|
7 |
+
|
8 |
+
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
9 |
+
# USM the ground-truth
|
10 |
+
l1_gt_usm: True
|
11 |
+
percep_gt_usm: True
|
12 |
+
gan_gt_usm: False
|
13 |
+
|
14 |
+
# the first degradation process
|
15 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
16 |
+
resize_range: [0.15, 1.5]
|
17 |
+
gaussian_noise_prob: 0.5
|
18 |
+
noise_range: [1, 30]
|
19 |
+
poisson_scale_range: [0.05, 3]
|
20 |
+
gray_noise_prob: 0.4
|
21 |
+
jpeg_range: [30, 95]
|
22 |
+
|
23 |
+
# the second degradation process
|
24 |
+
second_blur_prob: 0.8
|
25 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
26 |
+
resize_range2: [0.3, 1.2]
|
27 |
+
gaussian_noise_prob2: 0.5
|
28 |
+
noise_range2: [1, 25]
|
29 |
+
poisson_scale_range2: [0.05, 2.5]
|
30 |
+
gray_noise_prob2: 0.4
|
31 |
+
jpeg_range2: [30, 95]
|
32 |
+
|
33 |
+
gt_size: 256
|
34 |
+
queue_size: 180
|
35 |
+
|
36 |
+
# dataset and data loader settings
|
37 |
+
datasets:
|
38 |
+
train:
|
39 |
+
name: DF2K+OST
|
40 |
+
type: RealESRGANDataset
|
41 |
+
dataroot_gt: datasets/DF2K
|
42 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
43 |
+
io_backend:
|
44 |
+
type: disk
|
45 |
+
|
46 |
+
blur_kernel_size: 21
|
47 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
48 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
49 |
+
sinc_prob: 0.1
|
50 |
+
blur_sigma: [0.2, 3]
|
51 |
+
betag_range: [0.5, 4]
|
52 |
+
betap_range: [1, 2]
|
53 |
+
|
54 |
+
blur_kernel_size2: 21
|
55 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
56 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
57 |
+
sinc_prob2: 0.1
|
58 |
+
blur_sigma2: [0.2, 1.5]
|
59 |
+
betag_range2: [0.5, 4]
|
60 |
+
betap_range2: [1, 2]
|
61 |
+
|
62 |
+
final_sinc_prob: 0.8
|
63 |
+
|
64 |
+
gt_size: 256
|
65 |
+
use_hflip: True
|
66 |
+
use_rot: False
|
67 |
+
|
68 |
+
# data loader
|
69 |
+
use_shuffle: true
|
70 |
+
num_worker_per_gpu: 5
|
71 |
+
batch_size_per_gpu: 12
|
72 |
+
dataset_enlarge_ratio: 1
|
73 |
+
prefetch_mode: ~
|
74 |
+
|
75 |
+
# Uncomment these for validation
|
76 |
+
# val:
|
77 |
+
# name: validation
|
78 |
+
# type: PairedImageDataset
|
79 |
+
# dataroot_gt: path_to_gt
|
80 |
+
# dataroot_lq: path_to_lq
|
81 |
+
# io_backend:
|
82 |
+
# type: disk
|
83 |
+
|
84 |
+
# network structures
|
85 |
+
network_g:
|
86 |
+
type: RRDBNet
|
87 |
+
num_in_ch: 3
|
88 |
+
num_out_ch: 3
|
89 |
+
num_feat: 64
|
90 |
+
num_block: 23
|
91 |
+
num_grow_ch: 32
|
92 |
+
scale: 2
|
93 |
+
|
94 |
+
network_d:
|
95 |
+
type: UNetDiscriminatorSN
|
96 |
+
num_in_ch: 3
|
97 |
+
num_feat: 64
|
98 |
+
skip_connection: True
|
99 |
+
|
100 |
+
# path
|
101 |
+
path:
|
102 |
+
# use the pre-trained Real-ESRNet model
|
103 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
|
104 |
+
param_key_g: params_ema
|
105 |
+
strict_load_g: true
|
106 |
+
resume_state: ~
|
107 |
+
|
108 |
+
# training settings
|
109 |
+
train:
|
110 |
+
ema_decay: 0.999
|
111 |
+
optim_g:
|
112 |
+
type: Adam
|
113 |
+
lr: !!float 1e-4
|
114 |
+
weight_decay: 0
|
115 |
+
betas: [0.9, 0.99]
|
116 |
+
optim_d:
|
117 |
+
type: Adam
|
118 |
+
lr: !!float 1e-4
|
119 |
+
weight_decay: 0
|
120 |
+
betas: [0.9, 0.99]
|
121 |
+
|
122 |
+
scheduler:
|
123 |
+
type: MultiStepLR
|
124 |
+
milestones: [400000]
|
125 |
+
gamma: 0.5
|
126 |
+
|
127 |
+
total_iter: 400000
|
128 |
+
warmup_iter: -1 # no warm up
|
129 |
+
|
130 |
+
# losses
|
131 |
+
pixel_opt:
|
132 |
+
type: L1Loss
|
133 |
+
loss_weight: 1.0
|
134 |
+
reduction: mean
|
135 |
+
# perceptual loss (content and style losses)
|
136 |
+
perceptual_opt:
|
137 |
+
type: PerceptualLoss
|
138 |
+
layer_weights:
|
139 |
+
# before relu
|
140 |
+
'conv1_2': 0.1
|
141 |
+
'conv2_2': 0.1
|
142 |
+
'conv3_4': 1
|
143 |
+
'conv4_4': 1
|
144 |
+
'conv5_4': 1
|
145 |
+
vgg_type: vgg19
|
146 |
+
use_input_norm: true
|
147 |
+
perceptual_weight: !!float 1.0
|
148 |
+
style_weight: 0
|
149 |
+
range_norm: false
|
150 |
+
criterion: l1
|
151 |
+
# gan loss
|
152 |
+
gan_opt:
|
153 |
+
type: GANLoss
|
154 |
+
gan_type: vanilla
|
155 |
+
real_label_val: 1.0
|
156 |
+
fake_label_val: 0.0
|
157 |
+
loss_weight: !!float 1e-1
|
158 |
+
|
159 |
+
net_d_iters: 1
|
160 |
+
net_d_init_iters: 0
|
161 |
+
|
162 |
+
# Uncomment these for validation
|
163 |
+
# validation settings
|
164 |
+
# val:
|
165 |
+
# val_freq: !!float 5e3
|
166 |
+
# save_img: True
|
167 |
+
|
168 |
+
# metrics:
|
169 |
+
# psnr: # metric name
|
170 |
+
# type: calculate_psnr
|
171 |
+
# crop_border: 4
|
172 |
+
# test_y_channel: false
|
173 |
+
|
174 |
+
# logging settings
|
175 |
+
logger:
|
176 |
+
print_freq: 100
|
177 |
+
save_checkpoint_freq: !!float 5e3
|
178 |
+
use_tb_logger: true
|
179 |
+
wandb:
|
180 |
+
project: ~
|
181 |
+
resume_id: ~
|
182 |
+
|
183 |
+
# dist training settings
|
184 |
+
dist_params:
|
185 |
+
backend: nccl
|
186 |
+
port: 29500
|
Real-ESRGAN/options/train_realesrgan_x4plus.yml
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: train_RealESRGANx4plus_400k_B12G4
|
3 |
+
model_type: RealESRGANModel
|
4 |
+
scale: 4
|
5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
6 |
+
manual_seed: 0
|
7 |
+
|
8 |
+
# ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
|
9 |
+
# USM the ground-truth
|
10 |
+
l1_gt_usm: True
|
11 |
+
percep_gt_usm: True
|
12 |
+
gan_gt_usm: False
|
13 |
+
|
14 |
+
# the first degradation process
|
15 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
16 |
+
resize_range: [0.15, 1.5]
|
17 |
+
gaussian_noise_prob: 0.5
|
18 |
+
noise_range: [1, 30]
|
19 |
+
poisson_scale_range: [0.05, 3]
|
20 |
+
gray_noise_prob: 0.4
|
21 |
+
jpeg_range: [30, 95]
|
22 |
+
|
23 |
+
# the second degradation process
|
24 |
+
second_blur_prob: 0.8
|
25 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
26 |
+
resize_range2: [0.3, 1.2]
|
27 |
+
gaussian_noise_prob2: 0.5
|
28 |
+
noise_range2: [1, 25]
|
29 |
+
poisson_scale_range2: [0.05, 2.5]
|
30 |
+
gray_noise_prob2: 0.4
|
31 |
+
jpeg_range2: [30, 95]
|
32 |
+
|
33 |
+
gt_size: 256
|
34 |
+
queue_size: 180
|
35 |
+
|
36 |
+
# dataset and data loader settings
|
37 |
+
datasets:
|
38 |
+
train:
|
39 |
+
name: DF2K+OST
|
40 |
+
type: RealESRGANDataset
|
41 |
+
dataroot_gt: datasets/DF2K
|
42 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
43 |
+
io_backend:
|
44 |
+
type: disk
|
45 |
+
|
46 |
+
blur_kernel_size: 21
|
47 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
48 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
49 |
+
sinc_prob: 0.1
|
50 |
+
blur_sigma: [0.2, 3]
|
51 |
+
betag_range: [0.5, 4]
|
52 |
+
betap_range: [1, 2]
|
53 |
+
|
54 |
+
blur_kernel_size2: 21
|
55 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
56 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
57 |
+
sinc_prob2: 0.1
|
58 |
+
blur_sigma2: [0.2, 1.5]
|
59 |
+
betag_range2: [0.5, 4]
|
60 |
+
betap_range2: [1, 2]
|
61 |
+
|
62 |
+
final_sinc_prob: 0.8
|
63 |
+
|
64 |
+
gt_size: 256
|
65 |
+
use_hflip: True
|
66 |
+
use_rot: False
|
67 |
+
|
68 |
+
# data loader
|
69 |
+
use_shuffle: true
|
70 |
+
num_worker_per_gpu: 5
|
71 |
+
batch_size_per_gpu: 12
|
72 |
+
dataset_enlarge_ratio: 1
|
73 |
+
prefetch_mode: ~
|
74 |
+
|
75 |
+
# Uncomment these for validation
|
76 |
+
# val:
|
77 |
+
# name: validation
|
78 |
+
# type: PairedImageDataset
|
79 |
+
# dataroot_gt: path_to_gt
|
80 |
+
# dataroot_lq: path_to_lq
|
81 |
+
# io_backend:
|
82 |
+
# type: disk
|
83 |
+
|
84 |
+
# network structures
|
85 |
+
network_g:
|
86 |
+
type: RRDBNet
|
87 |
+
num_in_ch: 3
|
88 |
+
num_out_ch: 3
|
89 |
+
num_feat: 64
|
90 |
+
num_block: 23
|
91 |
+
num_grow_ch: 32
|
92 |
+
|
93 |
+
network_d:
|
94 |
+
type: UNetDiscriminatorSN
|
95 |
+
num_in_ch: 3
|
96 |
+
num_feat: 64
|
97 |
+
skip_connection: True
|
98 |
+
|
99 |
+
# path
|
100 |
+
path:
|
101 |
+
# use the pre-trained Real-ESRNet model
|
102 |
+
pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
|
103 |
+
param_key_g: params_ema
|
104 |
+
strict_load_g: true
|
105 |
+
resume_state: ~
|
106 |
+
|
107 |
+
# training settings
|
108 |
+
train:
|
109 |
+
ema_decay: 0.999
|
110 |
+
optim_g:
|
111 |
+
type: Adam
|
112 |
+
lr: !!float 1e-4
|
113 |
+
weight_decay: 0
|
114 |
+
betas: [0.9, 0.99]
|
115 |
+
optim_d:
|
116 |
+
type: Adam
|
117 |
+
lr: !!float 1e-4
|
118 |
+
weight_decay: 0
|
119 |
+
betas: [0.9, 0.99]
|
120 |
+
|
121 |
+
scheduler:
|
122 |
+
type: MultiStepLR
|
123 |
+
milestones: [400000]
|
124 |
+
gamma: 0.5
|
125 |
+
|
126 |
+
total_iter: 400000
|
127 |
+
warmup_iter: -1 # no warm up
|
128 |
+
|
129 |
+
# losses
|
130 |
+
pixel_opt:
|
131 |
+
type: L1Loss
|
132 |
+
loss_weight: 1.0
|
133 |
+
reduction: mean
|
134 |
+
# perceptual loss (content and style losses)
|
135 |
+
perceptual_opt:
|
136 |
+
type: PerceptualLoss
|
137 |
+
layer_weights:
|
138 |
+
# before relu
|
139 |
+
'conv1_2': 0.1
|
140 |
+
'conv2_2': 0.1
|
141 |
+
'conv3_4': 1
|
142 |
+
'conv4_4': 1
|
143 |
+
'conv5_4': 1
|
144 |
+
vgg_type: vgg19
|
145 |
+
use_input_norm: true
|
146 |
+
perceptual_weight: !!float 1.0
|
147 |
+
style_weight: 0
|
148 |
+
range_norm: false
|
149 |
+
criterion: l1
|
150 |
+
# gan loss
|
151 |
+
gan_opt:
|
152 |
+
type: GANLoss
|
153 |
+
gan_type: vanilla
|
154 |
+
real_label_val: 1.0
|
155 |
+
fake_label_val: 0.0
|
156 |
+
loss_weight: !!float 1e-1
|
157 |
+
|
158 |
+
net_d_iters: 1
|
159 |
+
net_d_init_iters: 0
|
160 |
+
|
161 |
+
# Uncomment these for validation
|
162 |
+
# validation settings
|
163 |
+
# val:
|
164 |
+
# val_freq: !!float 5e3
|
165 |
+
# save_img: True
|
166 |
+
|
167 |
+
# metrics:
|
168 |
+
# psnr: # metric name
|
169 |
+
# type: calculate_psnr
|
170 |
+
# crop_border: 4
|
171 |
+
# test_y_channel: false
|
172 |
+
|
173 |
+
# logging settings
|
174 |
+
logger:
|
175 |
+
print_freq: 100
|
176 |
+
save_checkpoint_freq: !!float 5e3
|
177 |
+
use_tb_logger: true
|
178 |
+
wandb:
|
179 |
+
project: ~
|
180 |
+
resume_id: ~
|
181 |
+
|
182 |
+
# dist training settings
|
183 |
+
dist_params:
|
184 |
+
backend: nccl
|
185 |
+
port: 29500
|
Real-ESRGAN/options/train_realesrnet_x2plus.yml
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: train_RealESRNetx2plus_1000k_B12G4
|
3 |
+
model_type: RealESRNetModel
|
4 |
+
scale: 2
|
5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
6 |
+
manual_seed: 0
|
7 |
+
|
8 |
+
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
9 |
+
gt_usm: True # USM the ground-truth
|
10 |
+
|
11 |
+
# the first degradation process
|
12 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
13 |
+
resize_range: [0.15, 1.5]
|
14 |
+
gaussian_noise_prob: 0.5
|
15 |
+
noise_range: [1, 30]
|
16 |
+
poisson_scale_range: [0.05, 3]
|
17 |
+
gray_noise_prob: 0.4
|
18 |
+
jpeg_range: [30, 95]
|
19 |
+
|
20 |
+
# the second degradation process
|
21 |
+
second_blur_prob: 0.8
|
22 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
23 |
+
resize_range2: [0.3, 1.2]
|
24 |
+
gaussian_noise_prob2: 0.5
|
25 |
+
noise_range2: [1, 25]
|
26 |
+
poisson_scale_range2: [0.05, 2.5]
|
27 |
+
gray_noise_prob2: 0.4
|
28 |
+
jpeg_range2: [30, 95]
|
29 |
+
|
30 |
+
gt_size: 256
|
31 |
+
queue_size: 180
|
32 |
+
|
33 |
+
# dataset and data loader settings
|
34 |
+
datasets:
|
35 |
+
train:
|
36 |
+
name: DF2K+OST
|
37 |
+
type: RealESRGANDataset
|
38 |
+
dataroot_gt: datasets/DF2K
|
39 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
40 |
+
io_backend:
|
41 |
+
type: disk
|
42 |
+
|
43 |
+
blur_kernel_size: 21
|
44 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
45 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
46 |
+
sinc_prob: 0.1
|
47 |
+
blur_sigma: [0.2, 3]
|
48 |
+
betag_range: [0.5, 4]
|
49 |
+
betap_range: [1, 2]
|
50 |
+
|
51 |
+
blur_kernel_size2: 21
|
52 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
53 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
54 |
+
sinc_prob2: 0.1
|
55 |
+
blur_sigma2: [0.2, 1.5]
|
56 |
+
betag_range2: [0.5, 4]
|
57 |
+
betap_range2: [1, 2]
|
58 |
+
|
59 |
+
final_sinc_prob: 0.8
|
60 |
+
|
61 |
+
gt_size: 256
|
62 |
+
use_hflip: True
|
63 |
+
use_rot: False
|
64 |
+
|
65 |
+
# data loader
|
66 |
+
use_shuffle: true
|
67 |
+
num_worker_per_gpu: 5
|
68 |
+
batch_size_per_gpu: 12
|
69 |
+
dataset_enlarge_ratio: 1
|
70 |
+
prefetch_mode: ~
|
71 |
+
|
72 |
+
# Uncomment these for validation
|
73 |
+
# val:
|
74 |
+
# name: validation
|
75 |
+
# type: PairedImageDataset
|
76 |
+
# dataroot_gt: path_to_gt
|
77 |
+
# dataroot_lq: path_to_lq
|
78 |
+
# io_backend:
|
79 |
+
# type: disk
|
80 |
+
|
81 |
+
# network structures
|
82 |
+
network_g:
|
83 |
+
type: RRDBNet
|
84 |
+
num_in_ch: 3
|
85 |
+
num_out_ch: 3
|
86 |
+
num_feat: 64
|
87 |
+
num_block: 23
|
88 |
+
num_grow_ch: 32
|
89 |
+
scale: 2
|
90 |
+
|
91 |
+
# path
|
92 |
+
path:
|
93 |
+
pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
|
94 |
+
param_key_g: params_ema
|
95 |
+
strict_load_g: False
|
96 |
+
resume_state: ~
|
97 |
+
|
98 |
+
# training settings
|
99 |
+
train:
|
100 |
+
ema_decay: 0.999
|
101 |
+
optim_g:
|
102 |
+
type: Adam
|
103 |
+
lr: !!float 2e-4
|
104 |
+
weight_decay: 0
|
105 |
+
betas: [0.9, 0.99]
|
106 |
+
|
107 |
+
scheduler:
|
108 |
+
type: MultiStepLR
|
109 |
+
milestones: [1000000]
|
110 |
+
gamma: 0.5
|
111 |
+
|
112 |
+
total_iter: 1000000
|
113 |
+
warmup_iter: -1 # no warm up
|
114 |
+
|
115 |
+
# losses
|
116 |
+
pixel_opt:
|
117 |
+
type: L1Loss
|
118 |
+
loss_weight: 1.0
|
119 |
+
reduction: mean
|
120 |
+
|
121 |
+
# Uncomment these for validation
|
122 |
+
# validation settings
|
123 |
+
# val:
|
124 |
+
# val_freq: !!float 5e3
|
125 |
+
# save_img: True
|
126 |
+
|
127 |
+
# metrics:
|
128 |
+
# psnr: # metric name
|
129 |
+
# type: calculate_psnr
|
130 |
+
# crop_border: 4
|
131 |
+
# test_y_channel: false
|
132 |
+
|
133 |
+
# logging settings
|
134 |
+
logger:
|
135 |
+
print_freq: 100
|
136 |
+
save_checkpoint_freq: !!float 5e3
|
137 |
+
use_tb_logger: true
|
138 |
+
wandb:
|
139 |
+
project: ~
|
140 |
+
resume_id: ~
|
141 |
+
|
142 |
+
# dist training settings
|
143 |
+
dist_params:
|
144 |
+
backend: nccl
|
145 |
+
port: 29500
|
Real-ESRGAN/options/train_realesrnet_x4plus.yml
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general settings
|
2 |
+
name: train_RealESRNetx4plus_1000k_B12G4
|
3 |
+
model_type: RealESRNetModel
|
4 |
+
scale: 4
|
5 |
+
num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
|
6 |
+
manual_seed: 0
|
7 |
+
|
8 |
+
# ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
|
9 |
+
gt_usm: True # USM the ground-truth
|
10 |
+
|
11 |
+
# the first degradation process
|
12 |
+
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
|
13 |
+
resize_range: [0.15, 1.5]
|
14 |
+
gaussian_noise_prob: 0.5
|
15 |
+
noise_range: [1, 30]
|
16 |
+
poisson_scale_range: [0.05, 3]
|
17 |
+
gray_noise_prob: 0.4
|
18 |
+
jpeg_range: [30, 95]
|
19 |
+
|
20 |
+
# the second degradation process
|
21 |
+
second_blur_prob: 0.8
|
22 |
+
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
|
23 |
+
resize_range2: [0.3, 1.2]
|
24 |
+
gaussian_noise_prob2: 0.5
|
25 |
+
noise_range2: [1, 25]
|
26 |
+
poisson_scale_range2: [0.05, 2.5]
|
27 |
+
gray_noise_prob2: 0.4
|
28 |
+
jpeg_range2: [30, 95]
|
29 |
+
|
30 |
+
gt_size: 256
|
31 |
+
queue_size: 180
|
32 |
+
|
33 |
+
# dataset and data loader settings
|
34 |
+
datasets:
|
35 |
+
train:
|
36 |
+
name: DF2K+OST
|
37 |
+
type: RealESRGANDataset
|
38 |
+
dataroot_gt: datasets/DF2K
|
39 |
+
meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
|
40 |
+
io_backend:
|
41 |
+
type: disk
|
42 |
+
|
43 |
+
blur_kernel_size: 21
|
44 |
+
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
45 |
+
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
46 |
+
sinc_prob: 0.1
|
47 |
+
blur_sigma: [0.2, 3]
|
48 |
+
betag_range: [0.5, 4]
|
49 |
+
betap_range: [1, 2]
|
50 |
+
|
51 |
+
blur_kernel_size2: 21
|
52 |
+
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
|
53 |
+
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
|
54 |
+
sinc_prob2: 0.1
|
55 |
+
blur_sigma2: [0.2, 1.5]
|
56 |
+
betag_range2: [0.5, 4]
|
57 |
+
betap_range2: [1, 2]
|
58 |
+
|
59 |
+
final_sinc_prob: 0.8
|
60 |
+
|
61 |
+
gt_size: 256
|
62 |
+
use_hflip: True
|
63 |
+
use_rot: False
|
64 |
+
|
65 |
+
# data loader
|
66 |
+
use_shuffle: true
|
67 |
+
num_worker_per_gpu: 5
|
68 |
+
batch_size_per_gpu: 12
|
69 |
+
dataset_enlarge_ratio: 1
|
70 |
+
prefetch_mode: ~
|
71 |
+
|
72 |
+
# Uncomment these for validation
|
73 |
+
# val:
|
74 |
+
# name: validation
|
75 |
+
# type: PairedImageDataset
|
76 |
+
# dataroot_gt: path_to_gt
|
77 |
+
# dataroot_lq: path_to_lq
|
78 |
+
# io_backend:
|
79 |
+
# type: disk
|
80 |
+
|
81 |
+
# network structures
|
82 |
+
network_g:
|
83 |
+
type: RRDBNet
|
84 |
+
num_in_ch: 3
|
85 |
+
num_out_ch: 3
|
86 |
+
num_feat: 64
|
87 |
+
num_block: 23
|
88 |
+
num_grow_ch: 32
|
89 |
+
|
90 |
+
# path
|
91 |
+
path:
|
92 |
+
pretrain_network_g: experiments/pretrained_models/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth
|
93 |
+
param_key_g: params_ema
|
94 |
+
strict_load_g: true
|
95 |
+
resume_state: ~
|
96 |
+
|
97 |
+
# training settings
|
98 |
+
train:
|
99 |
+
ema_decay: 0.999
|
100 |
+
optim_g:
|
101 |
+
type: Adam
|
102 |
+
lr: !!float 2e-4
|
103 |
+
weight_decay: 0
|
104 |
+
betas: [0.9, 0.99]
|
105 |
+
|
106 |
+
scheduler:
|
107 |
+
type: MultiStepLR
|
108 |
+
milestones: [1000000]
|
109 |
+
gamma: 0.5
|
110 |
+
|
111 |
+
total_iter: 1000000
|
112 |
+
warmup_iter: -1 # no warm up
|
113 |
+
|
114 |
+
# losses
|
115 |
+
pixel_opt:
|
116 |
+
type: L1Loss
|
117 |
+
loss_weight: 1.0
|
118 |
+
reduction: mean
|
119 |
+
|
120 |
+
# Uncomment these for validation
|
121 |
+
# validation settings
|
122 |
+
# val:
|
123 |
+
# val_freq: !!float 5e3
|
124 |
+
# save_img: True
|
125 |
+
|
126 |
+
# metrics:
|
127 |
+
# psnr: # metric name
|
128 |
+
# type: calculate_psnr
|
129 |
+
# crop_border: 4
|
130 |
+
# test_y_channel: false
|
131 |
+
|
132 |
+
# logging settings
|
133 |
+
logger:
|
134 |
+
print_freq: 100
|
135 |
+
save_checkpoint_freq: !!float 5e3
|
136 |
+
use_tb_logger: true
|
137 |
+
wandb:
|
138 |
+
project: ~
|
139 |
+
resume_id: ~
|
140 |
+
|
141 |
+
# dist training settings
|
142 |
+
dist_params:
|
143 |
+
backend: nccl
|
144 |
+
port: 29500
|
Real-ESRGAN/realesrgan/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
from .archs import *
|
3 |
+
from .data import *
|
4 |
+
from .models import *
|
5 |
+
from .utils import *
|
6 |
+
from .version import *
|
Real-ESRGAN/realesrgan/archs/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import arch modules for registry
|
6 |
+
# scan all the files that end with '_arch.py' under the archs folder
|
7 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
9 |
+
# import all the arch modules
|
10 |
+
_arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames]
|
Real-ESRGAN/realesrgan/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
|
6 |
+
|
7 |
+
@ARCH_REGISTRY.register()
|
8 |
+
class UNetDiscriminatorSN(nn.Module):
|
9 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
10 |
+
|
11 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
12 |
+
|
13 |
+
Arg:
|
14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
15 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
16 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
20 |
+
super(UNetDiscriminatorSN, self).__init__()
|
21 |
+
self.skip_connection = skip_connection
|
22 |
+
norm = spectral_norm
|
23 |
+
# the first convolution
|
24 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
25 |
+
# downsample
|
26 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
27 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
28 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
29 |
+
# upsample
|
30 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
31 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
32 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
33 |
+
# extra convolutions
|
34 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
35 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
36 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
# downsample
|
40 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
41 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
42 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
43 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
44 |
+
|
45 |
+
# upsample
|
46 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
47 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
48 |
+
|
49 |
+
if self.skip_connection:
|
50 |
+
x4 = x4 + x2
|
51 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
52 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
53 |
+
|
54 |
+
if self.skip_connection:
|
55 |
+
x5 = x5 + x1
|
56 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
57 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
58 |
+
|
59 |
+
if self.skip_connection:
|
60 |
+
x6 = x6 + x0
|
61 |
+
|
62 |
+
# extra convolutions
|
63 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
64 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
65 |
+
out = self.conv9(out)
|
66 |
+
|
67 |
+
return out
|
Real-ESRGAN/realesrgan/archs/srvgg_arch.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
@ARCH_REGISTRY.register()
|
7 |
+
class SRVGGNetCompact(nn.Module):
|
8 |
+
"""A compact VGG-style network structure for super-resolution.
|
9 |
+
|
10 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
11 |
+
conducted on the HR feature space.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
15 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
16 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
17 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
18 |
+
upscale (int): Upsampling factor. Default: 4.
|
19 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
23 |
+
super(SRVGGNetCompact, self).__init__()
|
24 |
+
self.num_in_ch = num_in_ch
|
25 |
+
self.num_out_ch = num_out_ch
|
26 |
+
self.num_feat = num_feat
|
27 |
+
self.num_conv = num_conv
|
28 |
+
self.upscale = upscale
|
29 |
+
self.act_type = act_type
|
30 |
+
|
31 |
+
self.body = nn.ModuleList()
|
32 |
+
# the first conv
|
33 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
34 |
+
# the first activation
|
35 |
+
if act_type == 'relu':
|
36 |
+
activation = nn.ReLU(inplace=True)
|
37 |
+
elif act_type == 'prelu':
|
38 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
39 |
+
elif act_type == 'leakyrelu':
|
40 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
41 |
+
self.body.append(activation)
|
42 |
+
|
43 |
+
# the body structure
|
44 |
+
for _ in range(num_conv):
|
45 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
46 |
+
# activation
|
47 |
+
if act_type == 'relu':
|
48 |
+
activation = nn.ReLU(inplace=True)
|
49 |
+
elif act_type == 'prelu':
|
50 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
51 |
+
elif act_type == 'leakyrelu':
|
52 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
53 |
+
self.body.append(activation)
|
54 |
+
|
55 |
+
# the last conv
|
56 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
57 |
+
# upsample
|
58 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
out = x
|
62 |
+
for i in range(0, len(self.body)):
|
63 |
+
out = self.body[i](out)
|
64 |
+
|
65 |
+
out = self.upsampler(out)
|
66 |
+
# add the nearest upsampled image, so that the network learns the residual
|
67 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
68 |
+
out += base
|
69 |
+
return out
|
Real-ESRGAN/realesrgan/data/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import dataset modules for registry
|
6 |
+
# scan all the files that end with '_dataset.py' under the data folder
|
7 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
9 |
+
# import all the dataset modules
|
10 |
+
_dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames]
|
Real-ESRGAN/realesrgan/data/realesrgan_dataset.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import torch
|
9 |
+
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
10 |
+
from basicsr.data.transforms import augment
|
11 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
12 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
13 |
+
from torch.utils import data as data
|
14 |
+
|
15 |
+
|
16 |
+
@DATASET_REGISTRY.register()
|
17 |
+
class RealESRGANDataset(data.Dataset):
|
18 |
+
"""Dataset used for Real-ESRGAN model:
|
19 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
20 |
+
|
21 |
+
It loads gt (Ground-Truth) images, and augments them.
|
22 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
23 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
27 |
+
dataroot_gt (str): Data root path for gt.
|
28 |
+
meta_info (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
use_hflip (bool): Use horizontal flips.
|
31 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
32 |
+
Please see more options in the codes.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, opt):
|
36 |
+
super(RealESRGANDataset, self).__init__()
|
37 |
+
self.opt = opt
|
38 |
+
self.file_client = None
|
39 |
+
self.io_backend_opt = opt['io_backend']
|
40 |
+
self.gt_folder = opt['dataroot_gt']
|
41 |
+
|
42 |
+
# file client (lmdb io backend)
|
43 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
44 |
+
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
45 |
+
self.io_backend_opt['client_keys'] = ['gt']
|
46 |
+
if not self.gt_folder.endswith('.lmdb'):
|
47 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
48 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
49 |
+
self.paths = [line.split('.')[0] for line in fin]
|
50 |
+
else:
|
51 |
+
# disk backend with meta_info
|
52 |
+
# Each line in the meta_info describes the relative path to an image
|
53 |
+
with open(self.opt['meta_info']) as fin:
|
54 |
+
paths = [line.strip().split(' ')[0] for line in fin]
|
55 |
+
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
56 |
+
|
57 |
+
# blur settings for the first degradation
|
58 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
59 |
+
self.kernel_list = opt['kernel_list']
|
60 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
61 |
+
self.blur_sigma = opt['blur_sigma']
|
62 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
63 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
64 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
65 |
+
|
66 |
+
# blur settings for the second degradation
|
67 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
68 |
+
self.kernel_list2 = opt['kernel_list2']
|
69 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
70 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
71 |
+
self.betag_range2 = opt['betag_range2']
|
72 |
+
self.betap_range2 = opt['betap_range2']
|
73 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
74 |
+
|
75 |
+
# a final sinc filter
|
76 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
77 |
+
|
78 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
79 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
80 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
81 |
+
self.pulse_tensor[10, 10] = 1
|
82 |
+
|
83 |
+
def __getitem__(self, index):
|
84 |
+
if self.file_client is None:
|
85 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
86 |
+
|
87 |
+
# -------------------------------- Load gt images -------------------------------- #
|
88 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
89 |
+
gt_path = self.paths[index]
|
90 |
+
# avoid errors caused by high latency in reading files
|
91 |
+
retry = 3
|
92 |
+
while retry > 0:
|
93 |
+
try:
|
94 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
95 |
+
except (IOError, OSError) as e:
|
96 |
+
logger = get_root_logger()
|
97 |
+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
98 |
+
# change another file to read
|
99 |
+
index = random.randint(0, self.__len__())
|
100 |
+
gt_path = self.paths[index]
|
101 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
102 |
+
else:
|
103 |
+
break
|
104 |
+
finally:
|
105 |
+
retry -= 1
|
106 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
107 |
+
|
108 |
+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
109 |
+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
110 |
+
|
111 |
+
# crop or pad to 400
|
112 |
+
# TODO: 400 is hard-coded. You may change it accordingly
|
113 |
+
h, w = img_gt.shape[0:2]
|
114 |
+
crop_pad_size = 400
|
115 |
+
# pad
|
116 |
+
if h < crop_pad_size or w < crop_pad_size:
|
117 |
+
pad_h = max(0, crop_pad_size - h)
|
118 |
+
pad_w = max(0, crop_pad_size - w)
|
119 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
120 |
+
# crop
|
121 |
+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
122 |
+
h, w = img_gt.shape[0:2]
|
123 |
+
# randomly choose top and left coordinates
|
124 |
+
top = random.randint(0, h - crop_pad_size)
|
125 |
+
left = random.randint(0, w - crop_pad_size)
|
126 |
+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
127 |
+
|
128 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
129 |
+
kernel_size = random.choice(self.kernel_range)
|
130 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
131 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
132 |
+
if kernel_size < 13:
|
133 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
134 |
+
else:
|
135 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
136 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
137 |
+
else:
|
138 |
+
kernel = random_mixed_kernels(
|
139 |
+
self.kernel_list,
|
140 |
+
self.kernel_prob,
|
141 |
+
kernel_size,
|
142 |
+
self.blur_sigma,
|
143 |
+
self.blur_sigma, [-math.pi, math.pi],
|
144 |
+
self.betag_range,
|
145 |
+
self.betap_range,
|
146 |
+
noise_range=None)
|
147 |
+
# pad kernel
|
148 |
+
pad_size = (21 - kernel_size) // 2
|
149 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
150 |
+
|
151 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
152 |
+
kernel_size = random.choice(self.kernel_range)
|
153 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
154 |
+
if kernel_size < 13:
|
155 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
156 |
+
else:
|
157 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
158 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
159 |
+
else:
|
160 |
+
kernel2 = random_mixed_kernels(
|
161 |
+
self.kernel_list2,
|
162 |
+
self.kernel_prob2,
|
163 |
+
kernel_size,
|
164 |
+
self.blur_sigma2,
|
165 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
166 |
+
self.betag_range2,
|
167 |
+
self.betap_range2,
|
168 |
+
noise_range=None)
|
169 |
+
|
170 |
+
# pad kernel
|
171 |
+
pad_size = (21 - kernel_size) // 2
|
172 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
173 |
+
|
174 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
175 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
176 |
+
kernel_size = random.choice(self.kernel_range)
|
177 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
178 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
179 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
180 |
+
else:
|
181 |
+
sinc_kernel = self.pulse_tensor
|
182 |
+
|
183 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
184 |
+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
185 |
+
kernel = torch.FloatTensor(kernel)
|
186 |
+
kernel2 = torch.FloatTensor(kernel2)
|
187 |
+
|
188 |
+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
189 |
+
return return_d
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return len(self.paths)
|
Real-ESRGAN/realesrgan/data/realesrgan_paired_dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
|
3 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
4 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
5 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
6 |
+
from torch.utils import data as data
|
7 |
+
from torchvision.transforms.functional import normalize
|
8 |
+
|
9 |
+
|
10 |
+
@DATASET_REGISTRY.register()
|
11 |
+
class RealESRGANPairedDataset(data.Dataset):
|
12 |
+
"""Paired image dataset for image restoration.
|
13 |
+
|
14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
15 |
+
|
16 |
+
There are three modes:
|
17 |
+
1. 'lmdb': Use lmdb files.
|
18 |
+
If opt['io_backend'] == lmdb.
|
19 |
+
2. 'meta_info': Use meta information file to generate paths.
|
20 |
+
If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
|
21 |
+
3. 'folder': Scan folders to generate paths.
|
22 |
+
The rest.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
26 |
+
dataroot_gt (str): Data root path for gt.
|
27 |
+
dataroot_lq (str): Data root path for lq.
|
28 |
+
meta_info (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
31 |
+
Default: '{}'.
|
32 |
+
gt_size (int): Cropped patched size for gt patches.
|
33 |
+
use_hflip (bool): Use horizontal flips.
|
34 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
35 |
+
and w for implementation).
|
36 |
+
|
37 |
+
scale (bool): Scale, which will be added automatically.
|
38 |
+
phase (str): 'train' or 'val'.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, opt):
|
42 |
+
super(RealESRGANPairedDataset, self).__init__()
|
43 |
+
self.opt = opt
|
44 |
+
self.file_client = None
|
45 |
+
self.io_backend_opt = opt['io_backend']
|
46 |
+
# mean and std for normalizing the input images
|
47 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
48 |
+
self.std = opt['std'] if 'std' in opt else None
|
49 |
+
|
50 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
51 |
+
self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
|
52 |
+
|
53 |
+
# file client (lmdb io backend)
|
54 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
55 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
56 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
57 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
58 |
+
elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
|
59 |
+
# disk backend with meta_info
|
60 |
+
# Each line in the meta_info describes the relative path to an image
|
61 |
+
with open(self.opt['meta_info']) as fin:
|
62 |
+
paths = [line.strip() for line in fin]
|
63 |
+
self.paths = []
|
64 |
+
for path in paths:
|
65 |
+
gt_path, lq_path = path.split(', ')
|
66 |
+
gt_path = os.path.join(self.gt_folder, gt_path)
|
67 |
+
lq_path = os.path.join(self.lq_folder, lq_path)
|
68 |
+
self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
|
69 |
+
else:
|
70 |
+
# disk backend
|
71 |
+
# it will scan the whole folder to get meta info
|
72 |
+
# it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
|
73 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
74 |
+
|
75 |
+
def __getitem__(self, index):
|
76 |
+
if self.file_client is None:
|
77 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
78 |
+
|
79 |
+
scale = self.opt['scale']
|
80 |
+
|
81 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
82 |
+
# image range: [0, 1], float32.
|
83 |
+
gt_path = self.paths[index]['gt_path']
|
84 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
85 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
86 |
+
lq_path = self.paths[index]['lq_path']
|
87 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
88 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
89 |
+
|
90 |
+
# augmentation for training
|
91 |
+
if self.opt['phase'] == 'train':
|
92 |
+
gt_size = self.opt['gt_size']
|
93 |
+
# random crop
|
94 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
95 |
+
# flip, rotation
|
96 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
97 |
+
|
98 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
99 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
100 |
+
# normalize
|
101 |
+
if self.mean is not None or self.std is not None:
|
102 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
103 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
104 |
+
|
105 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.paths)
|
Real-ESRGAN/realesrgan/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from basicsr.utils import scandir
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
# automatically scan and import model modules for registry
|
6 |
+
# scan all the files that end with '_model.py' under the model folder
|
7 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
8 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
9 |
+
# import all the model modules
|
10 |
+
_model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames]
|
Real-ESRGAN/realesrgan/models/realesrgan_model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
5 |
+
from basicsr.data.transforms import paired_random_crop
|
6 |
+
from basicsr.models.srgan_model import SRGANModel
|
7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
8 |
+
from basicsr.utils.img_process_util import filter2D
|
9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
10 |
+
from collections import OrderedDict
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
|
14 |
+
@MODEL_REGISTRY.register()
|
15 |
+
class RealESRGANModel(SRGANModel):
|
16 |
+
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
17 |
+
|
18 |
+
It mainly performs:
|
19 |
+
1. randomly synthesize LQ images in GPU tensors
|
20 |
+
2. optimize the networks with GAN training.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
super(RealESRGANModel, self).__init__(opt)
|
25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
27 |
+
self.queue_size = opt.get('queue_size', 180)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def _dequeue_and_enqueue(self):
|
31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
32 |
+
|
33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
35 |
+
to increase the degradation diversity in a batch.
|
36 |
+
"""
|
37 |
+
# initialize
|
38 |
+
b, c, h, w = self.lq.size()
|
39 |
+
if not hasattr(self, 'queue_lr'):
|
40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
42 |
+
_, c, h, w = self.gt.size()
|
43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
44 |
+
self.queue_ptr = 0
|
45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
46 |
+
# do dequeue and enqueue
|
47 |
+
# shuffle
|
48 |
+
idx = torch.randperm(self.queue_size)
|
49 |
+
self.queue_lr = self.queue_lr[idx]
|
50 |
+
self.queue_gt = self.queue_gt[idx]
|
51 |
+
# get first b samples
|
52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
54 |
+
# update the queue
|
55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
57 |
+
|
58 |
+
self.lq = lq_dequeue
|
59 |
+
self.gt = gt_dequeue
|
60 |
+
else:
|
61 |
+
# only do enqueue
|
62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
64 |
+
self.queue_ptr = self.queue_ptr + b
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def feed_data(self, data):
|
68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
69 |
+
"""
|
70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
71 |
+
# training data synthesis
|
72 |
+
self.gt = data['gt'].to(self.device)
|
73 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
74 |
+
|
75 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
76 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
77 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
78 |
+
|
79 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
80 |
+
|
81 |
+
# ----------------------- The first degradation process ----------------------- #
|
82 |
+
# blur
|
83 |
+
out = filter2D(self.gt_usm, self.kernel1)
|
84 |
+
# random resize
|
85 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
86 |
+
if updown_type == 'up':
|
87 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
88 |
+
elif updown_type == 'down':
|
89 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
90 |
+
else:
|
91 |
+
scale = 1
|
92 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
93 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
94 |
+
# add noise
|
95 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
96 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
97 |
+
out = random_add_gaussian_noise_pt(
|
98 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
99 |
+
else:
|
100 |
+
out = random_add_poisson_noise_pt(
|
101 |
+
out,
|
102 |
+
scale_range=self.opt['poisson_scale_range'],
|
103 |
+
gray_prob=gray_noise_prob,
|
104 |
+
clip=True,
|
105 |
+
rounds=False)
|
106 |
+
# JPEG compression
|
107 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
108 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
109 |
+
out = self.jpeger(out, quality=jpeg_p)
|
110 |
+
|
111 |
+
# ----------------------- The second degradation process ----------------------- #
|
112 |
+
# blur
|
113 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
114 |
+
out = filter2D(out, self.kernel2)
|
115 |
+
# random resize
|
116 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
117 |
+
if updown_type == 'up':
|
118 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
119 |
+
elif updown_type == 'down':
|
120 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
121 |
+
else:
|
122 |
+
scale = 1
|
123 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
124 |
+
out = F.interpolate(
|
125 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
126 |
+
# add noise
|
127 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
128 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
129 |
+
out = random_add_gaussian_noise_pt(
|
130 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
131 |
+
else:
|
132 |
+
out = random_add_poisson_noise_pt(
|
133 |
+
out,
|
134 |
+
scale_range=self.opt['poisson_scale_range2'],
|
135 |
+
gray_prob=gray_noise_prob,
|
136 |
+
clip=True,
|
137 |
+
rounds=False)
|
138 |
+
|
139 |
+
# JPEG compression + the final sinc filter
|
140 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
141 |
+
# as one operation.
|
142 |
+
# We consider two orders:
|
143 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
144 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
145 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
146 |
+
if np.random.uniform() < 0.5:
|
147 |
+
# resize back + the final sinc filter
|
148 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
149 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
150 |
+
out = filter2D(out, self.sinc_kernel)
|
151 |
+
# JPEG compression
|
152 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
153 |
+
out = torch.clamp(out, 0, 1)
|
154 |
+
out = self.jpeger(out, quality=jpeg_p)
|
155 |
+
else:
|
156 |
+
# JPEG compression
|
157 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
158 |
+
out = torch.clamp(out, 0, 1)
|
159 |
+
out = self.jpeger(out, quality=jpeg_p)
|
160 |
+
# resize back + the final sinc filter
|
161 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
162 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
163 |
+
out = filter2D(out, self.sinc_kernel)
|
164 |
+
|
165 |
+
# clamp and round
|
166 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
167 |
+
|
168 |
+
# random crop
|
169 |
+
gt_size = self.opt['gt_size']
|
170 |
+
(self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
|
171 |
+
self.opt['scale'])
|
172 |
+
|
173 |
+
# training pair pool
|
174 |
+
self._dequeue_and_enqueue()
|
175 |
+
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
|
176 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
177 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
178 |
+
else:
|
179 |
+
# for paired training or validation
|
180 |
+
self.lq = data['lq'].to(self.device)
|
181 |
+
if 'gt' in data:
|
182 |
+
self.gt = data['gt'].to(self.device)
|
183 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
184 |
+
|
185 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
186 |
+
# do not use the synthetic process during validation
|
187 |
+
self.is_train = False
|
188 |
+
super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
189 |
+
self.is_train = True
|
190 |
+
|
191 |
+
def optimize_parameters(self, current_iter):
|
192 |
+
# usm sharpening
|
193 |
+
l1_gt = self.gt_usm
|
194 |
+
percep_gt = self.gt_usm
|
195 |
+
gan_gt = self.gt_usm
|
196 |
+
if self.opt['l1_gt_usm'] is False:
|
197 |
+
l1_gt = self.gt
|
198 |
+
if self.opt['percep_gt_usm'] is False:
|
199 |
+
percep_gt = self.gt
|
200 |
+
if self.opt['gan_gt_usm'] is False:
|
201 |
+
gan_gt = self.gt
|
202 |
+
|
203 |
+
# optimize net_g
|
204 |
+
for p in self.net_d.parameters():
|
205 |
+
p.requires_grad = False
|
206 |
+
|
207 |
+
self.optimizer_g.zero_grad()
|
208 |
+
self.output = self.net_g(self.lq)
|
209 |
+
|
210 |
+
l_g_total = 0
|
211 |
+
loss_dict = OrderedDict()
|
212 |
+
if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
|
213 |
+
# pixel loss
|
214 |
+
if self.cri_pix:
|
215 |
+
l_g_pix = self.cri_pix(self.output, l1_gt)
|
216 |
+
l_g_total += l_g_pix
|
217 |
+
loss_dict['l_g_pix'] = l_g_pix
|
218 |
+
# perceptual loss
|
219 |
+
if self.cri_perceptual:
|
220 |
+
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
|
221 |
+
if l_g_percep is not None:
|
222 |
+
l_g_total += l_g_percep
|
223 |
+
loss_dict['l_g_percep'] = l_g_percep
|
224 |
+
if l_g_style is not None:
|
225 |
+
l_g_total += l_g_style
|
226 |
+
loss_dict['l_g_style'] = l_g_style
|
227 |
+
# gan loss
|
228 |
+
fake_g_pred = self.net_d(self.output)
|
229 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
230 |
+
l_g_total += l_g_gan
|
231 |
+
loss_dict['l_g_gan'] = l_g_gan
|
232 |
+
|
233 |
+
l_g_total.backward()
|
234 |
+
self.optimizer_g.step()
|
235 |
+
|
236 |
+
# optimize net_d
|
237 |
+
for p in self.net_d.parameters():
|
238 |
+
p.requires_grad = True
|
239 |
+
|
240 |
+
self.optimizer_d.zero_grad()
|
241 |
+
# real
|
242 |
+
real_d_pred = self.net_d(gan_gt)
|
243 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
244 |
+
loss_dict['l_d_real'] = l_d_real
|
245 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
246 |
+
l_d_real.backward()
|
247 |
+
# fake
|
248 |
+
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
|
249 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
250 |
+
loss_dict['l_d_fake'] = l_d_fake
|
251 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
252 |
+
l_d_fake.backward()
|
253 |
+
self.optimizer_d.step()
|
254 |
+
|
255 |
+
if self.ema_decay > 0:
|
256 |
+
self.model_ema(decay=self.ema_decay)
|
257 |
+
|
258 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
Real-ESRGAN/realesrgan/models/realesrnet_model.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
5 |
+
from basicsr.data.transforms import paired_random_crop
|
6 |
+
from basicsr.models.sr_model import SRModel
|
7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
8 |
+
from basicsr.utils.img_process_util import filter2D
|
9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
@MODEL_REGISTRY.register()
|
14 |
+
class RealESRNetModel(SRModel):
|
15 |
+
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
16 |
+
|
17 |
+
It is trained without GAN losses.
|
18 |
+
It mainly performs:
|
19 |
+
1. randomly synthesize LQ images in GPU tensors
|
20 |
+
2. optimize the networks with GAN training.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
super(RealESRNetModel, self).__init__(opt)
|
25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
27 |
+
self.queue_size = opt.get('queue_size', 180)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def _dequeue_and_enqueue(self):
|
31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
32 |
+
|
33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
35 |
+
to increase the degradation diversity in a batch.
|
36 |
+
"""
|
37 |
+
# initialize
|
38 |
+
b, c, h, w = self.lq.size()
|
39 |
+
if not hasattr(self, 'queue_lr'):
|
40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
42 |
+
_, c, h, w = self.gt.size()
|
43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
44 |
+
self.queue_ptr = 0
|
45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
46 |
+
# do dequeue and enqueue
|
47 |
+
# shuffle
|
48 |
+
idx = torch.randperm(self.queue_size)
|
49 |
+
self.queue_lr = self.queue_lr[idx]
|
50 |
+
self.queue_gt = self.queue_gt[idx]
|
51 |
+
# get first b samples
|
52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
54 |
+
# update the queue
|
55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
57 |
+
|
58 |
+
self.lq = lq_dequeue
|
59 |
+
self.gt = gt_dequeue
|
60 |
+
else:
|
61 |
+
# only do enqueue
|
62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
64 |
+
self.queue_ptr = self.queue_ptr + b
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def feed_data(self, data):
|
68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
69 |
+
"""
|
70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
71 |
+
# training data synthesis
|
72 |
+
self.gt = data['gt'].to(self.device)
|
73 |
+
# USM sharpen the GT images
|
74 |
+
if self.opt['gt_usm'] is True:
|
75 |
+
self.gt = self.usm_sharpener(self.gt)
|
76 |
+
|
77 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
78 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
79 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
80 |
+
|
81 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
82 |
+
|
83 |
+
# ----------------------- The first degradation process ----------------------- #
|
84 |
+
# blur
|
85 |
+
out = filter2D(self.gt, self.kernel1)
|
86 |
+
# random resize
|
87 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
88 |
+
if updown_type == 'up':
|
89 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
90 |
+
elif updown_type == 'down':
|
91 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
92 |
+
else:
|
93 |
+
scale = 1
|
94 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
95 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
96 |
+
# add noise
|
97 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
98 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
99 |
+
out = random_add_gaussian_noise_pt(
|
100 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
101 |
+
else:
|
102 |
+
out = random_add_poisson_noise_pt(
|
103 |
+
out,
|
104 |
+
scale_range=self.opt['poisson_scale_range'],
|
105 |
+
gray_prob=gray_noise_prob,
|
106 |
+
clip=True,
|
107 |
+
rounds=False)
|
108 |
+
# JPEG compression
|
109 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
110 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
111 |
+
out = self.jpeger(out, quality=jpeg_p)
|
112 |
+
|
113 |
+
# ----------------------- The second degradation process ----------------------- #
|
114 |
+
# blur
|
115 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
116 |
+
out = filter2D(out, self.kernel2)
|
117 |
+
# random resize
|
118 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
119 |
+
if updown_type == 'up':
|
120 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
121 |
+
elif updown_type == 'down':
|
122 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
123 |
+
else:
|
124 |
+
scale = 1
|
125 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
126 |
+
out = F.interpolate(
|
127 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
128 |
+
# add noise
|
129 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
130 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
131 |
+
out = random_add_gaussian_noise_pt(
|
132 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
133 |
+
else:
|
134 |
+
out = random_add_poisson_noise_pt(
|
135 |
+
out,
|
136 |
+
scale_range=self.opt['poisson_scale_range2'],
|
137 |
+
gray_prob=gray_noise_prob,
|
138 |
+
clip=True,
|
139 |
+
rounds=False)
|
140 |
+
|
141 |
+
# JPEG compression + the final sinc filter
|
142 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
143 |
+
# as one operation.
|
144 |
+
# We consider two orders:
|
145 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
146 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
147 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
148 |
+
if np.random.uniform() < 0.5:
|
149 |
+
# resize back + the final sinc filter
|
150 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
151 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
152 |
+
out = filter2D(out, self.sinc_kernel)
|
153 |
+
# JPEG compression
|
154 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
155 |
+
out = torch.clamp(out, 0, 1)
|
156 |
+
out = self.jpeger(out, quality=jpeg_p)
|
157 |
+
else:
|
158 |
+
# JPEG compression
|
159 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
160 |
+
out = torch.clamp(out, 0, 1)
|
161 |
+
out = self.jpeger(out, quality=jpeg_p)
|
162 |
+
# resize back + the final sinc filter
|
163 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
164 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
165 |
+
out = filter2D(out, self.sinc_kernel)
|
166 |
+
|
167 |
+
# clamp and round
|
168 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
169 |
+
|
170 |
+
# random crop
|
171 |
+
gt_size = self.opt['gt_size']
|
172 |
+
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
173 |
+
|
174 |
+
# training pair pool
|
175 |
+
self._dequeue_and_enqueue()
|
176 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
177 |
+
else:
|
178 |
+
# for paired training or validation
|
179 |
+
self.lq = data['lq'].to(self.device)
|
180 |
+
if 'gt' in data:
|
181 |
+
self.gt = data['gt'].to(self.device)
|
182 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
183 |
+
|
184 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
185 |
+
# do not use the synthetic process during validation
|
186 |
+
self.is_train = False
|
187 |
+
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
188 |
+
self.is_train = True
|
Real-ESRGAN/realesrgan/train.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
import os.path as osp
|
3 |
+
from basicsr.train import train_pipeline
|
4 |
+
|
5 |
+
import realesrgan.archs
|
6 |
+
import realesrgan.data
|
7 |
+
import realesrgan.models
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
11 |
+
train_pipeline(root_path)
|
Real-ESRGAN/realesrgan/utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import queue
|
6 |
+
import threading
|
7 |
+
import torch
|
8 |
+
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
12 |
+
|
13 |
+
|
14 |
+
class RealESRGANer():
|
15 |
+
"""A helper class for upsampling images with RealESRGAN.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
19 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
20 |
+
model (nn.Module): The defined network. Default: None.
|
21 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
22 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
23 |
+
0 denotes for do not use tile. Default: 0.
|
24 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
25 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
26 |
+
half (float): Whether to use half precision during inference. Default: False.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
scale,
|
31 |
+
model_path,
|
32 |
+
dni_weight=None,
|
33 |
+
model=None,
|
34 |
+
tile=0,
|
35 |
+
tile_pad=10,
|
36 |
+
pre_pad=10,
|
37 |
+
half=False,
|
38 |
+
device=None,
|
39 |
+
gpu_id=None):
|
40 |
+
self.scale = scale
|
41 |
+
self.tile_size = tile
|
42 |
+
self.tile_pad = tile_pad
|
43 |
+
self.pre_pad = pre_pad
|
44 |
+
self.mod_scale = None
|
45 |
+
self.half = half
|
46 |
+
|
47 |
+
# initialize model
|
48 |
+
if gpu_id:
|
49 |
+
self.device = torch.device(
|
50 |
+
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
51 |
+
else:
|
52 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
53 |
+
|
54 |
+
if isinstance(model_path, list):
|
55 |
+
# dni
|
56 |
+
assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.'
|
57 |
+
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
58 |
+
else:
|
59 |
+
# if the model_path starts with https, it will first download models to the folder: weights
|
60 |
+
if model_path.startswith('https://'):
|
61 |
+
model_path = load_file_from_url(
|
62 |
+
url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
63 |
+
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
64 |
+
|
65 |
+
# prefer to use params_ema
|
66 |
+
if 'params_ema' in loadnet:
|
67 |
+
keyname = 'params_ema'
|
68 |
+
else:
|
69 |
+
keyname = 'params'
|
70 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
71 |
+
|
72 |
+
model.eval()
|
73 |
+
self.model = model.to(self.device)
|
74 |
+
if self.half:
|
75 |
+
self.model = self.model.half()
|
76 |
+
|
77 |
+
def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'):
|
78 |
+
"""Deep network interpolation.
|
79 |
+
|
80 |
+
``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition``
|
81 |
+
"""
|
82 |
+
net_a = torch.load(net_a, map_location=torch.device(loc))
|
83 |
+
net_b = torch.load(net_b, map_location=torch.device(loc))
|
84 |
+
for k, v_a in net_a[key].items():
|
85 |
+
net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k]
|
86 |
+
return net_a
|
87 |
+
|
88 |
+
def pre_process(self, img):
|
89 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
90 |
+
"""
|
91 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
92 |
+
self.img = img.unsqueeze(0).to(self.device)
|
93 |
+
if self.half:
|
94 |
+
self.img = self.img.half()
|
95 |
+
|
96 |
+
# pre_pad
|
97 |
+
if self.pre_pad != 0:
|
98 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
99 |
+
# mod pad for divisible borders
|
100 |
+
if self.scale == 2:
|
101 |
+
self.mod_scale = 2
|
102 |
+
elif self.scale == 1:
|
103 |
+
self.mod_scale = 4
|
104 |
+
if self.mod_scale is not None:
|
105 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
106 |
+
_, _, h, w = self.img.size()
|
107 |
+
if (h % self.mod_scale != 0):
|
108 |
+
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
109 |
+
if (w % self.mod_scale != 0):
|
110 |
+
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
111 |
+
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
112 |
+
|
113 |
+
def process(self):
|
114 |
+
# model inference
|
115 |
+
self.output = self.model(self.img)
|
116 |
+
|
117 |
+
def tile_process(self):
|
118 |
+
"""It will first crop input images to tiles, and then process each tile.
|
119 |
+
Finally, all the processed tiles are merged into one images.
|
120 |
+
|
121 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
122 |
+
"""
|
123 |
+
batch, channel, height, width = self.img.shape
|
124 |
+
output_height = height * self.scale
|
125 |
+
output_width = width * self.scale
|
126 |
+
output_shape = (batch, channel, output_height, output_width)
|
127 |
+
|
128 |
+
# start with black image
|
129 |
+
self.output = self.img.new_zeros(output_shape)
|
130 |
+
tiles_x = math.ceil(width / self.tile_size)
|
131 |
+
tiles_y = math.ceil(height / self.tile_size)
|
132 |
+
|
133 |
+
# loop over all tiles
|
134 |
+
for y in range(tiles_y):
|
135 |
+
for x in range(tiles_x):
|
136 |
+
# extract tile from input image
|
137 |
+
ofs_x = x * self.tile_size
|
138 |
+
ofs_y = y * self.tile_size
|
139 |
+
# input tile area on total image
|
140 |
+
input_start_x = ofs_x
|
141 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
142 |
+
input_start_y = ofs_y
|
143 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
144 |
+
|
145 |
+
# input tile area on total image with padding
|
146 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
147 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
148 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
149 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
150 |
+
|
151 |
+
# input tile dimensions
|
152 |
+
input_tile_width = input_end_x - input_start_x
|
153 |
+
input_tile_height = input_end_y - input_start_y
|
154 |
+
tile_idx = y * tiles_x + x + 1
|
155 |
+
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
156 |
+
|
157 |
+
# upscale tile
|
158 |
+
try:
|
159 |
+
with torch.no_grad():
|
160 |
+
output_tile = self.model(input_tile)
|
161 |
+
except RuntimeError as error:
|
162 |
+
print('Error', error)
|
163 |
+
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
164 |
+
|
165 |
+
# output tile area on total image
|
166 |
+
output_start_x = input_start_x * self.scale
|
167 |
+
output_end_x = input_end_x * self.scale
|
168 |
+
output_start_y = input_start_y * self.scale
|
169 |
+
output_end_y = input_end_y * self.scale
|
170 |
+
|
171 |
+
# output tile area without padding
|
172 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
173 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
174 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
175 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
176 |
+
|
177 |
+
# put tile into output image
|
178 |
+
self.output[:, :, output_start_y:output_end_y,
|
179 |
+
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
180 |
+
output_start_x_tile:output_end_x_tile]
|
181 |
+
|
182 |
+
def post_process(self):
|
183 |
+
# remove extra pad
|
184 |
+
if self.mod_scale is not None:
|
185 |
+
_, _, h, w = self.output.size()
|
186 |
+
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
187 |
+
# remove prepad
|
188 |
+
if self.pre_pad != 0:
|
189 |
+
_, _, h, w = self.output.size()
|
190 |
+
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
191 |
+
return self.output
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
195 |
+
h_input, w_input = img.shape[0:2]
|
196 |
+
# img: numpy
|
197 |
+
img = img.astype(np.float32)
|
198 |
+
if np.max(img) > 256: # 16-bit image
|
199 |
+
max_range = 65535
|
200 |
+
print('\tInput is a 16-bit image')
|
201 |
+
else:
|
202 |
+
max_range = 255
|
203 |
+
img = img / max_range
|
204 |
+
if len(img.shape) == 2: # gray image
|
205 |
+
img_mode = 'L'
|
206 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
207 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
208 |
+
img_mode = 'RGBA'
|
209 |
+
alpha = img[:, :, 3]
|
210 |
+
img = img[:, :, 0:3]
|
211 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
212 |
+
if alpha_upsampler == 'realesrgan':
|
213 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
214 |
+
else:
|
215 |
+
img_mode = 'RGB'
|
216 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
217 |
+
|
218 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
219 |
+
self.pre_process(img)
|
220 |
+
if self.tile_size > 0:
|
221 |
+
self.tile_process()
|
222 |
+
else:
|
223 |
+
self.process()
|
224 |
+
output_img = self.post_process()
|
225 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
226 |
+
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
227 |
+
if img_mode == 'L':
|
228 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
229 |
+
|
230 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
231 |
+
if img_mode == 'RGBA':
|
232 |
+
if alpha_upsampler == 'realesrgan':
|
233 |
+
self.pre_process(alpha)
|
234 |
+
if self.tile_size > 0:
|
235 |
+
self.tile_process()
|
236 |
+
else:
|
237 |
+
self.process()
|
238 |
+
output_alpha = self.post_process()
|
239 |
+
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
240 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
241 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
242 |
+
else: # use the cv2 resize for alpha channel
|
243 |
+
h, w = alpha.shape[0:2]
|
244 |
+
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
245 |
+
|
246 |
+
# merge the alpha channel
|
247 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
248 |
+
output_img[:, :, 3] = output_alpha
|
249 |
+
|
250 |
+
# ------------------------------ return ------------------------------ #
|
251 |
+
if max_range == 65535: # 16-bit image
|
252 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
253 |
+
else:
|
254 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
255 |
+
|
256 |
+
if outscale is not None and outscale != float(self.scale):
|
257 |
+
output = cv2.resize(
|
258 |
+
output, (
|
259 |
+
int(w_input * outscale),
|
260 |
+
int(h_input * outscale),
|
261 |
+
), interpolation=cv2.INTER_LANCZOS4)
|
262 |
+
|
263 |
+
return output, img_mode
|
264 |
+
|
265 |
+
|
266 |
+
class PrefetchReader(threading.Thread):
|
267 |
+
"""Prefetch images.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
img_list (list[str]): A image list of image paths to be read.
|
271 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, img_list, num_prefetch_queue):
|
275 |
+
super().__init__()
|
276 |
+
self.que = queue.Queue(num_prefetch_queue)
|
277 |
+
self.img_list = img_list
|
278 |
+
|
279 |
+
def run(self):
|
280 |
+
for img_path in self.img_list:
|
281 |
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
282 |
+
self.que.put(img)
|
283 |
+
|
284 |
+
self.que.put(None)
|
285 |
+
|
286 |
+
def __next__(self):
|
287 |
+
next_item = self.que.get()
|
288 |
+
if next_item is None:
|
289 |
+
raise StopIteration
|
290 |
+
return next_item
|
291 |
+
|
292 |
+
def __iter__(self):
|
293 |
+
return self
|
294 |
+
|
295 |
+
|
296 |
+
class IOConsumer(threading.Thread):
|
297 |
+
|
298 |
+
def __init__(self, opt, que, qid):
|
299 |
+
super().__init__()
|
300 |
+
self._queue = que
|
301 |
+
self.qid = qid
|
302 |
+
self.opt = opt
|
303 |
+
|
304 |
+
def run(self):
|
305 |
+
while True:
|
306 |
+
msg = self._queue.get()
|
307 |
+
if isinstance(msg, str) and msg == 'quit':
|
308 |
+
break
|
309 |
+
|
310 |
+
output = msg['output']
|
311 |
+
save_path = msg['save_path']
|
312 |
+
cv2.imwrite(save_path, output)
|
313 |
+
print(f'IO worker {self.qid} is done.')
|
Real-ESRGAN/realesrgan/version.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GENERATED VERSION FILE
|
2 |
+
# TIME: Mon Apr 14 13:27:54 2025
|
3 |
+
__version__ = '0.3.0'
|
4 |
+
__gitsha__ = 'a4abfb2'
|
5 |
+
version_info = (0, 3, 0)
|
Real-ESRGAN/scripts/extract_subimages.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from basicsr.utils import scandir
|
7 |
+
from multiprocessing import Pool
|
8 |
+
from os import path as osp
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
"""A multi-thread tool to crop large images to sub-images for faster IO.
|
14 |
+
|
15 |
+
opt (dict): Configuration dict. It contains:
|
16 |
+
n_thread (int): Thread number.
|
17 |
+
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
|
18 |
+
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
19 |
+
input_folder (str): Path to the input folder.
|
20 |
+
save_folder (str): Path to save folder.
|
21 |
+
crop_size (int): Crop size.
|
22 |
+
step (int): Step for overlapped sliding window.
|
23 |
+
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
24 |
+
|
25 |
+
Usage:
|
26 |
+
For each folder, run this script.
|
27 |
+
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
|
28 |
+
After process, each sub_folder should have the same number of subimages.
|
29 |
+
Remember to modify opt configurations according to your settings.
|
30 |
+
"""
|
31 |
+
|
32 |
+
opt = {}
|
33 |
+
opt['n_thread'] = args.n_thread
|
34 |
+
opt['compression_level'] = args.compression_level
|
35 |
+
opt['input_folder'] = args.input
|
36 |
+
opt['save_folder'] = args.output
|
37 |
+
opt['crop_size'] = args.crop_size
|
38 |
+
opt['step'] = args.step
|
39 |
+
opt['thresh_size'] = args.thresh_size
|
40 |
+
extract_subimages(opt)
|
41 |
+
|
42 |
+
|
43 |
+
def extract_subimages(opt):
|
44 |
+
"""Crop images to subimages.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
opt (dict): Configuration dict. It contains:
|
48 |
+
input_folder (str): Path to the input folder.
|
49 |
+
save_folder (str): Path to save folder.
|
50 |
+
n_thread (int): Thread number.
|
51 |
+
"""
|
52 |
+
input_folder = opt['input_folder']
|
53 |
+
save_folder = opt['save_folder']
|
54 |
+
if not osp.exists(save_folder):
|
55 |
+
os.makedirs(save_folder)
|
56 |
+
print(f'mkdir {save_folder} ...')
|
57 |
+
else:
|
58 |
+
print(f'Folder {save_folder} already exists. Exit.')
|
59 |
+
sys.exit(1)
|
60 |
+
|
61 |
+
# scan all images
|
62 |
+
img_list = list(scandir(input_folder, full_path=True))
|
63 |
+
|
64 |
+
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
65 |
+
pool = Pool(opt['n_thread'])
|
66 |
+
for path in img_list:
|
67 |
+
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
|
68 |
+
pool.close()
|
69 |
+
pool.join()
|
70 |
+
pbar.close()
|
71 |
+
print('All processes done.')
|
72 |
+
|
73 |
+
|
74 |
+
def worker(path, opt):
|
75 |
+
"""Worker for each process.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
path (str): Image path.
|
79 |
+
opt (dict): Configuration dict. It contains:
|
80 |
+
crop_size (int): Crop size.
|
81 |
+
step (int): Step for overlapped sliding window.
|
82 |
+
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
83 |
+
save_folder (str): Path to save folder.
|
84 |
+
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
process_info (str): Process information displayed in progress bar.
|
88 |
+
"""
|
89 |
+
crop_size = opt['crop_size']
|
90 |
+
step = opt['step']
|
91 |
+
thresh_size = opt['thresh_size']
|
92 |
+
img_name, extension = osp.splitext(osp.basename(path))
|
93 |
+
|
94 |
+
# remove the x2, x3, x4 and x8 in the filename for DIV2K
|
95 |
+
img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
|
96 |
+
|
97 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
98 |
+
|
99 |
+
h, w = img.shape[0:2]
|
100 |
+
h_space = np.arange(0, h - crop_size + 1, step)
|
101 |
+
if h - (h_space[-1] + crop_size) > thresh_size:
|
102 |
+
h_space = np.append(h_space, h - crop_size)
|
103 |
+
w_space = np.arange(0, w - crop_size + 1, step)
|
104 |
+
if w - (w_space[-1] + crop_size) > thresh_size:
|
105 |
+
w_space = np.append(w_space, w - crop_size)
|
106 |
+
|
107 |
+
index = 0
|
108 |
+
for x in h_space:
|
109 |
+
for y in w_space:
|
110 |
+
index += 1
|
111 |
+
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
|
112 |
+
cropped_img = np.ascontiguousarray(cropped_img)
|
113 |
+
cv2.imwrite(
|
114 |
+
osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
|
115 |
+
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
116 |
+
process_info = f'Processing {img_name} ...'
|
117 |
+
return process_info
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
parser = argparse.ArgumentParser()
|
122 |
+
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
123 |
+
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
|
124 |
+
parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
|
125 |
+
parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
|
126 |
+
parser.add_argument(
|
127 |
+
'--thresh_size',
|
128 |
+
type=int,
|
129 |
+
default=0,
|
130 |
+
help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
|
131 |
+
parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
|
132 |
+
parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
|
133 |
+
args = parser.parse_args()
|
134 |
+
|
135 |
+
main(args)
|
Real-ESRGAN/scripts/generate_meta_info.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def main(args):
|
8 |
+
txt_file = open(args.meta_info, 'w')
|
9 |
+
for folder, root in zip(args.input, args.root):
|
10 |
+
img_paths = sorted(glob.glob(os.path.join(folder, '*')))
|
11 |
+
for img_path in img_paths:
|
12 |
+
status = True
|
13 |
+
if args.check:
|
14 |
+
# read the image once for check, as some images may have errors
|
15 |
+
try:
|
16 |
+
img = cv2.imread(img_path)
|
17 |
+
except (IOError, OSError) as error:
|
18 |
+
print(f'Read {img_path} error: {error}')
|
19 |
+
status = False
|
20 |
+
if img is None:
|
21 |
+
status = False
|
22 |
+
print(f'Img is None: {img_path}')
|
23 |
+
if status:
|
24 |
+
# get the relative path
|
25 |
+
img_name = os.path.relpath(img_path, root)
|
26 |
+
print(img_name)
|
27 |
+
txt_file.write(f'{img_name}\n')
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
"""Generate meta info (txt file) for only Ground-Truth images.
|
32 |
+
|
33 |
+
It can also generate meta info from several folders into one txt file.
|
34 |
+
"""
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument(
|
37 |
+
'--input',
|
38 |
+
nargs='+',
|
39 |
+
default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'],
|
40 |
+
help='Input folder, can be a list')
|
41 |
+
parser.add_argument(
|
42 |
+
'--root',
|
43 |
+
nargs='+',
|
44 |
+
default=['datasets/DF2K', 'datasets/DF2K'],
|
45 |
+
help='Folder root, should have the length as input folders')
|
46 |
+
parser.add_argument(
|
47 |
+
'--meta_info',
|
48 |
+
type=str,
|
49 |
+
default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
|
50 |
+
help='txt path for meta info')
|
51 |
+
parser.add_argument('--check', action='store_true', help='Read image to check whether it is ok')
|
52 |
+
args = parser.parse_args()
|
53 |
+
|
54 |
+
assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
|
55 |
+
f'{len(args.input)} and {len(args.root)}.')
|
56 |
+
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
|
57 |
+
|
58 |
+
main(args)
|
Real-ESRGAN/scripts/generate_meta_info_pairdata.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def main(args):
|
7 |
+
txt_file = open(args.meta_info, 'w')
|
8 |
+
# sca images
|
9 |
+
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
|
10 |
+
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
|
11 |
+
|
12 |
+
assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
|
13 |
+
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
|
14 |
+
|
15 |
+
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
|
16 |
+
# get the relative paths
|
17 |
+
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
|
18 |
+
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
|
19 |
+
print(f'{img_name_gt}, {img_name_lq}')
|
20 |
+
txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
"""This script is used to generate meta info (txt file) for paired images.
|
25 |
+
"""
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument(
|
28 |
+
'--input',
|
29 |
+
nargs='+',
|
30 |
+
default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
|
31 |
+
help='Input folder, should be [gt_folder, lq_folder]')
|
32 |
+
parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
|
33 |
+
parser.add_argument(
|
34 |
+
'--meta_info',
|
35 |
+
type=str,
|
36 |
+
default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
|
37 |
+
help='txt path for meta info')
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
|
41 |
+
assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
|
42 |
+
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
|
43 |
+
for i in range(2):
|
44 |
+
if args.input[i].endswith('/'):
|
45 |
+
args.input[i] = args.input[i][:-1]
|
46 |
+
if args.root[i] is None:
|
47 |
+
args.root[i] = os.path.dirname(args.input[i])
|
48 |
+
|
49 |
+
main(args)
|
Real-ESRGAN/scripts/generate_multiscale_DF2K.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def main(args):
|
8 |
+
# For DF2K, we consider the following three scales,
|
9 |
+
# and the smallest image whose shortest edge is 400
|
10 |
+
scale_list = [0.75, 0.5, 1 / 3]
|
11 |
+
shortest_edge = 400
|
12 |
+
|
13 |
+
path_list = sorted(glob.glob(os.path.join(args.input, '*')))
|
14 |
+
for path in path_list:
|
15 |
+
print(path)
|
16 |
+
basename = os.path.splitext(os.path.basename(path))[0]
|
17 |
+
|
18 |
+
img = Image.open(path)
|
19 |
+
width, height = img.size
|
20 |
+
for idx, scale in enumerate(scale_list):
|
21 |
+
print(f'\t{scale:.2f}')
|
22 |
+
rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
|
23 |
+
rlt.save(os.path.join(args.output, f'{basename}T{idx}.png'))
|
24 |
+
|
25 |
+
# save the smallest image which the shortest edge is 400
|
26 |
+
if width < height:
|
27 |
+
ratio = height / width
|
28 |
+
width = shortest_edge
|
29 |
+
height = int(width * ratio)
|
30 |
+
else:
|
31 |
+
ratio = width / height
|
32 |
+
height = shortest_edge
|
33 |
+
width = int(height * ratio)
|
34 |
+
rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS)
|
35 |
+
rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png'))
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
"""Generate multi-scale versions for GT images with LANCZOS resampling.
|
40 |
+
It is now used for DF2K dataset (DIV2K + Flickr 2K)
|
41 |
+
"""
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
44 |
+
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
os.makedirs(args.output, exist_ok=True)
|
48 |
+
main(args)
|
Real-ESRGAN/scripts/pytorch2onnx.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import torch.onnx
|
4 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
5 |
+
|
6 |
+
|
7 |
+
def main(args):
|
8 |
+
# An instance of the model
|
9 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
10 |
+
if args.params:
|
11 |
+
keyname = 'params'
|
12 |
+
else:
|
13 |
+
keyname = 'params_ema'
|
14 |
+
model.load_state_dict(torch.load(args.input)[keyname])
|
15 |
+
# set the train mode to false since we will only run the forward pass.
|
16 |
+
model.train(False)
|
17 |
+
model.cpu().eval()
|
18 |
+
|
19 |
+
# An example input
|
20 |
+
x = torch.rand(1, 3, 64, 64)
|
21 |
+
# Export the model
|
22 |
+
with torch.no_grad():
|
23 |
+
torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
|
24 |
+
print(torch_out.shape)
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
"""Convert pytorch model to onnx models"""
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument(
|
31 |
+
'--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
|
32 |
+
parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
|
33 |
+
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
main(args)
|
Real-ESRGAN/setup.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from setuptools import find_packages, setup
|
4 |
+
|
5 |
+
import os
|
6 |
+
import subprocess
|
7 |
+
import time
|
8 |
+
|
9 |
+
version_file = 'realesrgan/version.py'
|
10 |
+
|
11 |
+
|
12 |
+
def readme():
|
13 |
+
with open('README.md', encoding='utf-8') as f:
|
14 |
+
content = f.read()
|
15 |
+
return content
|
16 |
+
|
17 |
+
|
18 |
+
def get_git_hash():
|
19 |
+
|
20 |
+
def _minimal_ext_cmd(cmd):
|
21 |
+
# construct minimal environment
|
22 |
+
env = {}
|
23 |
+
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
24 |
+
v = os.environ.get(k)
|
25 |
+
if v is not None:
|
26 |
+
env[k] = v
|
27 |
+
# LANGUAGE is used on win32
|
28 |
+
env['LANGUAGE'] = 'C'
|
29 |
+
env['LANG'] = 'C'
|
30 |
+
env['LC_ALL'] = 'C'
|
31 |
+
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
32 |
+
return out
|
33 |
+
|
34 |
+
try:
|
35 |
+
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
36 |
+
sha = out.strip().decode('ascii')
|
37 |
+
except OSError:
|
38 |
+
sha = 'unknown'
|
39 |
+
|
40 |
+
return sha
|
41 |
+
|
42 |
+
|
43 |
+
def get_hash():
|
44 |
+
if os.path.exists('.git'):
|
45 |
+
sha = get_git_hash()[:7]
|
46 |
+
else:
|
47 |
+
sha = 'unknown'
|
48 |
+
|
49 |
+
return sha
|
50 |
+
|
51 |
+
|
52 |
+
def write_version_py():
|
53 |
+
content = """# GENERATED VERSION FILE
|
54 |
+
# TIME: {}
|
55 |
+
__version__ = '{}'
|
56 |
+
__gitsha__ = '{}'
|
57 |
+
version_info = ({})
|
58 |
+
"""
|
59 |
+
sha = get_hash()
|
60 |
+
with open('VERSION', 'r') as f:
|
61 |
+
SHORT_VERSION = f.read().strip()
|
62 |
+
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
|
63 |
+
|
64 |
+
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
|
65 |
+
with open(version_file, 'w') as f:
|
66 |
+
f.write(version_file_str)
|
67 |
+
|
68 |
+
|
69 |
+
def get_version():
|
70 |
+
with open(version_file, 'r') as f:
|
71 |
+
exec(compile(f.read(), version_file, 'exec'))
|
72 |
+
return locals()['__version__']
|
73 |
+
|
74 |
+
|
75 |
+
def get_requirements(filename='requirements.txt'):
|
76 |
+
here = os.path.dirname(os.path.realpath(__file__))
|
77 |
+
with open(os.path.join(here, filename), 'r') as f:
|
78 |
+
requires = [line.replace('\n', '') for line in f.readlines()]
|
79 |
+
return requires
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
write_version_py()
|
84 |
+
setup(
|
85 |
+
name='realesrgan',
|
86 |
+
version=get_version(),
|
87 |
+
description='Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration',
|
88 |
+
long_description=readme(),
|
89 |
+
long_description_content_type='text/markdown',
|
90 |
+
author='Xintao Wang',
|
91 |
+
author_email='xintao.wang@outlook.com',
|
92 |
+
keywords='computer vision, pytorch, image restoration, super-resolution, esrgan, real-esrgan',
|
93 |
+
url='https://github.com/xinntao/Real-ESRGAN',
|
94 |
+
include_package_data=True,
|
95 |
+
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
|
96 |
+
classifiers=[
|
97 |
+
'Development Status :: 4 - Beta',
|
98 |
+
'License :: OSI Approved :: Apache Software License',
|
99 |
+
'Operating System :: OS Independent',
|
100 |
+
'Programming Language :: Python :: 3',
|
101 |
+
'Programming Language :: Python :: 3.7',
|
102 |
+
'Programming Language :: Python :: 3.8',
|
103 |
+
],
|
104 |
+
license='BSD-3-Clause License',
|
105 |
+
setup_requires=['cython', 'numpy'],
|
106 |
+
install_requires=get_requirements(),
|
107 |
+
zip_safe=False)
|
Real-ESRGAN/weights/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Weights
|
2 |
+
|
3 |
+
Put the downloaded weights to this folder.
|
Real-ESRGAN/weights/RealESRGAN_x2plus.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb
|
3 |
+
size 67061725
|
Real-ESRGAN/weights/RealESRGAN_x4plus.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1
|
3 |
+
size 67040989
|
api.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import gc
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
import traceback
|
7 |
+
import shutil
|
8 |
+
import logging
|
9 |
+
from typing import Optional, List
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Form
|
13 |
+
from fastapi.responses import FileResponse
|
14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
15 |
+
import uvicorn
|
16 |
+
import psutil
|
17 |
+
|
18 |
+
# --- Configuration ---
|
19 |
+
SCRIPT_DIR = Path(__file__).parent.resolve()
|
20 |
+
REAL_ESRGAN_DIR = SCRIPT_DIR / "Real-ESRGAN"
|
21 |
+
INFERENCE_SCRIPT = REAL_ESRGAN_DIR / "inference_realesrgan.py"
|
22 |
+
MODEL_DIR = REAL_ESRGAN_DIR / "weights"
|
23 |
+
INPUT_DIR = SCRIPT_DIR / "api_inputs"
|
24 |
+
OUTPUT_DIR = SCRIPT_DIR / "api_outputs"
|
25 |
+
API_PORT = 8000
|
26 |
+
LOG_FILE = SCRIPT_DIR / "api.log"
|
27 |
+
|
28 |
+
# --- Setup Logging ---
|
29 |
+
logging.basicConfig(
|
30 |
+
level=logging.INFO,
|
31 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
32 |
+
handlers=[
|
33 |
+
logging.FileHandler(LOG_FILE),
|
34 |
+
logging.StreamHandler(sys.stdout) # Also print logs to console
|
35 |
+
]
|
36 |
+
)
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
# --- Create Directories ---
|
40 |
+
INPUT_DIR.mkdir(exist_ok=True)
|
41 |
+
OUTPUT_DIR.mkdir(exist_ok=True)
|
42 |
+
|
43 |
+
# --- FastAPI App Initialization ---
|
44 |
+
app = FastAPI(
|
45 |
+
title="Image Enhancer API",
|
46 |
+
description="API for enhancing images.",
|
47 |
+
version="1.0.0"
|
48 |
+
)
|
49 |
+
|
50 |
+
# --- CORS Middleware ---
|
51 |
+
app.add_middleware(
|
52 |
+
CORSMiddleware,
|
53 |
+
allow_origins=["*"], # Allow all origins for simplicity, adjust in production
|
54 |
+
allow_credentials=True,
|
55 |
+
allow_methods=["*"],
|
56 |
+
allow_headers=["*"],
|
57 |
+
)
|
58 |
+
|
59 |
+
# --- Global State ---
|
60 |
+
processing_lock = False
|
61 |
+
available_models = []
|
62 |
+
DEFAULT_MODEL_PREFERENCE = "RealESRGAN_x4plus" # Preferred default
|
63 |
+
|
64 |
+
# Define allowed values for API input validation
|
65 |
+
AVAILABLE_MODELS_API = ["RealESRGAN_x4plus", "RealESRGAN_x2plus"]
|
66 |
+
ALLOWED_SCALES_API = [1.0, 2.0, 4.0, 8.0]
|
67 |
+
DEFAULT_MODEL_API = "RealESRGAN_x4plus"
|
68 |
+
DEFAULT_SCALE_API = 4.0
|
69 |
+
DEFAULT_TILE_SIZE = 400 # Default tile size to use on memory error retry
|
70 |
+
|
71 |
+
def update_available_models():
|
72 |
+
"""Scans the model directory and updates the list of available models."""
|
73 |
+
global available_models
|
74 |
+
try:
|
75 |
+
models = [f.stem for f in MODEL_DIR.glob("*.pth")]
|
76 |
+
if not models:
|
77 |
+
logger.warning(f"No model files (.pth) found in {MODEL_DIR}")
|
78 |
+
available_models = sorted(models)
|
79 |
+
logger.info(f"Available models updated: {available_models}")
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"Error scanning model directory {MODEL_DIR}: {e}")
|
82 |
+
available_models = []
|
83 |
+
|
84 |
+
# Initialize models on startup
|
85 |
+
update_available_models()
|
86 |
+
|
87 |
+
|
88 |
+
# --- Helper Functions ---
|
89 |
+
def release_lock():
|
90 |
+
"""Releases the processing lock."""
|
91 |
+
global processing_lock
|
92 |
+
processing_lock = False
|
93 |
+
logger.info("Processing lock released.")
|
94 |
+
|
95 |
+
# --- API Endpoints ---
|
96 |
+
@app.get("/")
|
97 |
+
async def root():
|
98 |
+
"""Root endpoint providing basic API information."""
|
99 |
+
return {"message": "Image Enhancer API is running"}
|
100 |
+
|
101 |
+
@app.get("/models/", response_model=List[str])
|
102 |
+
async def get_models():
|
103 |
+
"""Returns a list of available Real-ESRGAN models."""
|
104 |
+
if not available_models:
|
105 |
+
update_available_models() # Attempt to rescan if list is empty
|
106 |
+
if not available_models:
|
107 |
+
raise HTTPException(status_code=404, detail=f"No models found in {MODEL_DIR}")
|
108 |
+
return available_models
|
109 |
+
|
110 |
+
@app.post("/enhance/", response_class=FileResponse)
|
111 |
+
async def enhance_image(
|
112 |
+
background_tasks: BackgroundTasks,
|
113 |
+
file: UploadFile = File(...),
|
114 |
+
model_name: Optional[str] = Form(DEFAULT_MODEL_API),
|
115 |
+
outscale: float = Form(DEFAULT_SCALE_API),
|
116 |
+
face_enhance: bool = Form(False),
|
117 |
+
fp32: bool = Form(False),
|
118 |
+
tile: Optional[int] = Form(0)
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Enhances an uploaded image using Real-ESRGAN.
|
122 |
+
Automatically retries with tiling if an out-of-memory error is detected.
|
123 |
+
"""
|
124 |
+
global processing_lock
|
125 |
+
temp_input_path = None
|
126 |
+
temp_output_dir_for_request = None
|
127 |
+
temp_input_dir_for_request = None # Added for consistency
|
128 |
+
|
129 |
+
# --- Request Handling ---
|
130 |
+
request_id = uuid.uuid4().hex
|
131 |
+
logger.info(f"Received enhancement request ID: {request_id}")
|
132 |
+
|
133 |
+
# Check processing lock
|
134 |
+
if processing_lock:
|
135 |
+
logger.warning(f"Request {request_id}: Server busy, denying request.")
|
136 |
+
raise HTTPException(
|
137 |
+
status_code=429,
|
138 |
+
detail="Server is busy processing another image. Please try again shortly."
|
139 |
+
)
|
140 |
+
processing_lock = True
|
141 |
+
logger.info(f"Request {request_id}: Processing lock acquired.")
|
142 |
+
|
143 |
+
# --- Input Validation ---
|
144 |
+
# Validate model name against allowed list
|
145 |
+
if model_name not in AVAILABLE_MODELS_API:
|
146 |
+
logger.warning(f"Request {request_id}: Invalid model_name specified: '{model_name}'. Allowed: {AVAILABLE_MODELS_API}")
|
147 |
+
release_lock()
|
148 |
+
raise HTTPException(
|
149 |
+
status_code=400,
|
150 |
+
detail=f"Invalid model name '{model_name}'. Allowed values: {AVAILABLE_MODELS_API}"
|
151 |
+
)
|
152 |
+
|
153 |
+
# Validate scale against allowed list
|
154 |
+
if outscale not in ALLOWED_SCALES_API:
|
155 |
+
logger.warning(f"Request {request_id}: Invalid outscale specified: '{outscale}'. Allowed: {ALLOWED_SCALES_API}")
|
156 |
+
release_lock()
|
157 |
+
raise HTTPException(
|
158 |
+
status_code=400,
|
159 |
+
detail=f"Invalid scale value '{outscale}'. Allowed values: {ALLOWED_SCALES_API}"
|
160 |
+
)
|
161 |
+
|
162 |
+
# Validate file type
|
163 |
+
if not file.content_type or not file.content_type.startswith("image/"):
|
164 |
+
logger.warning(f"Request {request_id}: Invalid file type uploaded: {file.content_type}")
|
165 |
+
release_lock()
|
166 |
+
raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
|
167 |
+
|
168 |
+
# --- Model Existence Check ---
|
169 |
+
# Check if the validated model actually exists in the scanned directory
|
170 |
+
if model_name not in available_models:
|
171 |
+
logger.error(f"Request {request_id}: Model '{model_name}' is allowed but not found in {MODEL_DIR}. Scanned models: {available_models}")
|
172 |
+
update_available_models() # Try rescanning
|
173 |
+
if model_name not in available_models:
|
174 |
+
release_lock()
|
175 |
+
raise HTTPException(
|
176 |
+
status_code=500,
|
177 |
+
detail=f"Model '{model_name}' not found on server, even though it's an allowed option. Please check server configuration."
|
178 |
+
)
|
179 |
+
|
180 |
+
final_model_name = model_name # Use the validated model name
|
181 |
+
logger.info(f"Request {request_id}: Using validated model: {final_model_name}, scale: {outscale}")
|
182 |
+
|
183 |
+
try:
|
184 |
+
# --- File Handling ---
|
185 |
+
# Create unique temporary paths for this request
|
186 |
+
input_suffix = Path(file.filename).suffix if file.filename else '.png'
|
187 |
+
# Use original filename for input file within its own request dir
|
188 |
+
temp_input_filename = Path(file.filename).name if file.filename else f"input_{request_id}{input_suffix}"
|
189 |
+
|
190 |
+
# Input directory for this specific request
|
191 |
+
temp_input_dir_for_request = INPUT_DIR / request_id
|
192 |
+
temp_input_dir_for_request.mkdir(exist_ok=True)
|
193 |
+
temp_input_path = temp_input_dir_for_request / temp_input_filename
|
194 |
+
|
195 |
+
# Output directory for this specific request's results
|
196 |
+
temp_output_dir_for_request = OUTPUT_DIR / request_id
|
197 |
+
temp_output_dir_for_request.mkdir(exist_ok=True)
|
198 |
+
|
199 |
+
# Save uploaded file to its request-specific input dir
|
200 |
+
try:
|
201 |
+
logger.info(f"Request {request_id}: Saving uploaded file to {temp_input_path}")
|
202 |
+
contents = await file.read()
|
203 |
+
with open(temp_input_path, "wb") as buffer:
|
204 |
+
buffer.write(contents)
|
205 |
+
logger.info(f"Request {request_id}: Uploaded file saved successfully.")
|
206 |
+
except Exception as e:
|
207 |
+
logger.error(f"Request {request_id}: Failed to save uploaded file: {e}")
|
208 |
+
raise HTTPException(status_code=500, detail="Failed to save uploaded file.")
|
209 |
+
finally:
|
210 |
+
await file.close() # Ensure file handle is closed
|
211 |
+
|
212 |
+
# --- Inference Execution ---
|
213 |
+
# Construct command (base_cmd now uses temp_input_path which includes the subdir)
|
214 |
+
base_cmd = [
|
215 |
+
sys.executable, str(INFERENCE_SCRIPT),
|
216 |
+
"-i", str(temp_input_path),
|
217 |
+
"-o", str(temp_output_dir_for_request),
|
218 |
+
"-n", final_model_name,
|
219 |
+
"-s", str(outscale),
|
220 |
+
]
|
221 |
+
if face_enhance:
|
222 |
+
base_cmd.append("--face_enhance")
|
223 |
+
if fp32:
|
224 |
+
base_cmd.append("--fp32")
|
225 |
+
# Add tile param only if explicitly provided (> 0) or during retry
|
226 |
+
if tile > 0:
|
227 |
+
base_cmd.extend(["-t", str(tile)])
|
228 |
+
|
229 |
+
logger.info(f"Request {request_id}: Preparing initial inference command...")
|
230 |
+
|
231 |
+
# Execute the script - Attempt 1 (No Tile unless specified)
|
232 |
+
try:
|
233 |
+
logger.info(f"Request {request_id}: Running inference (Attempt 1): {' '.join(base_cmd)}")
|
234 |
+
process = subprocess.run(
|
235 |
+
base_cmd,
|
236 |
+
capture_output=True,
|
237 |
+
text=True,
|
238 |
+
check=True,
|
239 |
+
cwd=REAL_ESRGAN_DIR
|
240 |
+
)
|
241 |
+
logger.info(f"Request {request_id}: Inference script (Attempt 1) stdout:{process.stdout}")
|
242 |
+
if process.stderr:
|
243 |
+
logger.warning(f"Request {request_id}: Inference script (Attempt 1) stderr:{process.stderr}")
|
244 |
+
|
245 |
+
except (subprocess.CalledProcessError, RuntimeError) as e:
|
246 |
+
error_output = ""
|
247 |
+
if isinstance(e, subprocess.CalledProcessError):
|
248 |
+
error_output = e.stderr
|
249 |
+
logger.error(f"Request {request_id}: Inference script failed (Attempt 1) with exit code {e.returncode}")
|
250 |
+
logger.error(f"Request {request_id}: Stdout: {e.stdout}")
|
251 |
+
logger.error(f"Request {request_id}: Stderr: {e.stderr}")
|
252 |
+
else: # Handle RuntimeError which might be raised by realesrgan directly
|
253 |
+
error_output = str(e)
|
254 |
+
logger.error(f"Request {request_id}: Inference script raised RuntimeError (Attempt 1): {e}")
|
255 |
+
|
256 |
+
# Check if it's a memory error and tile wasn't already manually set
|
257 |
+
is_memory_error = "memory" in error_output.lower() or "cuda" in error_output.lower()
|
258 |
+
tile_arg_present = any(arg == "-t" for arg in base_cmd)
|
259 |
+
|
260 |
+
if is_memory_error and not tile_arg_present:
|
261 |
+
logger.warning(f"Request {request_id}: Detected potential memory error. Retrying with tiling (tile_size={DEFAULT_TILE_SIZE})...")
|
262 |
+
# Attempt 2 (With Tile)
|
263 |
+
retry_cmd = base_cmd + ["-t", str(DEFAULT_TILE_SIZE)]
|
264 |
+
try:
|
265 |
+
logger.info(f"Request {request_id}: Running inference (Attempt 2 - Tiled): {' '.join(retry_cmd)}")
|
266 |
+
process = subprocess.run(
|
267 |
+
retry_cmd,
|
268 |
+
capture_output=True,
|
269 |
+
text=True,
|
270 |
+
check=True,
|
271 |
+
cwd=REAL_ESRGAN_DIR
|
272 |
+
)
|
273 |
+
logger.info(f"Request {request_id}: Inference script (Attempt 2 - Tiled) stdout:{process.stdout}")
|
274 |
+
if process.stderr:
|
275 |
+
logger.warning(f"Request {request_id}: Inference script (Attempt 2 - Tiled) stderr:{process.stderr}")
|
276 |
+
except (subprocess.CalledProcessError, RuntimeError) as e2:
|
277 |
+
logger.error(f"Request {request_id}: Inference script failed even on retry with tiling.")
|
278 |
+
# Log the second error
|
279 |
+
if isinstance(e2, subprocess.CalledProcessError):
|
280 |
+
logger.error(f"Request {request_id}: Retry Exit Code: {e2.returncode}, Stderr: {e2.stderr}")
|
281 |
+
error_output = e2.stderr # Use the error from the retry attempt
|
282 |
+
else:
|
283 |
+
logger.error(f"Request {request_id}: Retry RuntimeError: {e2}")
|
284 |
+
error_output = str(e2)
|
285 |
+
# Raise original error type but with potentially updated message from retry
|
286 |
+
raise HTTPException(status_code=500, detail=f"Image enhancement failed, even with tiling: {error_output or 'Unknown error'}")
|
287 |
+
else:
|
288 |
+
# Not a memory error, or tile was already specified - fail normally
|
289 |
+
raise HTTPException(status_code=500, detail=f"Image enhancement script failed: {error_output or 'Unknown error'}")
|
290 |
+
|
291 |
+
except Exception as e:
|
292 |
+
# Catch any other unexpected errors during subprocess execution
|
293 |
+
logger.error(f"Request {request_id}: Unexpected error executing inference script: {e}")
|
294 |
+
logger.error(traceback.format_exc())
|
295 |
+
raise HTTPException(status_code=500, detail=f"Failed to run enhancement process: {e}")
|
296 |
+
|
297 |
+
# --- Result Handling ---
|
298 |
+
# Find the output file (assumes script outputs one file with '_out' suffix)
|
299 |
+
# The script `inference_realesrgan.py` saves the output as `{basename}_out.{ext}`
|
300 |
+
original_basename = Path(temp_input_filename).stem
|
301 |
+
expected_output_stem = f"{original_basename}_out"
|
302 |
+
output_files = list(temp_output_dir_for_request.glob(f"{expected_output_stem}.*"))
|
303 |
+
|
304 |
+
if not output_files:
|
305 |
+
logger.error(f"Request {request_id}: No output file found in {temp_output_dir_for_request} matching stem {expected_output_stem}")
|
306 |
+
raise HTTPException(status_code=500, detail="Enhancement finished, but output file not found.")
|
307 |
+
|
308 |
+
output_path = output_files[0]
|
309 |
+
output_media_type = f"image/{output_path.suffix.strip('.')}"
|
310 |
+
output_filename = f"enhanced_{Path(file.filename).name}" if file.filename else f"enhanced_{request_id}{output_path.suffix}"
|
311 |
+
|
312 |
+
logger.info(f"Request {request_id}: Enhancement successful. Output: {output_path}")
|
313 |
+
|
314 |
+
# Schedule cleanup task (input file and the whole output dir for this request)
|
315 |
+
# background_tasks.add_task(cleanup_files, [temp_input_path, temp_output_dir_for_request]) # Removed cleanup
|
316 |
+
# Release lock AFTER scheduling cleanup but BEFORE returning response
|
317 |
+
background_tasks.add_task(release_lock)
|
318 |
+
|
319 |
+
# Return the enhanced image file
|
320 |
+
return FileResponse(
|
321 |
+
path=output_path,
|
322 |
+
media_type=output_media_type,
|
323 |
+
filename=output_filename
|
324 |
+
)
|
325 |
+
|
326 |
+
except HTTPException as http_exc:
|
327 |
+
# If an HTTPException occurred (validation, busy, etc.), release lock immediately
|
328 |
+
release_lock()
|
329 |
+
# Re-raise the exception to be handled by FastAPI
|
330 |
+
raise http_exc
|
331 |
+
except Exception as e:
|
332 |
+
error_msg = f"Request {request_id}: Unexpected error during enhancement: {str(e)}"
|
333 |
+
logger.error(error_msg)
|
334 |
+
logger.error(traceback.format_exc())
|
335 |
+
# Ensure cleanup happens even on unexpected errors (Cleanup is removed, but keep release_lock)
|
336 |
+
# We need to potentially clean up the created input directory as well if saving failed
|
337 |
+
# For simplicity now, inputs/outputs persist on errors too, consistent with success path
|
338 |
+
release_lock()
|
339 |
+
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
|
340 |
+
|
341 |
+
|
342 |
+
@app.get("/status/")
|
343 |
+
async def status():
|
344 |
+
"""Checks the API status and resource usage."""
|
345 |
+
logger.info("Status check requested.")
|
346 |
+
return {
|
347 |
+
"status": "ok" if not processing_lock else "busy",
|
348 |
+
"processing_active": processing_lock,
|
349 |
+
"available_models": available_models,
|
350 |
+
"memory_usage": {
|
351 |
+
"percent": f"{psutil.virtual_memory().percent}%",
|
352 |
+
"available": f"{psutil.virtual_memory().available / (1024**3):.2f} GB",
|
353 |
+
},
|
354 |
+
"cpu_usage": f"{psutil.cpu_percent()}%",
|
355 |
+
"real_esrgan_dir_exists": REAL_ESRGAN_DIR.exists(),
|
356 |
+
"inference_script_exists": INFERENCE_SCRIPT.exists(),
|
357 |
+
"model_dir_exists": MODEL_DIR.exists(),
|
358 |
+
"input_dir_exists": INPUT_DIR.exists(),
|
359 |
+
"output_dir_exists": OUTPUT_DIR.exists(),
|
360 |
+
}
|
361 |
+
|
362 |
+
# --- Server Execution ---
|
363 |
+
if __name__ == "__main__":
|
364 |
+
logger.info(f"Starting Image Enhancer API server on port {API_PORT}...")
|
365 |
+
logger.info(f"Real-ESRGAN Directory: {REAL_ESRGAN_DIR}")
|
366 |
+
logger.info(f"Inference Script: {INFERENCE_SCRIPT}")
|
367 |
+
logger.info(f"Model Directory: {MODEL_DIR}")
|
368 |
+
logger.info(f"API Input Directory: {INPUT_DIR}")
|
369 |
+
logger.info(f"API Output Directory: {OUTPUT_DIR}")
|
370 |
+
update_available_models() # Ensure models are listed on startup
|
371 |
+
uvicorn.run(
|
372 |
+
"api:app",
|
373 |
+
host="0.0.0.0",
|
374 |
+
port=API_PORT,
|
375 |
+
reload=False, # Use reload carefully, can cause issues with locking/models
|
376 |
+
log_level="info" # Uvicorn's own log level
|
377 |
+
)
|
app.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import requests
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# --- Configuration ---
|
9 |
+
API_URL = "http://localhost:8000" # Keep API port
|
10 |
+
PAGE_TITLE = "Image Enhancer App"
|
11 |
+
PAGE_ICON = "✨"
|
12 |
+
# Define fixed model options and default
|
13 |
+
AVAILABLE_MODELS = ["RealESRGAN_x4plus", "RealESRGAN_x2plus"]
|
14 |
+
DEFAULT_MODEL = "RealESRGAN_x4plus"
|
15 |
+
# Define fixed scale options and default
|
16 |
+
AVAILABLE_SCALES = [1, 2, 4, 8]
|
17 |
+
DEFAULT_SCALE = 4
|
18 |
+
|
19 |
+
# --- Initialize Session State ---
|
20 |
+
# Use get to avoid errors if keys don't exist yet after a code change/refresh
|
21 |
+
st.session_state.setdefault('enhanced_image_data', None)
|
22 |
+
st.session_state.setdefault('enhanced_image_caption', None)
|
23 |
+
st.session_state.setdefault('download_filename', None)
|
24 |
+
st.session_state.setdefault('download_mime', None)
|
25 |
+
st.session_state.setdefault('current_file_identifier', None)
|
26 |
+
st.session_state.setdefault('error_message', None)
|
27 |
+
|
28 |
+
# --- Streamlit Page Setup ---
|
29 |
+
st.set_page_config(
|
30 |
+
page_title=PAGE_TITLE,
|
31 |
+
page_icon=PAGE_ICON,
|
32 |
+
layout="wide",
|
33 |
+
)
|
34 |
+
|
35 |
+
st.title(PAGE_TITLE)
|
36 |
+
st.markdown("""
|
37 |
+
Enhance your images using the power of AI upscaling.
|
38 |
+
Upload an image and choose your enhancement options.
|
39 |
+
""")
|
40 |
+
|
41 |
+
# --- Helper Functions ---
|
42 |
+
# @st.cache_data # Cache the status check result for a short time? Maybe not needed.
|
43 |
+
def get_api_status():
|
44 |
+
"""Checks the status of the backend API."""
|
45 |
+
try:
|
46 |
+
response = requests.get(f"{API_URL}/status/", timeout=5)
|
47 |
+
if response.status_code == 200:
|
48 |
+
return response.json()
|
49 |
+
else:
|
50 |
+
st.error(f"API Error: Status code {response.status_code}")
|
51 |
+
return None
|
52 |
+
except requests.exceptions.RequestException as e:
|
53 |
+
st.error(f"API Connection Error: {e}")
|
54 |
+
return None
|
55 |
+
|
56 |
+
# No longer fetching models from API for selection, using fixed list
|
57 |
+
# def get_available_models():
|
58 |
+
# """Fetches the list of available models from the API."""
|
59 |
+
# try:
|
60 |
+
# response = requests.get(f"{API_URL}/models/", timeout=5)
|
61 |
+
# if response.status_code == 200:
|
62 |
+
# return response.json()
|
63 |
+
# else:
|
64 |
+
# st.warning(f"Could not fetch models (Status: {response.status_code}). Using defaults.")
|
65 |
+
# return ["RealESRGAN_x4plus", "RealESRGAN_x2plus"] # Fallback
|
66 |
+
# except requests.exceptions.RequestException as e:
|
67 |
+
# st.warning(f"Could not fetch models (Error: {e}). Using defaults.")
|
68 |
+
# return ["RealESRGAN_x4plus", "RealESRGAN_x2plus"] # Fallback
|
69 |
+
|
70 |
+
# --- Sidebar Controls ---
|
71 |
+
with st.sidebar:
|
72 |
+
st.header("Enhancement Options")
|
73 |
+
|
74 |
+
# Model selection using fixed list
|
75 |
+
# Ensure default is selected if available, otherwise fallback to first item
|
76 |
+
default_model_index = AVAILABLE_MODELS.index(DEFAULT_MODEL) if DEFAULT_MODEL in AVAILABLE_MODELS else 0
|
77 |
+
selected_model = st.selectbox(
|
78 |
+
"Select Model",
|
79 |
+
AVAILABLE_MODELS,
|
80 |
+
index=default_model_index
|
81 |
+
)
|
82 |
+
|
83 |
+
# Scale factor selection using fixed list
|
84 |
+
default_scale_index = AVAILABLE_SCALES.index(DEFAULT_SCALE) if DEFAULT_SCALE in AVAILABLE_SCALES else 0
|
85 |
+
output_scale = st.selectbox(
|
86 |
+
"Output Scale Factor",
|
87 |
+
AVAILABLE_SCALES,
|
88 |
+
index=default_scale_index
|
89 |
+
)
|
90 |
+
# Remove number input and auto-detect logic
|
91 |
+
# # Determine default scale based on model name if possible
|
92 |
+
# default_scale = 2.0 if selected_model and 'x2' in selected_model else 4.0
|
93 |
+
# output_scale = st.number_input("Output Scale Factor", min_value=1.0, max_value=8.0, value=default_scale, step=0.1)
|
94 |
+
|
95 |
+
# Checkboxes for boolean flags
|
96 |
+
face_enhance = st.checkbox("Enable Face Enhancement (GFPGAN)", value=False)
|
97 |
+
use_fp32 = st.checkbox("Use FP32 Precision (Slower, More Memory)", value=False)
|
98 |
+
|
99 |
+
st.markdown("---")
|
100 |
+
st.header("API Status")
|
101 |
+
if st.button("Check API Status"):
|
102 |
+
status_info = get_api_status()
|
103 |
+
if status_info:
|
104 |
+
status_text = status_info.get("status", "unknown")
|
105 |
+
if status_text == "ok":
|
106 |
+
st.success("✅ API is running and ready.")
|
107 |
+
elif status_text == "busy":
|
108 |
+
st.warning("⏳ API is currently busy processing.")
|
109 |
+
else:
|
110 |
+
st.error(f"❌ API reported status: {status_text}.")
|
111 |
+
|
112 |
+
# Display API-reported models for confirmation
|
113 |
+
api_models = status_info.get("available_models", [])
|
114 |
+
if api_models:
|
115 |
+
st.write(f"**API Models Found:** {', '.join(api_models)}")
|
116 |
+
# Check if selected model is actually available according to API
|
117 |
+
if selected_model not in api_models:
|
118 |
+
st.warning(f"Selected model '{selected_model}' not found by API!")
|
119 |
+
else:
|
120 |
+
st.warning("Could not verify available models from API.")
|
121 |
+
|
122 |
+
if "memory_usage" in status_info:
|
123 |
+
st.write(f"**Memory:** {status_info['memory_usage'].get('percent', 'N/A')}")
|
124 |
+
# Error handling is done within get_api_status
|
125 |
+
|
126 |
+
# --- Main Area ---
|
127 |
+
|
128 |
+
# File uploader
|
129 |
+
uploaded_file = st.file_uploader(
|
130 |
+
"Choose an image to enhance...", type=["jpg", "jpeg", "png", "bmp", "webp"]
|
131 |
+
)
|
132 |
+
|
133 |
+
if uploaded_file is not None:
|
134 |
+
# Use name + size as a relatively stable identifier across uploads
|
135 |
+
current_file_identifier = f"{uploaded_file.name}-{uploaded_file.size}"
|
136 |
+
|
137 |
+
# --- Reset state if a new file is uploaded ---
|
138 |
+
if current_file_identifier != st.session_state.get('current_file_identifier'):
|
139 |
+
st.session_state.enhanced_image_data = None
|
140 |
+
st.session_state.enhanced_image_caption = None
|
141 |
+
st.session_state.download_filename = None
|
142 |
+
st.session_state.download_mime = None
|
143 |
+
st.session_state.error_message = None # Clear previous errors
|
144 |
+
st.session_state.current_file_identifier = current_file_identifier
|
145 |
+
|
146 |
+
# Display the original image
|
147 |
+
col1, col2 = st.columns(2)
|
148 |
+
with col1:
|
149 |
+
st.subheader("Original Image")
|
150 |
+
try:
|
151 |
+
original_image = Image.open(uploaded_file)
|
152 |
+
st.image(original_image, use_column_width=True, caption="Original")
|
153 |
+
|
154 |
+
# --- Moved Enhance Button Here ---
|
155 |
+
enhance_button_pressed = st.button("Enhance Image ✨")
|
156 |
+
|
157 |
+
except Exception as e:
|
158 |
+
st.error(f"Error loading image: {e}")
|
159 |
+
# Clear state if original image load fails
|
160 |
+
st.session_state.enhanced_image_data = None
|
161 |
+
st.session_state.error_message = None
|
162 |
+
st.session_state.current_file_identifier = None
|
163 |
+
enhance_button_pressed = False # Ensure button state is false if image fails
|
164 |
+
st.stop() # Stop execution if image can't be loaded
|
165 |
+
|
166 |
+
# Process image button and display enhanced result area
|
167 |
+
with col2:
|
168 |
+
st.subheader("Enhanced Result")
|
169 |
+
|
170 |
+
# Use a container within col2 for dynamic content (spinner, result)
|
171 |
+
# Button is now outside this container, in col1
|
172 |
+
result_container = st.container()
|
173 |
+
|
174 |
+
with result_container:
|
175 |
+
# Trigger processing logic if button in col1 was pressed
|
176 |
+
if enhance_button_pressed:
|
177 |
+
st.session_state.enhanced_image_data = None # Clear previous result before trying again
|
178 |
+
st.session_state.error_message = None # Clear previous errors
|
179 |
+
# Show spinner *within the container*
|
180 |
+
with st.spinner("Enhancing your image... This might take a moment."):
|
181 |
+
try:
|
182 |
+
# Prepare form data for the API request
|
183 |
+
# Use getvalue() to read file content for the request
|
184 |
+
files = {"file": (uploaded_file.name, uploaded_file.getvalue(), uploaded_file.type)}
|
185 |
+
payload = {
|
186 |
+
"model_name": selected_model,
|
187 |
+
"outscale": float(output_scale),
|
188 |
+
"face_enhance": face_enhance,
|
189 |
+
"fp32": use_fp32
|
190 |
+
}
|
191 |
+
|
192 |
+
# Send request to API
|
193 |
+
start_time = time.time()
|
194 |
+
response = requests.post(f"{API_URL}/enhance/", files=files, data=payload, timeout=300)
|
195 |
+
end_time = time.time()
|
196 |
+
|
197 |
+
if response.status_code == 200:
|
198 |
+
# --- Store results in session state on success ---
|
199 |
+
st.session_state.enhanced_image_data = response.content
|
200 |
+
st.session_state.enhanced_image_caption = f"Enhanced ({end_time - start_time:.2f}s)"
|
201 |
+
|
202 |
+
# Prepare filename for download
|
203 |
+
base, ext = Path(uploaded_file.name).stem, Path(uploaded_file.name).suffix
|
204 |
+
download_filename = f"{base}_enhanced_s{int(output_scale)}x{ext if ext else '.png'}" # Use int scale
|
205 |
+
st.session_state.download_filename = download_filename
|
206 |
+
st.session_state.download_mime = response.headers.get("content-type", "image/png")
|
207 |
+
st.session_state.error_message = None # Clear error on success
|
208 |
+
|
209 |
+
else:
|
210 |
+
# Store error details from JSON response
|
211 |
+
try:
|
212 |
+
error_details = response.json().get('detail', 'Unknown API error')
|
213 |
+
except requests.exceptions.JSONDecodeError:
|
214 |
+
error_details = response.text # Fallback to raw text
|
215 |
+
st.session_state.error_message = f"API Error (Status {response.status_code}): {error_details}"
|
216 |
+
st.session_state.enhanced_image_data = None # Ensure data is cleared on error
|
217 |
+
|
218 |
+
except requests.exceptions.Timeout:
|
219 |
+
st.session_state.error_message = "Error: The enhancement request timed out. The process may be too long or the server might be overloaded."
|
220 |
+
st.session_state.enhanced_image_data = None
|
221 |
+
except requests.exceptions.RequestException as e:
|
222 |
+
st.session_state.error_message = f"Error connecting to API: {e}"
|
223 |
+
st.session_state.enhanced_image_data = None
|
224 |
+
except Exception as e:
|
225 |
+
st.session_state.error_message = f"An unexpected error occurred: {e}"
|
226 |
+
st.session_state.enhanced_image_data = None
|
227 |
+
# Spinner automatically stops here
|
228 |
+
|
229 |
+
# --- Display results/errors/placeholder also within the container ---
|
230 |
+
# Display Error Message (if any)
|
231 |
+
if st.session_state.error_message:
|
232 |
+
st.error(st.session_state.error_message)
|
233 |
+
|
234 |
+
# Display Enhanced Image and Download Button (if available in state)
|
235 |
+
# Use elif to prevent showing placeholder if error exists or image exists
|
236 |
+
elif st.session_state.enhanced_image_data is not None:
|
237 |
+
try:
|
238 |
+
enhanced_image = Image.open(io.BytesIO(st.session_state.enhanced_image_data))
|
239 |
+
st.image(
|
240 |
+
enhanced_image,
|
241 |
+
use_column_width=True,
|
242 |
+
caption=st.session_state.enhanced_image_caption
|
243 |
+
)
|
244 |
+
st.download_button(
|
245 |
+
label="Download Enhanced Image 💾",
|
246 |
+
data=st.session_state.enhanced_image_data,
|
247 |
+
file_name=st.session_state.download_filename,
|
248 |
+
mime=st.session_state.download_mime
|
249 |
+
)
|
250 |
+
except Exception as e:
|
251 |
+
st.error(f"Error displaying enhanced image: {e}")
|
252 |
+
# Clear state if display fails to prevent repeated errors
|
253 |
+
st.session_state.enhanced_image_data = None
|
254 |
+
st.session_state.error_message = "Failed to display the enhanced image."
|
255 |
+
|
256 |
+
# Display placeholder only if button wasn't pressed in this run AND no image/error
|
257 |
+
# Note: enhance_button_pressed is defined in col1 now, but its state persists across the rerun
|
258 |
+
elif not st.session_state.get('enhanced_image_data') and not st.session_state.get('error_message'):
|
259 |
+
st.markdown("Click 'Enhance Image ✨' (below Original) to process.")
|
260 |
+
|
261 |
+
else:
|
262 |
+
st.info("Upload an image to begin the enhancement process.")
|
263 |
+
# --- Clear state if no file is uploaded ---
|
264 |
+
st.session_state.enhanced_image_data = None
|
265 |
+
st.session_state.enhanced_image_caption = None
|
266 |
+
st.session_state.download_filename = None
|
267 |
+
st.session_state.download_mime = None
|
268 |
+
st.session_state.current_file_identifier = None
|
269 |
+
st.session_state.error_message = None
|
270 |
+
|
271 |
+
# --- Footer ---
|
272 |
+
st.markdown("---")
|
273 |
+
st.markdown("Powered by Real-ESRGAN") # Removed link as requested
|
environment.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: esrgan-env
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
- conda-forge
|
5 |
+
dependencies:
|
6 |
+
- pip
|
7 |
+
- python=3.10
|
8 |
+
- pytorch::pytorch=1.11.0
|
9 |
+
- pytorch::torchvision
|
10 |
+
- pip:
|
11 |
+
- opencv-python==4.11.0.86
|
12 |
+
- PyYAML
|
13 |
+
- tqdm
|
14 |
+
- yapf
|
15 |
+
- basicsr-fixed
|
16 |
+
- facexlib
|
17 |
+
- gfpgan
|
18 |
+
|
19 |
+
- fastapi==0.104.0
|
20 |
+
- uvicorn==0.23.2
|
21 |
+
- streamlit==1.27.0
|
22 |
+
- python-multipart==0.0.6
|
23 |
+
- psutil
|
run.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import webbrowser
|
6 |
+
import requests
|
7 |
+
import signal
|
8 |
+
import psutil
|
9 |
+
from pathlib import Path
|
10 |
+
import platform # Import platform module
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
"""Run the Image Enhancer Application by starting API and Streamlit app."""
|
15 |
+
# Change to the directory of this script
|
16 |
+
script_dir = Path(__file__).parent.resolve()
|
17 |
+
os.chdir(script_dir)
|
18 |
+
|
19 |
+
print("\n🚀 Starting Image Enhancer Application...\n")
|
20 |
+
|
21 |
+
# Define URLs and Ports
|
22 |
+
api_port = 8000
|
23 |
+
streamlit_port = 8501
|
24 |
+
api_url = f"http://localhost:{api_port}"
|
25 |
+
streamlit_url = f"http://localhost:{streamlit_port}"
|
26 |
+
|
27 |
+
api_process = None
|
28 |
+
streamlit_process = None
|
29 |
+
|
30 |
+
# --- Helper: Kill Process Tree ---
|
31 |
+
def kill_proc_tree(pid, sig=signal.SIGTERM, include_parent=True, timeout=None, on_terminate=None):
|
32 |
+
"""Kill a process tree (including grandchildren) with signal `sig` and fallback signal `signal.SIGKILL`.
|
33 |
+
Source: https://psutil.readthedocs.io/en/latest/#kill-process-tree
|
34 |
+
"""
|
35 |
+
assert pid != os.getpid(), "won't kill myself"
|
36 |
+
try:
|
37 |
+
parent = psutil.Process(pid)
|
38 |
+
children = parent.children(recursive=True)
|
39 |
+
if include_parent:
|
40 |
+
children.append(parent)
|
41 |
+
for p in children:
|
42 |
+
try:
|
43 |
+
p.send_signal(sig)
|
44 |
+
except psutil.NoSuchProcess:
|
45 |
+
pass
|
46 |
+
gone, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate)
|
47 |
+
if alive:
|
48 |
+
# Fallback to SIGKILL for processes that didn't terminate
|
49 |
+
for p in alive:
|
50 |
+
try:
|
51 |
+
p.kill()
|
52 |
+
except psutil.NoSuchProcess:
|
53 |
+
pass
|
54 |
+
psutil.wait_procs(alive, timeout=1) # Wait a bit more
|
55 |
+
except psutil.NoSuchProcess:
|
56 |
+
pass # Process already gone
|
57 |
+
|
58 |
+
# --- Helper: Clean Up Existing Processes ---
|
59 |
+
def cleanup_existing_processes():
|
60 |
+
print("🧹 Checking for and cleaning up existing related processes...")
|
61 |
+
killed_count = 0
|
62 |
+
current_pid = os.getpid()
|
63 |
+
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
64 |
+
if proc.info['pid'] == current_pid:
|
65 |
+
continue # Don't kill self
|
66 |
+
try:
|
67 |
+
cmdline = proc.info['cmdline']
|
68 |
+
if not cmdline: continue
|
69 |
+
|
70 |
+
# Check for Uvicorn running api.py on the specified port
|
71 |
+
is_api = (
|
72 |
+
('uvicorn' in proc.info['name'] or 'python' in proc.info['name']) and
|
73 |
+
any(f'api:app' in arg for arg in cmdline) and
|
74 |
+
any(f'--port={api_port}' in arg or f'--port {api_port}' in arg for arg in cmdline)
|
75 |
+
)
|
76 |
+
|
77 |
+
# Check for Streamlit running app.py
|
78 |
+
is_streamlit = (
|
79 |
+
('streamlit' in proc.info['name'] or 'python' in proc.info['name']) and
|
80 |
+
any('streamlit' in arg and 'run' in arg and 'app.py' in arg for arg in cmdline)
|
81 |
+
)
|
82 |
+
|
83 |
+
if is_api or is_streamlit:
|
84 |
+
service_name = "API server" if is_api else "Streamlit app"
|
85 |
+
print(f" Killing existing {service_name} (PID: {proc.info['pid']}) {' '.join(cmdline)[:80]}...")
|
86 |
+
kill_proc_tree(proc.info['pid']) # Kill the process and its children
|
87 |
+
killed_count += 1
|
88 |
+
|
89 |
+
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
90 |
+
continue # Process might have died already
|
91 |
+
except Exception as e:
|
92 |
+
print(f" Error checking process {proc.info['pid']}: {e}")
|
93 |
+
if killed_count > 0:
|
94 |
+
print(f" Killed {killed_count} existing process(es).")
|
95 |
+
else:
|
96 |
+
print(" No conflicting processes found.")
|
97 |
+
time.sleep(1) # Give OS a moment to release ports
|
98 |
+
|
99 |
+
# --- Start API Server ---
|
100 |
+
def start_api_server():
|
101 |
+
nonlocal api_process
|
102 |
+
print(f"🔄 Starting API server (api.py) on port {api_port}...")
|
103 |
+
cmd = [sys.executable, "api.py"] # Assuming api.py handles its own uvicorn run
|
104 |
+
try:
|
105 |
+
# Use Popen for non-blocking execution
|
106 |
+
api_process = subprocess.Popen(
|
107 |
+
cmd,
|
108 |
+
stdout=subprocess.PIPE,
|
109 |
+
stderr=subprocess.PIPE,
|
110 |
+
bufsize=1, # Line buffered
|
111 |
+
universal_newlines=True,
|
112 |
+
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if platform.system() == "Windows" else 0 # Create new group for easier cleanup on Windows
|
113 |
+
)
|
114 |
+
print(f" API server process started (PID: {api_process.pid})")
|
115 |
+
except Exception as e:
|
116 |
+
print(f"❌ Failed to start API server: {e}")
|
117 |
+
api_process = None
|
118 |
+
return api_process
|
119 |
+
|
120 |
+
# --- Check API Health ---
|
121 |
+
def check_api_health(timeout=3):
|
122 |
+
if not api_process or api_process.poll() is not None:
|
123 |
+
return False # Process not running
|
124 |
+
try:
|
125 |
+
response = requests.get(f"{api_url}/status", timeout=timeout)
|
126 |
+
return response.status_code == 200 and response.json().get("status") in ["ok", "busy"]
|
127 |
+
except requests.exceptions.RequestException:
|
128 |
+
return False
|
129 |
+
|
130 |
+
# --- Start Streamlit App ---
|
131 |
+
def start_streamlit_app():
|
132 |
+
nonlocal streamlit_process
|
133 |
+
print(f"🔄 Starting Streamlit web interface (app.py) on port {streamlit_port}...")
|
134 |
+
cmd = [sys.executable, "-m", "streamlit", "run", "app.py", f"--server.port={streamlit_port}", "--server.headless=true"]
|
135 |
+
try:
|
136 |
+
streamlit_process = subprocess.Popen(
|
137 |
+
cmd,
|
138 |
+
stdout=subprocess.PIPE,
|
139 |
+
stderr=subprocess.PIPE,
|
140 |
+
bufsize=1,
|
141 |
+
universal_newlines=True,
|
142 |
+
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if platform.system() == "Windows" else 0
|
143 |
+
)
|
144 |
+
print(f" Streamlit process started (PID: {streamlit_process.pid})")
|
145 |
+
except Exception as e:
|
146 |
+
print(f"❌ Failed to start Streamlit app: {e}")
|
147 |
+
streamlit_process = None
|
148 |
+
return streamlit_process
|
149 |
+
|
150 |
+
# --- Shutdown Services ---
|
151 |
+
def shutdown_services(api_proc, streamlit_proc):
|
152 |
+
print("\n🛑 Shutting down services...")
|
153 |
+
processes_to_stop = {
|
154 |
+
"API": api_proc,
|
155 |
+
"Streamlit": streamlit_proc
|
156 |
+
}
|
157 |
+
for name, proc in processes_to_stop.items():
|
158 |
+
if proc and proc.poll() is None:
|
159 |
+
print(f" Stopping {name} (PID: {proc.pid})...")
|
160 |
+
try:
|
161 |
+
# Use the kill_proc_tree helper for robust termination
|
162 |
+
kill_proc_tree(proc.pid, sig=signal.SIGTERM, timeout=3)
|
163 |
+
print(f" {name} stopped.")
|
164 |
+
except Exception as e:
|
165 |
+
print(f" Error stopping {name} (PID: {proc.pid}): {e}. Attempting force kill.")
|
166 |
+
try:
|
167 |
+
kill_proc_tree(proc.pid, sig=signal.SIGKILL)
|
168 |
+
except Exception as final_e:
|
169 |
+
print(f" Force kill also failed for {name} (PID: {proc.pid}): {final_e}")
|
170 |
+
elif proc:
|
171 |
+
print(f" {name} (PID: {proc.pid}) already stopped.")
|
172 |
+
else:
|
173 |
+
print(f" {name} was not running.")
|
174 |
+
# Final cleanup check
|
175 |
+
cleanup_existing_processes()
|
176 |
+
print("✅ Application stopped.")
|
177 |
+
|
178 |
+
# --- Main Execution Logic ---
|
179 |
+
|
180 |
+
# Set up graceful exit handler
|
181 |
+
def handle_exit(signum, frame):
|
182 |
+
print("\n👋 Signal received, initiating graceful shutdown...")
|
183 |
+
shutdown_services(api_process, streamlit_process)
|
184 |
+
sys.exit(0)
|
185 |
+
|
186 |
+
signal.signal(signal.SIGINT, handle_exit) # Ctrl+C
|
187 |
+
signal.signal(signal.SIGTERM, handle_exit) # Termination signal
|
188 |
+
|
189 |
+
try:
|
190 |
+
# Initial cleanup before starting
|
191 |
+
cleanup_existing_processes()
|
192 |
+
|
193 |
+
# Start API
|
194 |
+
api_process = start_api_server()
|
195 |
+
if not api_process:
|
196 |
+
raise RuntimeError("Failed to start API server, cannot continue.")
|
197 |
+
|
198 |
+
# Wait for API to be ready
|
199 |
+
print(" Waiting for API to become available", end="")
|
200 |
+
api_ready = False
|
201 |
+
for i in range(20): # Increased wait time (20 secs)
|
202 |
+
if check_api_health():
|
203 |
+
api_ready = True
|
204 |
+
print("\n✅ API server is running and responding.")
|
205 |
+
break
|
206 |
+
print(".", end="", flush=True)
|
207 |
+
time.sleep(1)
|
208 |
+
if api_process.poll() is not None:
|
209 |
+
print(f"\n❌ API process terminated unexpectedly during startup (exit code: {api_process.poll()}).")
|
210 |
+
# Attempt to read stderr
|
211 |
+
try:
|
212 |
+
_, stderr_output = api_process.communicate(timeout=1)
|
213 |
+
print("--- API Stderr ---")
|
214 |
+
print(stderr_output or "<No stderr captured>")
|
215 |
+
print("------------------")
|
216 |
+
except:
|
217 |
+
pass
|
218 |
+
raise RuntimeError("API process failed during startup.")
|
219 |
+
|
220 |
+
if not api_ready:
|
221 |
+
print("\n⚠️ API server did not become responsive within the time limit. Check api.log. Proceeding anyway...")
|
222 |
+
|
223 |
+
# Start Streamlit
|
224 |
+
streamlit_process = start_streamlit_app()
|
225 |
+
if not streamlit_process:
|
226 |
+
print("❌ Failed to start Streamlit app. You may need to start it manually.")
|
227 |
+
# Don't raise error, maybe user only wants API
|
228 |
+
|
229 |
+
# Give Streamlit a moment
|
230 |
+
print(" Waiting for Streamlit to initialize...")
|
231 |
+
time.sleep(5)
|
232 |
+
if streamlit_process and streamlit_process.poll() is None:
|
233 |
+
print("✅ Streamlit interface should be starting.")
|
234 |
+
elif streamlit_process:
|
235 |
+
print(f"❌ Streamlit process terminated unexpectedly after start (exit code: {streamlit_process.poll()}). Check logs.")
|
236 |
+
else:
|
237 |
+
print(" Streamlit process failed to start.")
|
238 |
+
|
239 |
+
# Open browser if Streamlit started
|
240 |
+
if streamlit_process and streamlit_process.poll() is None:
|
241 |
+
try:
|
242 |
+
print(f" Opening web interface ({streamlit_url}) in your browser...")
|
243 |
+
webbrowser.open(streamlit_url)
|
244 |
+
except Exception as e:
|
245 |
+
print(f" Could not open browser automatically: {e}. Please navigate to the URL manually.")
|
246 |
+
else:
|
247 |
+
print(f" Streamlit not running, cannot open browser.")
|
248 |
+
|
249 |
+
# Print URLs
|
250 |
+
print("\n📋 Application URLs:")
|
251 |
+
print(f" - Web Interface: {streamlit_url} (if started)")
|
252 |
+
print(f" - API Root: {api_url}")
|
253 |
+
print(f" - API Status: {api_url}/status")
|
254 |
+
print(f" - API Models: {api_url}/models")
|
255 |
+
|
256 |
+
print("\n✨ Application is running.")
|
257 |
+
print(" Monitoring services... Press Ctrl+C to stop.")
|
258 |
+
print(" Check api.log for API server logs.")
|
259 |
+
|
260 |
+
# Monitoring Loop
|
261 |
+
last_health_check_time = time.time()
|
262 |
+
consecutive_api_failures = 0
|
263 |
+
while True:
|
264 |
+
# Check API process
|
265 |
+
api_status = api_process.poll() if api_process else -1
|
266 |
+
if api_status is not None:
|
267 |
+
print(f"\n⚠️ API server process stopped unexpectedly (exit code: {api_status}). Restarting...")
|
268 |
+
# Attempt to read stderr
|
269 |
+
try:
|
270 |
+
_, stderr_output = api_process.communicate(timeout=1)
|
271 |
+
print("--- API Stderr ---")
|
272 |
+
print(stderr_output or "<No stderr captured>")
|
273 |
+
print("------------------")
|
274 |
+
except:
|
275 |
+
pass
|
276 |
+
api_process = start_api_server()
|
277 |
+
if not api_process:
|
278 |
+
print("❌ Failed to restart API server after crash. Exiting monitoring loop.")
|
279 |
+
break
|
280 |
+
time.sleep(5) # Give it time to restart
|
281 |
+
last_health_check_time = time.time() # Reset check timer
|
282 |
+
consecutive_api_failures = 0
|
283 |
+
continue # Skip rest of loop iteration
|
284 |
+
|
285 |
+
# Check Streamlit process
|
286 |
+
if streamlit_process:
|
287 |
+
streamlit_status = streamlit_process.poll()
|
288 |
+
if streamlit_status is not None:
|
289 |
+
print(f"\n⚠️ Streamlit process stopped unexpectedly (exit code: {streamlit_status}). Restarting...")
|
290 |
+
# Attempt to read stderr
|
291 |
+
try:
|
292 |
+
_, stderr_output = streamlit_process.communicate(timeout=1)
|
293 |
+
print("--- Streamlit Stderr ---")
|
294 |
+
print(stderr_output or "<No stderr captured>")
|
295 |
+
print("----------------------")
|
296 |
+
except:
|
297 |
+
pass
|
298 |
+
streamlit_process = start_streamlit_app()
|
299 |
+
if not streamlit_process:
|
300 |
+
print("❌ Failed to restart Streamlit. Will not monitor Streamlit anymore.")
|
301 |
+
time.sleep(5) # Give it time to restart
|
302 |
+
continue
|
303 |
+
|
304 |
+
# Periodic Health Check (every 30 seconds)
|
305 |
+
current_time = time.time()
|
306 |
+
if current_time - last_health_check_time > 30:
|
307 |
+
if check_api_health():
|
308 |
+
# print(" [Health Check] API is responsive.") # Verbose logging
|
309 |
+
consecutive_api_failures = 0
|
310 |
+
else:
|
311 |
+
consecutive_api_failures += 1
|
312 |
+
print(f" [Health Check] ⚠️ API failed health check #{consecutive_api_failures}.")
|
313 |
+
if consecutive_api_failures >= 3:
|
314 |
+
print(" API unresponsive for 3 consecutive checks. Restarting API server...")
|
315 |
+
kill_proc_tree(api_process.pid) # Force kill unresponsive API
|
316 |
+
api_process = start_api_server()
|
317 |
+
if not api_process:
|
318 |
+
print("❌ Failed to restart API server after health failures. Exiting monitoring loop.")
|
319 |
+
break
|
320 |
+
time.sleep(5)
|
321 |
+
consecutive_api_failures = 0 # Reset counter after restart
|
322 |
+
|
323 |
+
last_health_check_time = current_time
|
324 |
+
|
325 |
+
# Add memory check if needed (optional)
|
326 |
+
# try:
|
327 |
+
# api_mem = psutil.Process(api_process.pid).memory_info().rss / (1024 * 1024) # MB
|
328 |
+
# if api_mem > 1024: # Example threshold: 1GB
|
329 |
+
# print(f" ⚠️ API memory usage high ({api_mem:.1f} MB). Restarting...")
|
330 |
+
# kill_proc_tree(api_process.pid)
|
331 |
+
# api_process = start_api_server()
|
332 |
+
# # ... handle failure ...
|
333 |
+
# except psutil.NoSuchProcess:
|
334 |
+
# pass # Process might have just restarted
|
335 |
+
|
336 |
+
time.sleep(5) # Check processes every 5 seconds
|
337 |
+
|
338 |
+
except KeyboardInterrupt:
|
339 |
+
print("\n⌨️ Ctrl+C detected.")
|
340 |
+
except RuntimeError as e:
|
341 |
+
print(f"\n❌ Runtime Error: {e}")
|
342 |
+
except Exception as e:
|
343 |
+
print(f"\n❌ An unexpected error occurred in the run script: {e}")
|
344 |
+
import traceback
|
345 |
+
traceback.print_exc()
|
346 |
+
finally:
|
347 |
+
# Ensure services are stopped on any exit
|
348 |
+
shutdown_services(api_process, streamlit_process)
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
main()
|