Spaces:
Running
Running
sd commited on
Upload 110 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dockerignore +15 -0
- .gitattributes +5 -35
- .gitignore +131 -0
- API_DOCUMENTATION.md +118 -0
- DOCUMENTATION.md +158 -0
- Dockerfile +39 -0
- LICENSE +35 -0
- README.md +229 -10
- app.py +358 -0
- basicsr/VERSION +1 -0
- basicsr/__init__.py +11 -0
- basicsr/archs/__init__.py +25 -0
- basicsr/archs/arcface_arch.py +245 -0
- basicsr/archs/arch_util.py +318 -0
- basicsr/archs/codeformer_arch.py +280 -0
- basicsr/archs/rrdbnet_arch.py +119 -0
- basicsr/archs/vgg_arch.py +161 -0
- basicsr/archs/vqgan_arch.py +434 -0
- basicsr/data/__init__.py +100 -0
- basicsr/data/data_sampler.py +48 -0
- basicsr/data/data_util.py +392 -0
- basicsr/data/ffhq_blind_dataset.py +299 -0
- basicsr/data/ffhq_blind_joint_dataset.py +324 -0
- basicsr/data/gaussian_kernels.py +690 -0
- basicsr/data/paired_image_dataset.py +101 -0
- basicsr/data/prefetch_dataloader.py +125 -0
- basicsr/data/transforms.py +165 -0
- basicsr/losses/__init__.py +26 -0
- basicsr/losses/loss_util.py +95 -0
- basicsr/losses/losses.py +455 -0
- basicsr/metrics/__init__.py +19 -0
- basicsr/metrics/metric_util.py +45 -0
- basicsr/metrics/psnr_ssim.py +128 -0
- basicsr/models/__init__.py +30 -0
- basicsr/models/base_model.py +322 -0
- basicsr/models/codeformer_idx_model.py +220 -0
- basicsr/models/codeformer_joint_model.py +350 -0
- basicsr/models/codeformer_model.py +332 -0
- basicsr/models/lr_scheduler.py +96 -0
- basicsr/models/sr_model.py +209 -0
- basicsr/models/vqgan_model.py +285 -0
- basicsr/ops/__init__.py +0 -0
- basicsr/ops/dcn/__init__.py +7 -0
- basicsr/ops/dcn/deform_conv.py +377 -0
- basicsr/ops/dcn/src/deform_conv_cuda.cpp +685 -0
- basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu +867 -0
- basicsr/ops/dcn/src/deform_conv_ext.cpp +164 -0
- basicsr/ops/fused_act/__init__.py +3 -0
- basicsr/ops/fused_act/fused_act.py +89 -0
- basicsr/ops/fused_act/src/fused_bias_act.cpp +26 -0
.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.gitignore
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.pyd
|
| 7 |
+
.DS_Store
|
| 8 |
+
weights/
|
| 9 |
+
results/
|
| 10 |
+
inputs/cropped_faces/
|
| 11 |
+
inputs/gray_faces/
|
| 12 |
+
inputs/masked_faces/
|
| 13 |
+
inputs/whole_imgs/
|
| 14 |
+
output/
|
| 15 |
+
web-demos/
|
.gitattributes
CHANGED
|
@@ -1,35 +1,5 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
-
*.
|
| 4 |
-
*.
|
| 5 |
-
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
weights/facelib/tmpz5esw78c filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.vscode
|
| 2 |
+
|
| 3 |
+
# ignored files
|
| 4 |
+
version.py
|
| 5 |
+
|
| 6 |
+
# ignored files with suffix
|
| 7 |
+
*.html
|
| 8 |
+
*.png
|
| 9 |
+
*.jpeg
|
| 10 |
+
*.jpg
|
| 11 |
+
*.pt
|
| 12 |
+
*.gif
|
| 13 |
+
*.pth
|
| 14 |
+
*.dat
|
| 15 |
+
*.zip
|
| 16 |
+
|
| 17 |
+
# template
|
| 18 |
+
|
| 19 |
+
# Byte-compiled / optimized / DLL files
|
| 20 |
+
__pycache__/
|
| 21 |
+
*.py[cod]
|
| 22 |
+
*$py.class
|
| 23 |
+
|
| 24 |
+
# C extensions
|
| 25 |
+
*.so
|
| 26 |
+
|
| 27 |
+
# Distribution / packaging
|
| 28 |
+
.Python
|
| 29 |
+
build/
|
| 30 |
+
develop-eggs/
|
| 31 |
+
dist/
|
| 32 |
+
downloads/
|
| 33 |
+
eggs/
|
| 34 |
+
.eggs/
|
| 35 |
+
lib/
|
| 36 |
+
lib64/
|
| 37 |
+
parts/
|
| 38 |
+
sdist/
|
| 39 |
+
var/
|
| 40 |
+
wheels/
|
| 41 |
+
*.egg-info/
|
| 42 |
+
.installed.cfg
|
| 43 |
+
*.egg
|
| 44 |
+
MANIFEST
|
| 45 |
+
|
| 46 |
+
# PyInstaller
|
| 47 |
+
# Usually these files are written by a python script from a template
|
| 48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 49 |
+
*.manifest
|
| 50 |
+
*.spec
|
| 51 |
+
|
| 52 |
+
# Installer logs
|
| 53 |
+
pip-log.txt
|
| 54 |
+
pip-delete-this-directory.txt
|
| 55 |
+
|
| 56 |
+
# Unit test / coverage reports
|
| 57 |
+
htmlcov/
|
| 58 |
+
.tox/
|
| 59 |
+
.coverage
|
| 60 |
+
.coverage.*
|
| 61 |
+
.cache
|
| 62 |
+
nosetests.xml
|
| 63 |
+
coverage.xml
|
| 64 |
+
*.cover
|
| 65 |
+
.hypothesis/
|
| 66 |
+
.pytest_cache/
|
| 67 |
+
|
| 68 |
+
# Translations
|
| 69 |
+
*.mo
|
| 70 |
+
*.pot
|
| 71 |
+
|
| 72 |
+
# Django stuff:
|
| 73 |
+
*.log
|
| 74 |
+
local_settings.py
|
| 75 |
+
db.sqlite3
|
| 76 |
+
|
| 77 |
+
# Flask stuff:
|
| 78 |
+
instance/
|
| 79 |
+
.webassets-cache
|
| 80 |
+
|
| 81 |
+
# Scrapy stuff:
|
| 82 |
+
.scrapy
|
| 83 |
+
|
| 84 |
+
# Sphinx documentation
|
| 85 |
+
docs/_build/
|
| 86 |
+
|
| 87 |
+
# PyBuilder
|
| 88 |
+
target/
|
| 89 |
+
|
| 90 |
+
# Jupyter Notebook
|
| 91 |
+
.ipynb_checkpoints
|
| 92 |
+
|
| 93 |
+
# pyenv
|
| 94 |
+
.python-version
|
| 95 |
+
|
| 96 |
+
# celery beat schedule file
|
| 97 |
+
celerybeat-schedule
|
| 98 |
+
|
| 99 |
+
# SageMath parsed files
|
| 100 |
+
*.sage.py
|
| 101 |
+
|
| 102 |
+
# Environments
|
| 103 |
+
.env
|
| 104 |
+
.venv
|
| 105 |
+
env/
|
| 106 |
+
venv/
|
| 107 |
+
ENV/
|
| 108 |
+
env.bak/
|
| 109 |
+
venv.bak/
|
| 110 |
+
|
| 111 |
+
# Spyder project settings
|
| 112 |
+
.spyderproject
|
| 113 |
+
.spyproject
|
| 114 |
+
|
| 115 |
+
# Rope project settings
|
| 116 |
+
.ropeproject
|
| 117 |
+
|
| 118 |
+
# mkdocs documentation
|
| 119 |
+
/site
|
| 120 |
+
|
| 121 |
+
# mypy
|
| 122 |
+
.mypy_cache/
|
| 123 |
+
|
| 124 |
+
# project
|
| 125 |
+
results/
|
| 126 |
+
experiments/
|
| 127 |
+
tb_logger/
|
| 128 |
+
run.sh
|
| 129 |
+
*debug*
|
| 130 |
+
*_old*
|
| 131 |
+
|
API_DOCUMENTATION.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CodeFormer API Documentation
|
| 2 |
+
|
| 3 |
+
This document describes the programmatic interface for the CodeFormer Face Restoration service.
|
| 4 |
+
|
| 5 |
+
## Base URL
|
| 6 |
+
The API is accessible at:
|
| 7 |
+
`https://esmailx50-job.hf.space` (or your specific Hugging Face Space URL)
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## 1. Process Images
|
| 12 |
+
Processes one or more images for face restoration and enhancement.
|
| 13 |
+
|
| 14 |
+
- **Endpoint:** `/api/process`
|
| 15 |
+
- **Method:** `POST`
|
| 16 |
+
- **Consumes:** `multipart/form-data` OR `application/json`
|
| 17 |
+
|
| 18 |
+
### Parameters
|
| 19 |
+
| Parameter | Type | Default | Description |
|
| 20 |
+
| :--- | :--- | :--- | :--- |
|
| 21 |
+
| `fidelity` | float | `0.5` | Fidelity weight ($w$). Range [0, 1]. Lower is more "hallucinated" detail, higher is more identity preservation. |
|
| 22 |
+
| `upscale` | int | `2` | Final upscaling factor. Supported: `1`, `2`, `4`. |
|
| 23 |
+
| `background_enhance` | bool | `false` | Enhance the background using Real-ESRGAN. |
|
| 24 |
+
| `face_upsample` | bool | `false` | Upsample restored faces using Real-ESRGAN. |
|
| 25 |
+
| `return_base64` | bool | `false` | If true, includes the processed image as a base64 string in the JSON response. |
|
| 26 |
+
|
| 27 |
+
### Input Formats
|
| 28 |
+
|
| 29 |
+
#### A. Multipart Form Data (`multipart/form-data`)
|
| 30 |
+
Useful for uploading files directly.
|
| 31 |
+
- `image`: One or more image files (as a list).
|
| 32 |
+
- Other parameters as form fields.
|
| 33 |
+
|
| 34 |
+
**Example (curl):**
|
| 35 |
+
```bash
|
| 36 |
+
curl -X POST
|
| 37 |
+
-F "image=@my_photo.jpg"
|
| 38 |
+
-F "fidelity=0.7"
|
| 39 |
+
-F "background_enhance=true"
|
| 40 |
+
https://esmailx50-job.hf.space/api/process
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
#### B. JSON (`application/json`)
|
| 44 |
+
Useful for sending base64-encoded image data.
|
| 45 |
+
- `image_base64`: A single base64 string (with or without data URI prefix).
|
| 46 |
+
- `images_base64`: (Optional) A list of base64 strings for batch processing.
|
| 47 |
+
- Other parameters as JSON keys.
|
| 48 |
+
|
| 49 |
+
**Example (curl):**
|
| 50 |
+
```bash
|
| 51 |
+
curl -X POST
|
| 52 |
+
-H "Content-Type: application/json"
|
| 53 |
+
-d '{
|
| 54 |
+
"image_base64": "data:image/png;base64,iVBORw0KG...",
|
| 55 |
+
"fidelity": 0.5,
|
| 56 |
+
"return_base64": true
|
| 57 |
+
}'
|
| 58 |
+
https://esmailx50-job.hf.space/api/process
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Success Response
|
| 62 |
+
```json
|
| 63 |
+
{
|
| 64 |
+
"status": "success",
|
| 65 |
+
"count": 1,
|
| 66 |
+
"results": [
|
| 67 |
+
{
|
| 68 |
+
"original_name": "image.png",
|
| 69 |
+
"filename": "api_result_uuid.png",
|
| 70 |
+
"image_url": "https://.../static/results/api_result_uuid.png",
|
| 71 |
+
"image_base64": "iVBORw0KG..." // Only if return_base64 was true
|
| 72 |
+
}
|
| 73 |
+
]
|
| 74 |
+
}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Error Response
|
| 78 |
+
```json
|
| 79 |
+
{
|
| 80 |
+
"status": "error",
|
| 81 |
+
"message": "Detailed error message here"
|
| 82 |
+
}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## 2. Health Check
|
| 88 |
+
Checks if the service is online and returns the compute device being used.
|
| 89 |
+
|
| 90 |
+
- **Endpoint:** `/api/health`
|
| 91 |
+
- **Method:** `GET`
|
| 92 |
+
|
| 93 |
+
**Success Response:**
|
| 94 |
+
```json
|
| 95 |
+
{
|
| 96 |
+
"status": "online",
|
| 97 |
+
"device": "cuda" // or "cpu"
|
| 98 |
+
}
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## CORS & Integration
|
| 104 |
+
Cross-Origin Resource Sharing (CORS) is enabled for all routes. This allows you to call the API directly from browser-based applications (React, Vue, etc.) without hitting "Same-Origin Policy" blocks.
|
| 105 |
+
|
| 106 |
+
**Javascript Example (Fetch):**
|
| 107 |
+
```javascript
|
| 108 |
+
const formData = new FormData();
|
| 109 |
+
formData.append('image', fileInput.files[0]);
|
| 110 |
+
formData.append('fidelity', '0.5');
|
| 111 |
+
|
| 112 |
+
const response = await fetch('https://esmailx50-job.hf.space/api/process', {
|
| 113 |
+
method: 'POST',
|
| 114 |
+
body: formData
|
| 115 |
+
});
|
| 116 |
+
const data = await response.json();
|
| 117 |
+
console.log(data.results[0].image_url);
|
| 118 |
+
```
|
DOCUMENTATION.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CodeFormer Face Restoration - Project Documentation
|
| 2 |
+
|
| 3 |
+
## 1. Introduction
|
| 4 |
+
|
| 5 |
+
**CodeFormer** is a robust blind face restoration algorithm designed to restore old, degraded, or AI-generated face images. It utilizes a **Codebook Lookup Transformer** (VQGAN-based) to predict high-quality facial features even from severe degradation, ensuring that the restored faces look natural and faithful to the original identity.
|
| 6 |
+
|
| 7 |
+
This project wraps the core CodeFormer research code into a deployable, user-friendly **Flask Web Application**, containerized with **Docker** for easy deployment on platforms like Hugging Face Spaces.
|
| 8 |
+
|
| 9 |
+
### Key Features
|
| 10 |
+
* **Blind Face Restoration:** Restores faces from low-quality inputs without knowing the specific degradation details.
|
| 11 |
+
* **Background Enhancement:** Uses **Real-ESRGAN** to upscale and enhance the non-face background regions of the image.
|
| 12 |
+
* **Face Alignment & Paste-back:** Automatically detects faces, aligns them for processing, and seamlessly blends them back into the original image.
|
| 13 |
+
* **Adjustable Fidelity:** Users can balance between restoration quality (hallucinating details) and identity fidelity (keeping the original look).
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 2. System Architecture
|
| 18 |
+
|
| 19 |
+
The application is built on a Python/PyTorch backend served via Flask.
|
| 20 |
+
|
| 21 |
+
### 2.1 Technology Stack
|
| 22 |
+
* **Framework:** Flask (Python Web Server)
|
| 23 |
+
* **Deep Learning:** PyTorch, TorchVision
|
| 24 |
+
* **Image Processing:** OpenCV, NumPy, Pillow
|
| 25 |
+
* **Core Libraries:** `basicsr` (Basic Super-Restoration), `facelib` (Face detection/utils)
|
| 26 |
+
* **Frontend:** HTML5, Bootstrap 5, Jinja2 Templates
|
| 27 |
+
* **Containerization:** Docker (CUDA-enabled)
|
| 28 |
+
|
| 29 |
+
### 2.2 Directory Structure
|
| 30 |
+
```
|
| 31 |
+
CodeFormer/
|
| 32 |
+
├── app.py # Main Flask application entry point
|
| 33 |
+
├── Dockerfile # Container configuration
|
| 34 |
+
├── requirements.txt # Python dependencies
|
| 35 |
+
├── basicsr/ # Core AI framework (Super-Resolution tools)
|
| 36 |
+
├── facelib/ # Face detection and alignment utilities
|
| 37 |
+
├── templates/ # HTML Frontend
|
| 38 |
+
│ ├── index.html # Upload interface
|
| 39 |
+
│ └── result.html # Results display
|
| 40 |
+
├── static/ # Static assets (css, js, uploads)
|
| 41 |
+
│ ├── uploads/ # Temporary storage for input images
|
| 42 |
+
│ └── results/ # Temporary storage for processed output
|
| 43 |
+
└── weights/ # Pre-trained model weights (downloaded on startup)
|
| 44 |
+
├── CodeFormer/ # CodeFormer model (.pth)
|
| 45 |
+
├── facelib/ # Detection (RetinaFace) and Parsing models
|
| 46 |
+
└── realesrgan/ # Background upscaler (Real-ESRGAN)
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### 2.3 Logic Flow
|
| 50 |
+
1. **Input:** User uploads an image via the Web UI.
|
| 51 |
+
2. **Pre-processing (`app.py`):**
|
| 52 |
+
* Image is saved to `static/uploads`.
|
| 53 |
+
* Parameters (fidelity, upscale factor) are parsed.
|
| 54 |
+
3. **Inference Pipeline:**
|
| 55 |
+
* **Detection:** `facelib` detects faces in the image using RetinaFace.
|
| 56 |
+
* **Alignment:** Faces are cropped and aligned to a standard 512x512 resolution.
|
| 57 |
+
* **Restoration:** The **CodeFormer** model processes the aligned faces.
|
| 58 |
+
* **Upscaling (Optional):** The background is upscaled using **Real-ESRGAN**.
|
| 59 |
+
* **Paste-back:** Restored faces are warped back to their original positions and blended.
|
| 60 |
+
4. **Output:** The final image is saved to `static/results` and displayed to the user.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## 3. Installation & Deployment
|
| 65 |
+
|
| 66 |
+
### 3.1 Docker Deployment (Recommended)
|
| 67 |
+
The project is optimized for Docker.
|
| 68 |
+
|
| 69 |
+
**Prerequisites:** Docker, NVIDIA GPU (optional, but recommended).
|
| 70 |
+
|
| 71 |
+
1. **Build the Image:**
|
| 72 |
+
```bash
|
| 73 |
+
docker build -t codeformer-app .
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
2. **Run the Container:**
|
| 77 |
+
```bash
|
| 78 |
+
# Run on port 7860 (Standard for HF Spaces)
|
| 79 |
+
docker run -it -p 7860:7860 codeformer-app
|
| 80 |
+
```
|
| 81 |
+
*Note: To use GPU, add the `--gpus all` flag to the run command.*
|
| 82 |
+
|
| 83 |
+
### 3.2 Hugging Face Spaces Deployment
|
| 84 |
+
This repository is configured for direct deployment to Hugging Face.
|
| 85 |
+
|
| 86 |
+
1. Create a **Docker** Space on Hugging Face.
|
| 87 |
+
2. Push this entire repository to the Space's Git remote.
|
| 88 |
+
```bash
|
| 89 |
+
git remote add hf git@hf.co:spaces/USERNAME/SPACE_NAME
|
| 90 |
+
git push hf main
|
| 91 |
+
```
|
| 92 |
+
3. The Space will build (approx. 5-10 mins) and launch automatically.
|
| 93 |
+
|
| 94 |
+
### 3.3 Local Development
|
| 95 |
+
1. **Install Environment:**
|
| 96 |
+
```bash
|
| 97 |
+
conda create -n codeformer python=3.8
|
| 98 |
+
conda activate codeformer
|
| 99 |
+
pip install -r requirements.txt
|
| 100 |
+
```
|
| 101 |
+
2. **Install Basicsr:**
|
| 102 |
+
```bash
|
| 103 |
+
python basicsr/setup.py install
|
| 104 |
+
```
|
| 105 |
+
3. **Run App:**
|
| 106 |
+
```bash
|
| 107 |
+
python app.py
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## 4. User Guide (Web Interface)
|
| 113 |
+
|
| 114 |
+
### 4.1 Interface Controls
|
| 115 |
+
|
| 116 |
+
* **Input Image:** Supports standard formats (JPG, PNG, WEBP). Drag and drop supported.
|
| 117 |
+
* **Fidelity Weight (w):**
|
| 118 |
+
* **Range:** 0.0 to 1.0.
|
| 119 |
+
* **0.0 (Better Quality):** The model "hallucinates" more details. Results look very sharp and high-quality but may slightly alter the person's identity (look less like the original).
|
| 120 |
+
* **1.0 (Better Identity):** The model sticks strictly to the original features. Results are faithful to the original photo but might be blurrier or contain more artifacts.
|
| 121 |
+
* **Recommended:** 0.5 is a balanced default.
|
| 122 |
+
* **Upscale Factor:**
|
| 123 |
+
* Scales the final output resolution (1x, 2x, or 4x).
|
| 124 |
+
* *Note: Higher scaling requires more VRAM.*
|
| 125 |
+
* **Enhance Background:**
|
| 126 |
+
* If checked, runs Real-ESRGAN on the non-face areas.
|
| 127 |
+
* *Recommendation:* Keep checked for full-photo restoration. Uncheck if you only care about the face or are running on limited hardware.
|
| 128 |
+
* **Upsample Face:**
|
| 129 |
+
* If checked, the restored face is also upsampled to match the background resolution.
|
| 130 |
+
|
| 131 |
+
### 4.2 Viewing Results
|
| 132 |
+
The result page features an interactive **Before/After Slider**. Drag the handle left and right to compare the pixels of the original versus the restored image directly.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## 5. Technical Details
|
| 137 |
+
|
| 138 |
+
### 5.1 Model Weights
|
| 139 |
+
The application automatically checks for and downloads the following weights to the `weights/` directory on startup:
|
| 140 |
+
|
| 141 |
+
| Model | Path | Description |
|
| 142 |
+
| :--- | :--- | :--- |
|
| 143 |
+
| **CodeFormer** | `weights/CodeFormer/codeformer.pth` | Main restoration model. |
|
| 144 |
+
| **RetinaFace** | `weights/facelib/detection_Resnet50_Final.pth` | Face detection. |
|
| 145 |
+
| **ParseNet** | `weights/facelib/parsing_parsenet.pth` | Face parsing (segmentation). |
|
| 146 |
+
| **Real-ESRGAN** | `weights/realesrgan/RealESRGAN_x2plus.pth` | Background upscaler (x2). |
|
| 147 |
+
|
| 148 |
+
### 5.2 Performance Notes
|
| 149 |
+
* **Memory:** The full pipeline (CodeFormer + Real-ESRGAN) requires significant RAM/VRAM. On CPU-only environments (like basic HF Spaces), processing a single image may take 30-60 seconds.
|
| 150 |
+
* **Git LFS:** Image assets in this repository are tracked with Git LFS to keep the repo size manageable.
|
| 151 |
+
|
| 152 |
+
---
|
| 153 |
+
|
| 154 |
+
## 6. Credits & References
|
| 155 |
+
|
| 156 |
+
* **Original Paper:** [Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)](https://arxiv.org/abs/2206.11253)
|
| 157 |
+
* **Authors:** Shangchen Zhou, Kelvin C.K. Chan, Chongyi Li, Chen Change Loy (S-Lab, Nanyang Technological University).
|
| 158 |
+
* **Original Repository:** [sczhou/CodeFormer](https://github.com/sczhou/CodeFormer)
|
Dockerfile
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
libgl1 \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
git \
|
| 10 |
+
ninja-build \
|
| 11 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 12 |
+
|
| 13 |
+
# Copy requirements
|
| 14 |
+
COPY requirements.txt .
|
| 15 |
+
|
| 16 |
+
# Install python dependencies
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Copy application code
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Create necessary directories and set permissions
|
| 23 |
+
RUN mkdir -p weights inputs output static && \
|
| 24 |
+
chmod 777 weights inputs output static
|
| 25 |
+
|
| 26 |
+
# Install basicsr (build extensions in-place)
|
| 27 |
+
RUN python basicsr/setup.py build_ext --inplace
|
| 28 |
+
|
| 29 |
+
# Create a non-root user and switch to it
|
| 30 |
+
RUN useradd -m -u 1000 user
|
| 31 |
+
USER user
|
| 32 |
+
ENV HOME=/home/user \
|
| 33 |
+
PATH=/home/user/.local/bin:$PATH
|
| 34 |
+
|
| 35 |
+
WORKDIR /code
|
| 36 |
+
|
| 37 |
+
EXPOSE 7860
|
| 38 |
+
|
| 39 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
S-Lab License 1.0
|
| 2 |
+
|
| 3 |
+
Copyright 2022 S-Lab
|
| 4 |
+
|
| 5 |
+
Redistribution and use for non-commercial purpose in source and
|
| 6 |
+
binary forms, with or without modification, are permitted provided
|
| 7 |
+
that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
1. Redistributions of source code must retain the above copyright
|
| 10 |
+
notice, this list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
2. Redistributions in binary form must reproduce the above copyright
|
| 13 |
+
notice, this list of conditions and the following disclaimer in
|
| 14 |
+
the documentation and/or other materials provided with the
|
| 15 |
+
distribution.
|
| 16 |
+
|
| 17 |
+
3. Neither the name of the copyright holder nor the names of its
|
| 18 |
+
contributors may be used to endorse or promote products derived
|
| 19 |
+
from this software without specific prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 22 |
+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 23 |
+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 24 |
+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 25 |
+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 26 |
+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 27 |
+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 28 |
+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 29 |
+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 30 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 31 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 32 |
+
|
| 33 |
+
In the event that redistribution and/or use for commercial purpose in
|
| 34 |
+
source or binary forms, with or without modification is required,
|
| 35 |
+
please contact the contributor(s) of the work.
|
README.md
CHANGED
|
@@ -1,10 +1,229 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: CodeFormer
|
| 3 |
+
emoji: 👤
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<img src="assets/CodeFormer_logo.png" height=110>
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)
|
| 16 |
+
|
| 17 |
+
[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
<a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [](https://huggingface.co/spaces/sczhou/CodeFormer) [](https://replicate.com/sczhou/codeformer) [](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer) 
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
|
| 24 |
+
|
| 25 |
+
S-Lab, Nanyang Technological University
|
| 26 |
+
|
| 27 |
+
<img src="assets/network.jpg" width="800px"/>
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
### Update
|
| 34 |
+
- **2023.07.20**: Integrated to :panda_face: [OpenXLab](https://openxlab.org.cn/apps). Try out online demo! [](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer)
|
| 35 |
+
- **2023.04.19**: :whale: Training codes and config files are public available now.
|
| 36 |
+
- **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images.
|
| 37 |
+
- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
|
| 38 |
+
- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper:
|
| 39 |
+
- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [](https://huggingface.co/spaces/sczhou/CodeFormer)
|
| 40 |
+
- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [](https://replicate.com/sczhou/codeformer)
|
| 41 |
+
- [**More**](docs/history_changelog.md)
|
| 42 |
+
|
| 43 |
+
### TODO
|
| 44 |
+
- [x] Add training code and config files
|
| 45 |
+
- [x] Add checkpoint and script for face inpainting
|
| 46 |
+
- [x] Add checkpoint and script for face colorization
|
| 47 |
+
- [x] ~~Add background image enhancement~~
|
| 48 |
+
|
| 49 |
+
#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts
|
| 50 |
+
[<img src="assets/imgsli_1.jpg" height="226px"/>](https://imgsli.com/MTI3NTE2) [<img src="assets/imgsli_2.jpg" height="226px"/>](https://imgsli.com/MTI3NTE1) [<img src="assets/imgsli_3.jpg" height="226px"/>](https://imgsli.com/MTI3NTIw)
|
| 51 |
+
|
| 52 |
+
#### Face Restoration
|
| 53 |
+
|
| 54 |
+
<img src="assets/restoration_result1.png" width="400px"/> <img src="assets/restoration_result2.png" width="400px"/>
|
| 55 |
+
<img src="assets/restoration_result3.png" width="400px"/> <img src="assets/restoration_result4.png" width="400px"/>
|
| 56 |
+
|
| 57 |
+
#### Face Color Enhancement and Restoration
|
| 58 |
+
|
| 59 |
+
<img src="assets/color_enhancement_result1.png" width="400px"/> <img src="assets/color_enhancement_result2.png" width="400px"/>
|
| 60 |
+
|
| 61 |
+
#### Face Inpainting
|
| 62 |
+
|
| 63 |
+
<img src="assets/inpainting_result1.png" width="400px"/> <img src="assets/inpainting_result2.png" width="400px"/>
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
### Dependencies and Installation
|
| 68 |
+
|
| 69 |
+
- Pytorch >= 1.7.1
|
| 70 |
+
- CUDA >= 10.1
|
| 71 |
+
- Other required packages in `requirements.txt`
|
| 72 |
+
```
|
| 73 |
+
# git clone this repository
|
| 74 |
+
git clone https://github.com/sczhou/CodeFormer
|
| 75 |
+
cd CodeFormer
|
| 76 |
+
|
| 77 |
+
# create new anaconda env
|
| 78 |
+
conda create -n codeformer python=3.8 -y
|
| 79 |
+
conda activate codeformer
|
| 80 |
+
|
| 81 |
+
# install python dependencies
|
| 82 |
+
pip3 install -r requirements.txt
|
| 83 |
+
python basicsr/setup.py develop
|
| 84 |
+
conda install -c conda-forge dlib (only for face detection or cropping with dlib)
|
| 85 |
+
```
|
| 86 |
+
<!-- conda install -c conda-forge dlib -->
|
| 87 |
+
|
| 88 |
+
### Quick Inference
|
| 89 |
+
|
| 90 |
+
#### Download Pre-trained Models:
|
| 91 |
+
Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
|
| 92 |
+
```
|
| 93 |
+
python scripts/download_pretrained_models.py facelib
|
| 94 |
+
python scripts/download_pretrained_models.py dlib (only for dlib face detector)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
|
| 98 |
+
```
|
| 99 |
+
python scripts/download_pretrained_models.py CodeFormer
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
#### Prepare Testing Data:
|
| 103 |
+
You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. You can get the cropped and aligned faces by running the following command:
|
| 104 |
+
```
|
| 105 |
+
# you may need to install dlib via: conda install -c conda-forge dlib
|
| 106 |
+
python scripts/crop_align_face.py -i [input folder] -o [output folder]
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
#### Testing:
|
| 111 |
+
[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
|
| 112 |
+
|
| 113 |
+
Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. The results will be saved in the `results` folder.
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
🧑🏻 Face Restoration (cropped and aligned face)
|
| 117 |
+
```
|
| 118 |
+
# For cropped and aligned faces (512x512)
|
| 119 |
+
python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path]
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
:framed_picture: Whole Image Enhancement
|
| 123 |
+
```
|
| 124 |
+
# For whole image
|
| 125 |
+
# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
|
| 126 |
+
# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
|
| 127 |
+
python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path]
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
:clapper: Video Enhancement
|
| 131 |
+
```
|
| 132 |
+
# For Windows/Mac users, please install ffmpeg first
|
| 133 |
+
conda install -c conda-forge ffmpeg
|
| 134 |
+
```
|
| 135 |
+
```
|
| 136 |
+
# For video clips
|
| 137 |
+
# Video path should end with '.mp4'|'.mov'|'.avi'
|
| 138 |
+
python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path]
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
🌈 Face Colorization (cropped and aligned face)
|
| 142 |
+
```
|
| 143 |
+
# For cropped and aligned faces (512x512)
|
| 144 |
+
# Colorize black and white or faded photo
|
| 145 |
+
python inference_colorization.py --input_path [image folder]|[image path]
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
🎨 Face Inpainting (cropped and aligned face)
|
| 149 |
+
```
|
| 150 |
+
# For cropped and aligned faces (512x512)
|
| 151 |
+
# Inputs could be masked by white brush using an image editing app (e.g., Photoshop)
|
| 152 |
+
# (check out the examples in inputs/masked_faces)
|
| 153 |
+
python inference_inpainting.py --input_path [image folder]|[image path]
|
| 154 |
+
```
|
| 155 |
+
### Training:
|
| 156 |
+
The training commands can be found in the documents: [English](docs/train.md) **|** [简体中文](docs/train_CN.md).
|
| 157 |
+
|
| 158 |
+
### License
|
| 159 |
+
|
| 160 |
+
This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
|
| 161 |
+
|
| 162 |
+
---
|
| 163 |
+
### 🐼 Ecosystem Applications & Deployments
|
| 164 |
+
|
| 165 |
+
CodeFormer has been widely adopted and deployed across a broad range (>20) of online applications, platforms, API services, and independent websites, and has also been integrated into many open-source projects and toolkits.
|
| 166 |
+
|
| 167 |
+
> Only demos on **Hugging Face Space**, **Replicate**, and **OpenXLab** are official deployments **maintained by the authors**. All other demos, APIs, apps, websites, and integrations listed below are **third-party (non-official)** and are not affiliated with the CodeFormer authors. Please verify their legitimacy to avoid potential financial loss.
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
#### Websites (Non-official)
|
| 171 |
+
|
| 172 |
+
⚠️⚠️⚠️ The following websites are **not official and are not operated by us**. They use our models without any license or authorization. Please verify their legitimacy to avoid potential financial loss.
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
| Website | Link | Notes |
|
| 176 |
+
|---------|------|--------|
|
| 177 |
+
| CodeFormer.net | https://codeformer.net/ | Non-official website |
|
| 178 |
+
| CodeFormer.cn | https://www.codeformer.cn/ | Non-official website |
|
| 179 |
+
| CodeFormerAI.com | https://codeformerai.com/ | Non-official website |
|
| 180 |
+
|
| 181 |
+
#### Online Demos / API Platforms
|
| 182 |
+
|
| 183 |
+
| Platform | Link | Notes |
|
| 184 |
+
|----------|------|--------|
|
| 185 |
+
| Hugging Face | https://huggingface.co/spaces/sczhou/CodeFormer | Maintained by Authors |
|
| 186 |
+
| Replicate | https://replicate.com/sczhou/codeformer | Maintained by Authors |
|
| 187 |
+
| OpenXLab | https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer |Maintained by Authors |
|
| 188 |
+
| Segmind | https://www.segmind.com/models/codeformer | Non-official |
|
| 189 |
+
| Sieve | https://www.sievedata.com/functions/sieve/codeformer | Non-official |
|
| 190 |
+
| Fal.ai | https://fal.ai/models/fal-ai/codeformer | Non-official |
|
| 191 |
+
| VaikerAI | https://vaikerai.com/sczhou/codeformer | Non-official |
|
| 192 |
+
| Scade.pro | https://www.scade.pro/processors/lucataco-codeformer | Non-official |
|
| 193 |
+
| Grandline | https://www.grandline.ai/model/codeformer | Non-official |
|
| 194 |
+
| AI Demos | https://aidemos.com/tools/codeformer | Non-official |
|
| 195 |
+
| Synexa | https://synexa.ai/explore/sczhou/codeformer | Non-official |
|
| 196 |
+
| RentPrompts | https://rentprompts.ai/models/Codeformer | Non-official |
|
| 197 |
+
| ElevaticsAI | https://elevatics.ai/models/super-resolution/codeformer | Non-official |
|
| 198 |
+
| Anakin.ai | https://anakin.ai/apps/codeformer-online-face-restoration-by-codeformer-19343 | Non-official |
|
| 199 |
+
| Relayto | https://relayto.com/explore/codeformer-yf9rj8kwc7zsr | Non-official |
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
#### Open-Source Projects & Toolkits
|
| 203 |
+
|
| 204 |
+
| Project / Toolkit | Link | Notes |
|
| 205 |
+
|-------------------|------|--------|
|
| 206 |
+
| Stable Diffusion GUI | https://nmkd.itch.io/t2i-gui | Integration |
|
| 207 |
+
| Stable Diffusion WebUI | https://github.com/AUTOMATIC1111/stable-diffusion-webui | Integration |
|
| 208 |
+
| ChaiNNer | https://github.com/chaiNNer-org/chaiNNer | Integration |
|
| 209 |
+
| PyPI | https://pypi.org/project/codeformer/ ; https://pypi.org/project/codeformer-pip/ | Python packages |
|
| 210 |
+
| ComfyUI | https://stable-diffusion-art.com/codeformer/ | Integration |
|
| 211 |
+
|
| 212 |
+
---
|
| 213 |
+
### Acknowledgement
|
| 214 |
+
|
| 215 |
+
This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
|
| 216 |
+
|
| 217 |
+
### Citation
|
| 218 |
+
If our work is useful for your research, please consider citing:
|
| 219 |
+
|
| 220 |
+
@inproceedings{zhou2022codeformer,
|
| 221 |
+
author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
|
| 222 |
+
title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
|
| 223 |
+
booktitle = {NeurIPS},
|
| 224 |
+
year = {2022}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
### Contact
|
| 229 |
+
If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`.
|
app.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CodeFormer Flask Application
|
| 3 |
+
Deployment on Hugging Face Spaces
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
import uuid
|
| 10 |
+
import numpy as np
|
| 11 |
+
import zipfile
|
| 12 |
+
import base64
|
| 13 |
+
from flask import Flask, render_template, request, send_file, url_for, jsonify, send_from_directory
|
| 14 |
+
from flask_cors import CORS
|
| 15 |
+
from werkzeug.utils import secure_filename
|
| 16 |
+
|
| 17 |
+
from torchvision.transforms.functional import normalize
|
| 18 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 19 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 20 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 21 |
+
from basicsr.utils.misc import gpu_is_available, get_device
|
| 22 |
+
from basicsr.utils.realesrgan_utils import RealESRGANer
|
| 23 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 24 |
+
|
| 25 |
+
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
| 26 |
+
from facelib.utils.misc import is_gray
|
| 27 |
+
|
| 28 |
+
# --- Initialization ---
|
| 29 |
+
app = Flask(__name__)
|
| 30 |
+
CORS(app) # Enable CORS for all routes
|
| 31 |
+
app.config['UPLOAD_FOLDER'] = 'static/uploads'
|
| 32 |
+
app.config['RESULT_FOLDER'] = 'static/results'
|
| 33 |
+
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB limit
|
| 34 |
+
|
| 35 |
+
# Ensure directories exist
|
| 36 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
| 37 |
+
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
|
| 38 |
+
os.makedirs('weights/CodeFormer', exist_ok=True)
|
| 39 |
+
os.makedirs('weights/facelib', exist_ok=True)
|
| 40 |
+
os.makedirs('weights/realesrgan', exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# Pretrained model URLs
|
| 43 |
+
pretrain_model_url = {
|
| 44 |
+
'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
|
| 45 |
+
'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
|
| 46 |
+
'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
|
| 47 |
+
'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def download_weights():
|
| 51 |
+
if not os.path.exists('weights/CodeFormer/codeformer.pth'):
|
| 52 |
+
load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='weights/CodeFormer', progress=True, file_name=None)
|
| 53 |
+
if not os.path.exists('weights/facelib/detection_Resnet50_Final.pth'):
|
| 54 |
+
load_file_from_url(url=pretrain_model_url['detection'], model_dir='weights/facelib', progress=True, file_name=None)
|
| 55 |
+
if not os.path.exists('weights/facelib/parsing_parsenet.pth'):
|
| 56 |
+
load_file_from_url(url=pretrain_model_url['parsing'], model_dir='weights/facelib', progress=True, file_name=None)
|
| 57 |
+
if not os.path.exists('weights/realesrgan/RealESRGAN_x2plus.pth'):
|
| 58 |
+
load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='weights/realesrgan', progress=True, file_name=None)
|
| 59 |
+
|
| 60 |
+
# Download weights on startup
|
| 61 |
+
print("Checking weights...")
|
| 62 |
+
download_weights()
|
| 63 |
+
|
| 64 |
+
# Global models
|
| 65 |
+
device = get_device()
|
| 66 |
+
upsampler = None
|
| 67 |
+
codeformer_net = None
|
| 68 |
+
|
| 69 |
+
def init_models():
|
| 70 |
+
global upsampler, codeformer_net
|
| 71 |
+
|
| 72 |
+
# RealESRGAN
|
| 73 |
+
half = True if gpu_is_available() else False
|
| 74 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
| 75 |
+
upsampler = RealESRGANer(
|
| 76 |
+
scale=2,
|
| 77 |
+
model_path="weights/realesrgan/RealESRGAN_x2plus.pth",
|
| 78 |
+
model=model,
|
| 79 |
+
tile=400,
|
| 80 |
+
tile_pad=40,
|
| 81 |
+
pre_pad=0,
|
| 82 |
+
half=half,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# CodeFormer
|
| 86 |
+
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
|
| 87 |
+
dim_embd=512,
|
| 88 |
+
codebook_size=1024,
|
| 89 |
+
n_head=8,
|
| 90 |
+
n_layers=9,
|
| 91 |
+
connect_list=["32", "64", "128", "256"],
|
| 92 |
+
).to(device)
|
| 93 |
+
|
| 94 |
+
ckpt_path = "weights/CodeFormer/codeformer.pth"
|
| 95 |
+
checkpoint = torch.load(ckpt_path)["params_ema"]
|
| 96 |
+
codeformer_net.load_state_dict(checkpoint)
|
| 97 |
+
codeformer_net.eval()
|
| 98 |
+
print("Models loaded successfully.")
|
| 99 |
+
|
| 100 |
+
init_models()
|
| 101 |
+
|
| 102 |
+
def process_image(img_path, background_enhance, face_upsample, upscale, codeformer_fidelity):
|
| 103 |
+
"""Core inference logic"""
|
| 104 |
+
try:
|
| 105 |
+
# Defaults
|
| 106 |
+
has_aligned = False
|
| 107 |
+
only_center_face = False
|
| 108 |
+
draw_box = False
|
| 109 |
+
detection_model = "retinaface_resnet50"
|
| 110 |
+
|
| 111 |
+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
| 112 |
+
|
| 113 |
+
# Memory safety checks
|
| 114 |
+
upscale = int(upscale)
|
| 115 |
+
if upscale > 4: upscale = 4
|
| 116 |
+
if upscale > 2 and max(img.shape[:2]) > 1000: upscale = 2
|
| 117 |
+
if max(img.shape[:2]) > 1500:
|
| 118 |
+
upscale = 1
|
| 119 |
+
background_enhance = False
|
| 120 |
+
face_upsample = False
|
| 121 |
+
|
| 122 |
+
face_helper = FaceRestoreHelper(
|
| 123 |
+
upscale,
|
| 124 |
+
face_size=512,
|
| 125 |
+
crop_ratio=(1, 1),
|
| 126 |
+
det_model=detection_model,
|
| 127 |
+
save_ext="png",
|
| 128 |
+
use_parse=True,
|
| 129 |
+
device=device,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
bg_upsampler = upsampler if background_enhance else None
|
| 133 |
+
face_upsampler = upsampler if face_upsample else None
|
| 134 |
+
|
| 135 |
+
if has_aligned:
|
| 136 |
+
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 137 |
+
face_helper.is_gray = is_gray(img, threshold=5)
|
| 138 |
+
face_helper.cropped_faces = [img]
|
| 139 |
+
else:
|
| 140 |
+
face_helper.read_image(img)
|
| 141 |
+
face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
|
| 142 |
+
face_helper.align_warp_face()
|
| 143 |
+
|
| 144 |
+
# Face restoration
|
| 145 |
+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
| 146 |
+
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
| 147 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 148 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
|
| 153 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Inference error: {e}")
|
| 156 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
| 157 |
+
|
| 158 |
+
restored_face = restored_face.astype("uint8")
|
| 159 |
+
face_helper.add_restored_face(restored_face)
|
| 160 |
+
|
| 161 |
+
# Paste back
|
| 162 |
+
if not has_aligned:
|
| 163 |
+
bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None
|
| 164 |
+
face_helper.get_inverse_affine(None)
|
| 165 |
+
|
| 166 |
+
if face_upsample and face_upsampler:
|
| 167 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler)
|
| 168 |
+
else:
|
| 169 |
+
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
|
| 170 |
+
else:
|
| 171 |
+
restored_img = face_helper.restored_faces[0]
|
| 172 |
+
|
| 173 |
+
return restored_img
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Global processing error: {e}")
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
# --- Routes ---
|
| 180 |
+
|
| 181 |
+
@app.route('/', methods=['GET'])
|
| 182 |
+
def index():
|
| 183 |
+
return render_template('index.html')
|
| 184 |
+
|
| 185 |
+
@app.route('/process', methods=['POST'])
|
| 186 |
+
def process():
|
| 187 |
+
if 'image' not in request.files:
|
| 188 |
+
return "No image uploaded", 400
|
| 189 |
+
|
| 190 |
+
files = request.files.getlist('image')
|
| 191 |
+
if not files or files[0].filename == '':
|
| 192 |
+
return "No selected file", 400
|
| 193 |
+
|
| 194 |
+
results = []
|
| 195 |
+
|
| 196 |
+
# Get params (same for all images)
|
| 197 |
+
try:
|
| 198 |
+
fidelity = float(request.form.get('fidelity', 0.5))
|
| 199 |
+
upscale = 4 # Enforce 4x upscale
|
| 200 |
+
background_enhance = 'background_enhance' in request.form
|
| 201 |
+
face_upsample = 'face_upsample' in request.form
|
| 202 |
+
except ValueError:
|
| 203 |
+
return "Invalid parameters", 400
|
| 204 |
+
|
| 205 |
+
for file in files:
|
| 206 |
+
if file.filename == '': continue
|
| 207 |
+
|
| 208 |
+
# Save input
|
| 209 |
+
filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename)
|
| 210 |
+
input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 211 |
+
file.save(input_path)
|
| 212 |
+
|
| 213 |
+
# Process
|
| 214 |
+
result_img = process_image(input_path, background_enhance, face_upsample, upscale, fidelity)
|
| 215 |
+
|
| 216 |
+
if result_img is None:
|
| 217 |
+
continue # Skip failed images or handle error appropriately
|
| 218 |
+
|
| 219 |
+
# Save output
|
| 220 |
+
output_filename = "result_" + filename.rsplit('.', 1)[0] + ".png"
|
| 221 |
+
output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
|
| 222 |
+
imwrite(result_img, output_path)
|
| 223 |
+
|
| 224 |
+
# Generate preview (max 1000px width/height)
|
| 225 |
+
preview_filename = "preview_" + output_filename
|
| 226 |
+
preview_path = os.path.join(app.config['RESULT_FOLDER'], preview_filename)
|
| 227 |
+
|
| 228 |
+
h, w = result_img.shape[:2]
|
| 229 |
+
if max(h, w) > 1000:
|
| 230 |
+
scale = 1000 / max(h, w)
|
| 231 |
+
preview_img = cv2.resize(result_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
| 232 |
+
imwrite(preview_img, preview_path)
|
| 233 |
+
else:
|
| 234 |
+
preview_filename = output_filename
|
| 235 |
+
|
| 236 |
+
results.append({
|
| 237 |
+
'original': filename,
|
| 238 |
+
'preview': preview_filename,
|
| 239 |
+
'download': output_filename
|
| 240 |
+
})
|
| 241 |
+
|
| 242 |
+
if not results:
|
| 243 |
+
return "Processing failed for all images", 500
|
| 244 |
+
|
| 245 |
+
# Create ZIP of all results
|
| 246 |
+
zip_filename = f"batch_{uuid.uuid4()}.zip"
|
| 247 |
+
zip_path = os.path.join(app.config['RESULT_FOLDER'], zip_filename)
|
| 248 |
+
|
| 249 |
+
with zipfile.ZipFile(zip_path, 'w') as zipf:
|
| 250 |
+
for item in results:
|
| 251 |
+
file_path = os.path.join(app.config['RESULT_FOLDER'], item['download'])
|
| 252 |
+
zipf.write(file_path, item['download'])
|
| 253 |
+
|
| 254 |
+
return render_template('result.html', results=results, zip_filename=zip_filename)
|
| 255 |
+
|
| 256 |
+
# --- API Routes ---
|
| 257 |
+
|
| 258 |
+
@app.route('/api/process', methods=['POST'])
|
| 259 |
+
def api_process():
|
| 260 |
+
"""
|
| 261 |
+
API endpoint for image processing.
|
| 262 |
+
Accepts:
|
| 263 |
+
- multipart/form-data with one or more 'image' files.
|
| 264 |
+
- application/json with 'image_base64' string (single image) or 'images_base64' list.
|
| 265 |
+
Parameters (form or JSON):
|
| 266 |
+
- fidelity: (float) 0-1, default 0.5.
|
| 267 |
+
- background_enhance: (bool) default False.
|
| 268 |
+
- face_upsample: (bool) default False.
|
| 269 |
+
- upscale: (int) 1-4, default 2.
|
| 270 |
+
- return_base64: (bool) default False.
|
| 271 |
+
"""
|
| 272 |
+
try:
|
| 273 |
+
is_json = request.is_json
|
| 274 |
+
data = request.get_json() if is_json else request.form
|
| 275 |
+
|
| 276 |
+
fidelity = float(data.get('fidelity', 0.5))
|
| 277 |
+
background_enhance = (str(data.get('background_enhance', 'false')).lower() == 'true') if not is_json else data.get('background_enhance', False)
|
| 278 |
+
face_upsample = (str(data.get('face_upsample', 'false')).lower() == 'true') if not is_json else data.get('face_upsample', False)
|
| 279 |
+
upscale = int(data.get('upscale', 2))
|
| 280 |
+
return_base64 = (str(data.get('return_base64', 'false')).lower() == 'true') if not is_json else data.get('return_base64', False)
|
| 281 |
+
|
| 282 |
+
processed_images = []
|
| 283 |
+
inputs = []
|
| 284 |
+
|
| 285 |
+
# Handle JSON input
|
| 286 |
+
if is_json:
|
| 287 |
+
if 'image_base64' in data:
|
| 288 |
+
inputs.append({'data': data['image_base64'], 'name': 'image.png'})
|
| 289 |
+
elif 'images_base64' in data:
|
| 290 |
+
for idx, img_b64 in enumerate(data['images_base64']):
|
| 291 |
+
inputs.append({'data': img_b64, 'name': f'image_{idx}.png'})
|
| 292 |
+
|
| 293 |
+
for inp in inputs:
|
| 294 |
+
temp_filename = str(uuid.uuid4())
|
| 295 |
+
image_data = base64.b64decode(inp['data'].split(',')[-1])
|
| 296 |
+
input_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{temp_filename}.png")
|
| 297 |
+
with open(input_path, 'wb') as f:
|
| 298 |
+
f.write(image_data)
|
| 299 |
+
inp['path'] = input_path
|
| 300 |
+
inp['temp_id'] = temp_filename
|
| 301 |
+
|
| 302 |
+
# Handle Multipart input
|
| 303 |
+
elif 'image' in request.files:
|
| 304 |
+
files = request.files.getlist('image')
|
| 305 |
+
for file in files:
|
| 306 |
+
if file.filename != '':
|
| 307 |
+
temp_filename = str(uuid.uuid4())
|
| 308 |
+
filename = f"{temp_filename}_{secure_filename(file.filename)}"
|
| 309 |
+
input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
| 310 |
+
file.save(input_path)
|
| 311 |
+
inputs.append({'path': input_path, 'name': file.filename, 'temp_id': temp_filename})
|
| 312 |
+
|
| 313 |
+
if not inputs:
|
| 314 |
+
return jsonify({"status": "error", "message": "No images provided"}), 400
|
| 315 |
+
|
| 316 |
+
for inp in inputs:
|
| 317 |
+
# Process image
|
| 318 |
+
result_img = process_image(inp['path'], background_enhance, face_upsample, upscale, fidelity)
|
| 319 |
+
if result_img is not None:
|
| 320 |
+
# Save result
|
| 321 |
+
output_filename = f"api_result_{inp['temp_id']}.png"
|
| 322 |
+
output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
|
| 323 |
+
imwrite(result_img, output_path)
|
| 324 |
+
|
| 325 |
+
res = {
|
| 326 |
+
"original_name": inp['name'],
|
| 327 |
+
"image_url": url_for('static', filename=f'results/{output_filename}', _external=True),
|
| 328 |
+
"filename": output_filename
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
if return_base64:
|
| 332 |
+
_, buffer = cv2.imencode('.png', result_img)
|
| 333 |
+
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
| 334 |
+
res["image_base64"] = img_base64
|
| 335 |
+
|
| 336 |
+
processed_images.append(res)
|
| 337 |
+
|
| 338 |
+
if not processed_images:
|
| 339 |
+
return jsonify({"status": "error", "message": "Processing failed for all images"}), 500
|
| 340 |
+
|
| 341 |
+
return jsonify({
|
| 342 |
+
"status": "success",
|
| 343 |
+
"count": len(processed_images),
|
| 344 |
+
"results": processed_images
|
| 345 |
+
})
|
| 346 |
+
|
| 347 |
+
except Exception as e:
|
| 348 |
+
import traceback
|
| 349 |
+
traceback.print_exc()
|
| 350 |
+
return jsonify({"status": "error", "message": str(e)}), 500
|
| 351 |
+
|
| 352 |
+
@app.route('/api/health', methods=['GET'])
|
| 353 |
+
def health_check():
|
| 354 |
+
return jsonify({"status": "online", "device": str(device)})
|
| 355 |
+
|
| 356 |
+
if __name__ == '__main__':
|
| 357 |
+
# Docker/HF Spaces entry point
|
| 358 |
+
app.run(host='0.0.0.0', port=7860)
|
basicsr/VERSION
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1.3.2
|
basicsr/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/xinntao/BasicSR
|
| 2 |
+
# flake8: noqa
|
| 3 |
+
from .archs import *
|
| 4 |
+
from .data import *
|
| 5 |
+
from .losses import *
|
| 6 |
+
from .metrics import *
|
| 7 |
+
from .models import *
|
| 8 |
+
from .ops import *
|
| 9 |
+
from .train import *
|
| 10 |
+
from .utils import *
|
| 11 |
+
from .version import __gitsha__, __version__
|
basicsr/archs/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
from basicsr.utils import get_root_logger, scandir
|
| 6 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 7 |
+
|
| 8 |
+
__all__ = ['build_network']
|
| 9 |
+
|
| 10 |
+
# automatically scan and import arch modules for registry
|
| 11 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
| 12 |
+
# '_arch.py'
|
| 13 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
| 14 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
| 15 |
+
# import all the arch modules
|
| 16 |
+
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_network(opt):
|
| 20 |
+
opt = deepcopy(opt)
|
| 21 |
+
network_type = opt.pop('type')
|
| 22 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
| 23 |
+
logger = get_root_logger()
|
| 24 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
| 25 |
+
return net
|
basicsr/archs/arcface_arch.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def conv3x3(inplanes, outplanes, stride=1):
|
| 6 |
+
"""A simple wrapper for 3x3 convolution with padding.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
inplanes (int): Channel number of inputs.
|
| 10 |
+
outplanes (int): Channel number of outputs.
|
| 11 |
+
stride (int): Stride in convolution. Default: 1.
|
| 12 |
+
"""
|
| 13 |
+
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BasicBlock(nn.Module):
|
| 17 |
+
"""Basic residual block used in the ResNetArcFace architecture.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
inplanes (int): Channel number of inputs.
|
| 21 |
+
planes (int): Channel number of outputs.
|
| 22 |
+
stride (int): Stride in convolution. Default: 1.
|
| 23 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 24 |
+
"""
|
| 25 |
+
expansion = 1 # output channel expansion ratio
|
| 26 |
+
|
| 27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 28 |
+
super(BasicBlock, self).__init__()
|
| 29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 30 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 31 |
+
self.relu = nn.ReLU(inplace=True)
|
| 32 |
+
self.conv2 = conv3x3(planes, planes)
|
| 33 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 34 |
+
self.downsample = downsample
|
| 35 |
+
self.stride = stride
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
residual = x
|
| 39 |
+
|
| 40 |
+
out = self.conv1(x)
|
| 41 |
+
out = self.bn1(out)
|
| 42 |
+
out = self.relu(out)
|
| 43 |
+
|
| 44 |
+
out = self.conv2(out)
|
| 45 |
+
out = self.bn2(out)
|
| 46 |
+
|
| 47 |
+
if self.downsample is not None:
|
| 48 |
+
residual = self.downsample(x)
|
| 49 |
+
|
| 50 |
+
out += residual
|
| 51 |
+
out = self.relu(out)
|
| 52 |
+
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class IRBlock(nn.Module):
|
| 57 |
+
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
inplanes (int): Channel number of inputs.
|
| 61 |
+
planes (int): Channel number of outputs.
|
| 62 |
+
stride (int): Stride in convolution. Default: 1.
|
| 63 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 64 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 65 |
+
"""
|
| 66 |
+
expansion = 1 # output channel expansion ratio
|
| 67 |
+
|
| 68 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
|
| 69 |
+
super(IRBlock, self).__init__()
|
| 70 |
+
self.bn0 = nn.BatchNorm2d(inplanes)
|
| 71 |
+
self.conv1 = conv3x3(inplanes, inplanes)
|
| 72 |
+
self.bn1 = nn.BatchNorm2d(inplanes)
|
| 73 |
+
self.prelu = nn.PReLU()
|
| 74 |
+
self.conv2 = conv3x3(inplanes, planes, stride)
|
| 75 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 76 |
+
self.downsample = downsample
|
| 77 |
+
self.stride = stride
|
| 78 |
+
self.use_se = use_se
|
| 79 |
+
if self.use_se:
|
| 80 |
+
self.se = SEBlock(planes)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
residual = x
|
| 84 |
+
out = self.bn0(x)
|
| 85 |
+
out = self.conv1(out)
|
| 86 |
+
out = self.bn1(out)
|
| 87 |
+
out = self.prelu(out)
|
| 88 |
+
|
| 89 |
+
out = self.conv2(out)
|
| 90 |
+
out = self.bn2(out)
|
| 91 |
+
if self.use_se:
|
| 92 |
+
out = self.se(out)
|
| 93 |
+
|
| 94 |
+
if self.downsample is not None:
|
| 95 |
+
residual = self.downsample(x)
|
| 96 |
+
|
| 97 |
+
out += residual
|
| 98 |
+
out = self.prelu(out)
|
| 99 |
+
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Bottleneck(nn.Module):
|
| 104 |
+
"""Bottleneck block used in the ResNetArcFace architecture.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
inplanes (int): Channel number of inputs.
|
| 108 |
+
planes (int): Channel number of outputs.
|
| 109 |
+
stride (int): Stride in convolution. Default: 1.
|
| 110 |
+
downsample (nn.Module): The downsample module. Default: None.
|
| 111 |
+
"""
|
| 112 |
+
expansion = 4 # output channel expansion ratio
|
| 113 |
+
|
| 114 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 115 |
+
super(Bottleneck, self).__init__()
|
| 116 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 117 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 118 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 119 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 120 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 121 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 122 |
+
self.relu = nn.ReLU(inplace=True)
|
| 123 |
+
self.downsample = downsample
|
| 124 |
+
self.stride = stride
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
residual = x
|
| 128 |
+
|
| 129 |
+
out = self.conv1(x)
|
| 130 |
+
out = self.bn1(out)
|
| 131 |
+
out = self.relu(out)
|
| 132 |
+
|
| 133 |
+
out = self.conv2(out)
|
| 134 |
+
out = self.bn2(out)
|
| 135 |
+
out = self.relu(out)
|
| 136 |
+
|
| 137 |
+
out = self.conv3(out)
|
| 138 |
+
out = self.bn3(out)
|
| 139 |
+
|
| 140 |
+
if self.downsample is not None:
|
| 141 |
+
residual = self.downsample(x)
|
| 142 |
+
|
| 143 |
+
out += residual
|
| 144 |
+
out = self.relu(out)
|
| 145 |
+
|
| 146 |
+
return out
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SEBlock(nn.Module):
|
| 150 |
+
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
channel (int): Channel number of inputs.
|
| 154 |
+
reduction (int): Channel reduction ration. Default: 16.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, channel, reduction=16):
|
| 158 |
+
super(SEBlock, self).__init__()
|
| 159 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
|
| 160 |
+
self.fc = nn.Sequential(
|
| 161 |
+
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
|
| 162 |
+
nn.Sigmoid())
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
b, c, _, _ = x.size()
|
| 166 |
+
y = self.avg_pool(x).view(b, c)
|
| 167 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 168 |
+
return x * y
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@ARCH_REGISTRY.register()
|
| 172 |
+
class ResNetArcFace(nn.Module):
|
| 173 |
+
"""ArcFace with ResNet architectures.
|
| 174 |
+
|
| 175 |
+
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
block (str): Block used in the ArcFace architecture.
|
| 179 |
+
layers (tuple(int)): Block numbers in each layer.
|
| 180 |
+
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, block, layers, use_se=True):
|
| 184 |
+
if block == 'IRBlock':
|
| 185 |
+
block = IRBlock
|
| 186 |
+
self.inplanes = 64
|
| 187 |
+
self.use_se = use_se
|
| 188 |
+
super(ResNetArcFace, self).__init__()
|
| 189 |
+
|
| 190 |
+
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
|
| 191 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 192 |
+
self.prelu = nn.PReLU()
|
| 193 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 194 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 195 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 196 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 197 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 198 |
+
self.bn4 = nn.BatchNorm2d(512)
|
| 199 |
+
self.dropout = nn.Dropout()
|
| 200 |
+
self.fc5 = nn.Linear(512 * 8 * 8, 512)
|
| 201 |
+
self.bn5 = nn.BatchNorm1d(512)
|
| 202 |
+
|
| 203 |
+
# initialization
|
| 204 |
+
for m in self.modules():
|
| 205 |
+
if isinstance(m, nn.Conv2d):
|
| 206 |
+
nn.init.xavier_normal_(m.weight)
|
| 207 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
| 208 |
+
nn.init.constant_(m.weight, 1)
|
| 209 |
+
nn.init.constant_(m.bias, 0)
|
| 210 |
+
elif isinstance(m, nn.Linear):
|
| 211 |
+
nn.init.xavier_normal_(m.weight)
|
| 212 |
+
nn.init.constant_(m.bias, 0)
|
| 213 |
+
|
| 214 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
| 215 |
+
downsample = None
|
| 216 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 217 |
+
downsample = nn.Sequential(
|
| 218 |
+
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
|
| 219 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 220 |
+
)
|
| 221 |
+
layers = []
|
| 222 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
|
| 223 |
+
self.inplanes = planes
|
| 224 |
+
for _ in range(1, num_blocks):
|
| 225 |
+
layers.append(block(self.inplanes, planes, use_se=self.use_se))
|
| 226 |
+
|
| 227 |
+
return nn.Sequential(*layers)
|
| 228 |
+
|
| 229 |
+
def forward(self, x):
|
| 230 |
+
x = self.conv1(x)
|
| 231 |
+
x = self.bn1(x)
|
| 232 |
+
x = self.prelu(x)
|
| 233 |
+
x = self.maxpool(x)
|
| 234 |
+
|
| 235 |
+
x = self.layer1(x)
|
| 236 |
+
x = self.layer2(x)
|
| 237 |
+
x = self.layer3(x)
|
| 238 |
+
x = self.layer4(x)
|
| 239 |
+
x = self.bn4(x)
|
| 240 |
+
x = self.dropout(x)
|
| 241 |
+
x = x.view(x.size(0), -1)
|
| 242 |
+
x = self.fc5(x)
|
| 243 |
+
x = self.bn5(x)
|
| 244 |
+
|
| 245 |
+
return x
|
basicsr/archs/arch_util.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections.abc
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
import warnings
|
| 6 |
+
from distutils.version import LooseVersion
|
| 7 |
+
from itertools import repeat
|
| 8 |
+
from torch import nn as nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.nn import init as init
|
| 11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 12 |
+
|
| 13 |
+
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
| 14 |
+
from basicsr.utils import get_root_logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
| 19 |
+
"""Initialize network weights.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
| 23 |
+
scale (float): Scale initialized weights, especially for residual
|
| 24 |
+
blocks. Default: 1.
|
| 25 |
+
bias_fill (float): The value to fill bias. Default: 0
|
| 26 |
+
kwargs (dict): Other arguments for initialization function.
|
| 27 |
+
"""
|
| 28 |
+
if not isinstance(module_list, list):
|
| 29 |
+
module_list = [module_list]
|
| 30 |
+
for module in module_list:
|
| 31 |
+
for m in module.modules():
|
| 32 |
+
if isinstance(m, nn.Conv2d):
|
| 33 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 34 |
+
m.weight.data *= scale
|
| 35 |
+
if m.bias is not None:
|
| 36 |
+
m.bias.data.fill_(bias_fill)
|
| 37 |
+
elif isinstance(m, nn.Linear):
|
| 38 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
| 39 |
+
m.weight.data *= scale
|
| 40 |
+
if m.bias is not None:
|
| 41 |
+
m.bias.data.fill_(bias_fill)
|
| 42 |
+
elif isinstance(m, _BatchNorm):
|
| 43 |
+
init.constant_(m.weight, 1)
|
| 44 |
+
if m.bias is not None:
|
| 45 |
+
m.bias.data.fill_(bias_fill)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
| 49 |
+
"""Make layers by stacking the same blocks.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
basic_block (nn.module): nn.module class for basic block.
|
| 53 |
+
num_basic_block (int): number of blocks.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
| 57 |
+
"""
|
| 58 |
+
layers = []
|
| 59 |
+
for _ in range(num_basic_block):
|
| 60 |
+
layers.append(basic_block(**kwarg))
|
| 61 |
+
return nn.Sequential(*layers)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResidualBlockNoBN(nn.Module):
|
| 65 |
+
"""Residual block without BN.
|
| 66 |
+
|
| 67 |
+
It has a style of:
|
| 68 |
+
---Conv-ReLU-Conv-+-
|
| 69 |
+
|________________|
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
num_feat (int): Channel number of intermediate features.
|
| 73 |
+
Default: 64.
|
| 74 |
+
res_scale (float): Residual scale. Default: 1.
|
| 75 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
| 76 |
+
otherwise, use default_init_weights. Default: False.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
| 80 |
+
super(ResidualBlockNoBN, self).__init__()
|
| 81 |
+
self.res_scale = res_scale
|
| 82 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 83 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
| 84 |
+
self.relu = nn.ReLU(inplace=True)
|
| 85 |
+
|
| 86 |
+
if not pytorch_init:
|
| 87 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
identity = x
|
| 91 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
| 92 |
+
return identity + out * self.res_scale
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Upsample(nn.Sequential):
|
| 96 |
+
"""Upsample module.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
| 100 |
+
num_feat (int): Channel number of intermediate features.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(self, scale, num_feat):
|
| 104 |
+
m = []
|
| 105 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
| 106 |
+
for _ in range(int(math.log(scale, 2))):
|
| 107 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
| 108 |
+
m.append(nn.PixelShuffle(2))
|
| 109 |
+
elif scale == 3:
|
| 110 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
| 111 |
+
m.append(nn.PixelShuffle(3))
|
| 112 |
+
else:
|
| 113 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
| 114 |
+
super(Upsample, self).__init__(*m)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
| 118 |
+
"""Warp an image or feature map with optical flow.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
| 122 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
| 123 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
| 124 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
| 125 |
+
Default: 'zeros'.
|
| 126 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
| 127 |
+
align_corners=True. After pytorch 1.3, the default value is
|
| 128 |
+
align_corners=False. Here, we use the True as default.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tensor: Warped image or feature map.
|
| 132 |
+
"""
|
| 133 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
| 134 |
+
_, _, h, w = x.size()
|
| 135 |
+
# create mesh grid
|
| 136 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
| 137 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
| 138 |
+
grid.requires_grad = False
|
| 139 |
+
|
| 140 |
+
vgrid = grid + flow
|
| 141 |
+
# scale grid to [-1,1]
|
| 142 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
| 143 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
| 144 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
| 145 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
| 146 |
+
|
| 147 |
+
# TODO, what if align_corners=False
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
| 152 |
+
"""Resize a flow according to ratio or shape.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
| 156 |
+
size_type (str): 'ratio' or 'shape'.
|
| 157 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
| 158 |
+
shape.
|
| 159 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
| 160 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
| 161 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
| 162 |
+
ratio > 1.0).
|
| 163 |
+
2) The order of output_size should be [out_h, out_w].
|
| 164 |
+
interp_mode (str): The mode of interpolation for resizing.
|
| 165 |
+
Default: 'bilinear'.
|
| 166 |
+
align_corners (bool): Whether align corners. Default: False.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tensor: Resized flow.
|
| 170 |
+
"""
|
| 171 |
+
_, _, flow_h, flow_w = flow.size()
|
| 172 |
+
if size_type == 'ratio':
|
| 173 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
| 174 |
+
elif size_type == 'shape':
|
| 175 |
+
output_h, output_w = sizes[0], sizes[1]
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
| 178 |
+
|
| 179 |
+
input_flow = flow.clone()
|
| 180 |
+
ratio_h = output_h / flow_h
|
| 181 |
+
ratio_w = output_w / flow_w
|
| 182 |
+
input_flow[:, 0, :, :] *= ratio_w
|
| 183 |
+
input_flow[:, 1, :, :] *= ratio_h
|
| 184 |
+
resized_flow = F.interpolate(
|
| 185 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
| 186 |
+
return resized_flow
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# TODO: may write a cpp file
|
| 190 |
+
def pixel_unshuffle(x, scale):
|
| 191 |
+
""" Pixel unshuffle.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
| 195 |
+
scale (int): Downsample ratio.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tensor: the pixel unshuffled feature.
|
| 199 |
+
"""
|
| 200 |
+
b, c, hh, hw = x.size()
|
| 201 |
+
out_channel = c * (scale**2)
|
| 202 |
+
assert hh % scale == 0 and hw % scale == 0
|
| 203 |
+
h = hh // scale
|
| 204 |
+
w = hw // scale
|
| 205 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
| 206 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class DCNv2Pack(ModulatedDeformConvPack):
|
| 210 |
+
"""Modulated deformable conv for deformable alignment.
|
| 211 |
+
|
| 212 |
+
Different from the official DCNv2Pack, which generates offsets and masks
|
| 213 |
+
from the preceding features, this DCNv2Pack takes another different
|
| 214 |
+
features to generate offsets and masks.
|
| 215 |
+
|
| 216 |
+
Ref:
|
| 217 |
+
Delving Deep into Deformable Alignment in Video Super-Resolution.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def forward(self, x, feat):
|
| 221 |
+
out = self.conv_offset(feat)
|
| 222 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 223 |
+
offset = torch.cat((o1, o2), dim=1)
|
| 224 |
+
mask = torch.sigmoid(mask)
|
| 225 |
+
|
| 226 |
+
offset_absmean = torch.mean(torch.abs(offset))
|
| 227 |
+
if offset_absmean > 50:
|
| 228 |
+
logger = get_root_logger()
|
| 229 |
+
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
| 230 |
+
|
| 231 |
+
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
| 232 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
| 233 |
+
self.dilation, mask)
|
| 234 |
+
else:
|
| 235 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
| 236 |
+
self.dilation, self.groups, self.deformable_groups)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 240 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
| 241 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 242 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 243 |
+
def norm_cdf(x):
|
| 244 |
+
# Computes standard normal cumulative distribution function
|
| 245 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 246 |
+
|
| 247 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 248 |
+
warnings.warn(
|
| 249 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
| 250 |
+
'The distribution of values may be incorrect.',
|
| 251 |
+
stacklevel=2)
|
| 252 |
+
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
# Values are generated by using a truncated uniform distribution and
|
| 255 |
+
# then using the inverse CDF for the normal distribution.
|
| 256 |
+
# Get upper and lower cdf values
|
| 257 |
+
low = norm_cdf((a - mean) / std)
|
| 258 |
+
up = norm_cdf((b - mean) / std)
|
| 259 |
+
|
| 260 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
| 261 |
+
# [2l-1, 2u-1].
|
| 262 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
| 263 |
+
|
| 264 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 265 |
+
# standard normal
|
| 266 |
+
tensor.erfinv_()
|
| 267 |
+
|
| 268 |
+
# Transform to proper mean, std
|
| 269 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 270 |
+
tensor.add_(mean)
|
| 271 |
+
|
| 272 |
+
# Clamp to ensure it's in the proper range
|
| 273 |
+
tensor.clamp_(min=a, max=b)
|
| 274 |
+
return tensor
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 278 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
| 279 |
+
normal distribution.
|
| 280 |
+
|
| 281 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
| 282 |
+
|
| 283 |
+
The values are effectively drawn from the
|
| 284 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
| 285 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
| 286 |
+
the bounds. The method used for generating the random values works
|
| 287 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
tensor: an n-dimensional `torch.Tensor`
|
| 291 |
+
mean: the mean of the normal distribution
|
| 292 |
+
std: the standard deviation of the normal distribution
|
| 293 |
+
a: the minimum cutoff value
|
| 294 |
+
b: the maximum cutoff value
|
| 295 |
+
|
| 296 |
+
Examples:
|
| 297 |
+
>>> w = torch.empty(3, 5)
|
| 298 |
+
>>> nn.init.trunc_normal_(w)
|
| 299 |
+
"""
|
| 300 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# From PyTorch
|
| 304 |
+
def _ntuple(n):
|
| 305 |
+
|
| 306 |
+
def parse(x):
|
| 307 |
+
if isinstance(x, collections.abc.Iterable):
|
| 308 |
+
return x
|
| 309 |
+
return tuple(repeat(x, n))
|
| 310 |
+
|
| 311 |
+
return parse
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
to_1tuple = _ntuple(1)
|
| 315 |
+
to_2tuple = _ntuple(2)
|
| 316 |
+
to_3tuple = _ntuple(3)
|
| 317 |
+
to_4tuple = _ntuple(4)
|
| 318 |
+
to_ntuple = _ntuple
|
basicsr/archs/codeformer_arch.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, Tensor
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
|
| 8 |
+
from basicsr.archs.vqgan_arch import *
|
| 9 |
+
from basicsr.utils import get_root_logger
|
| 10 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 11 |
+
|
| 12 |
+
def calc_mean_std(feat, eps=1e-5):
|
| 13 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
feat (Tensor): 4D tensor.
|
| 17 |
+
eps (float): A small value added to the variance to avoid
|
| 18 |
+
divide-by-zero. Default: 1e-5.
|
| 19 |
+
"""
|
| 20 |
+
size = feat.size()
|
| 21 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
| 22 |
+
b, c = size[:2]
|
| 23 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
| 24 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
| 25 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
| 26 |
+
return feat_mean, feat_std
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
| 30 |
+
"""Adaptive instance normalization.
|
| 31 |
+
|
| 32 |
+
Adjust the reference features to have the similar color and illuminations
|
| 33 |
+
as those in the degradate features.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
content_feat (Tensor): The reference feature.
|
| 37 |
+
style_feat (Tensor): The degradate features.
|
| 38 |
+
"""
|
| 39 |
+
size = content_feat.size()
|
| 40 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
| 41 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
| 42 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
| 43 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PositionEmbeddingSine(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
This is a more standard version of the position embedding, very similar to the one
|
| 49 |
+
used by the Attention is all you need paper, generalized to work on images.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.num_pos_feats = num_pos_feats
|
| 55 |
+
self.temperature = temperature
|
| 56 |
+
self.normalize = normalize
|
| 57 |
+
if scale is not None and normalize is False:
|
| 58 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 59 |
+
if scale is None:
|
| 60 |
+
scale = 2 * math.pi
|
| 61 |
+
self.scale = scale
|
| 62 |
+
|
| 63 |
+
def forward(self, x, mask=None):
|
| 64 |
+
if mask is None:
|
| 65 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
| 66 |
+
not_mask = ~mask
|
| 67 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 68 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 69 |
+
if self.normalize:
|
| 70 |
+
eps = 1e-6
|
| 71 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 72 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 73 |
+
|
| 74 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 75 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
| 76 |
+
|
| 77 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 78 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 79 |
+
pos_x = torch.stack(
|
| 80 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 81 |
+
).flatten(3)
|
| 82 |
+
pos_y = torch.stack(
|
| 83 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 84 |
+
).flatten(3)
|
| 85 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 86 |
+
return pos
|
| 87 |
+
|
| 88 |
+
def _get_activation_fn(activation):
|
| 89 |
+
"""Return an activation function given a string"""
|
| 90 |
+
if activation == "relu":
|
| 91 |
+
return F.relu
|
| 92 |
+
if activation == "gelu":
|
| 93 |
+
return F.gelu
|
| 94 |
+
if activation == "glu":
|
| 95 |
+
return F.glu
|
| 96 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TransformerSALayer(nn.Module):
|
| 100 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
| 103 |
+
# Implementation of Feedforward model - MLP
|
| 104 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
| 105 |
+
self.dropout = nn.Dropout(dropout)
|
| 106 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
| 107 |
+
|
| 108 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
| 109 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
| 110 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 111 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 112 |
+
|
| 113 |
+
self.activation = _get_activation_fn(activation)
|
| 114 |
+
|
| 115 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
| 116 |
+
return tensor if pos is None else tensor + pos
|
| 117 |
+
|
| 118 |
+
def forward(self, tgt,
|
| 119 |
+
tgt_mask: Optional[Tensor] = None,
|
| 120 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 121 |
+
query_pos: Optional[Tensor] = None):
|
| 122 |
+
|
| 123 |
+
# self attention
|
| 124 |
+
tgt2 = self.norm1(tgt)
|
| 125 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
| 126 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
| 127 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
| 128 |
+
tgt = tgt + self.dropout1(tgt2)
|
| 129 |
+
|
| 130 |
+
# ffn
|
| 131 |
+
tgt2 = self.norm2(tgt)
|
| 132 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
| 133 |
+
tgt = tgt + self.dropout2(tgt2)
|
| 134 |
+
return tgt
|
| 135 |
+
|
| 136 |
+
class Fuse_sft_block(nn.Module):
|
| 137 |
+
def __init__(self, in_ch, out_ch):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
| 140 |
+
|
| 141 |
+
self.scale = nn.Sequential(
|
| 142 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
| 143 |
+
nn.LeakyReLU(0.2, True),
|
| 144 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
| 145 |
+
|
| 146 |
+
self.shift = nn.Sequential(
|
| 147 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
| 148 |
+
nn.LeakyReLU(0.2, True),
|
| 149 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
| 150 |
+
|
| 151 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
| 152 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
| 153 |
+
scale = self.scale(enc_feat)
|
| 154 |
+
shift = self.shift(enc_feat)
|
| 155 |
+
residual = w * (dec_feat * scale + shift)
|
| 156 |
+
out = dec_feat + residual
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@ARCH_REGISTRY.register()
|
| 161 |
+
class CodeFormer(VQAutoEncoder):
|
| 162 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
| 163 |
+
codebook_size=1024, latent_size=256,
|
| 164 |
+
connect_list=['32', '64', '128', '256'],
|
| 165 |
+
fix_modules=['quantize','generator'], vqgan_path=None):
|
| 166 |
+
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
| 167 |
+
|
| 168 |
+
if vqgan_path is not None:
|
| 169 |
+
self.load_state_dict(
|
| 170 |
+
torch.load(vqgan_path, map_location='cpu')['params_ema'])
|
| 171 |
+
|
| 172 |
+
if fix_modules is not None:
|
| 173 |
+
for module in fix_modules:
|
| 174 |
+
for param in getattr(self, module).parameters():
|
| 175 |
+
param.requires_grad = False
|
| 176 |
+
|
| 177 |
+
self.connect_list = connect_list
|
| 178 |
+
self.n_layers = n_layers
|
| 179 |
+
self.dim_embd = dim_embd
|
| 180 |
+
self.dim_mlp = dim_embd*2
|
| 181 |
+
|
| 182 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
| 183 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
| 184 |
+
|
| 185 |
+
# transformer
|
| 186 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
| 187 |
+
for _ in range(self.n_layers)])
|
| 188 |
+
|
| 189 |
+
# logits_predict head
|
| 190 |
+
self.idx_pred_layer = nn.Sequential(
|
| 191 |
+
nn.LayerNorm(dim_embd),
|
| 192 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
| 193 |
+
|
| 194 |
+
self.channels = {
|
| 195 |
+
'16': 512,
|
| 196 |
+
'32': 256,
|
| 197 |
+
'64': 256,
|
| 198 |
+
'128': 128,
|
| 199 |
+
'256': 128,
|
| 200 |
+
'512': 64,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
# after second residual block for > 16, before attn layer for ==16
|
| 204 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
| 205 |
+
# after first residual block for > 16, before attn layer for ==16
|
| 206 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
| 207 |
+
|
| 208 |
+
# fuse_convs_dict
|
| 209 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
| 210 |
+
for f_size in self.connect_list:
|
| 211 |
+
in_ch = self.channels[f_size]
|
| 212 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
| 213 |
+
|
| 214 |
+
def _init_weights(self, module):
|
| 215 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 216 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 217 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 218 |
+
module.bias.data.zero_()
|
| 219 |
+
elif isinstance(module, nn.LayerNorm):
|
| 220 |
+
module.bias.data.zero_()
|
| 221 |
+
module.weight.data.fill_(1.0)
|
| 222 |
+
|
| 223 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
| 224 |
+
# ################### Encoder #####################
|
| 225 |
+
enc_feat_dict = {}
|
| 226 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
| 227 |
+
for i, block in enumerate(self.encoder.blocks):
|
| 228 |
+
x = block(x)
|
| 229 |
+
if i in out_list:
|
| 230 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
| 231 |
+
|
| 232 |
+
lq_feat = x
|
| 233 |
+
# ################# Transformer ###################
|
| 234 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
| 235 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
| 236 |
+
# BCHW -> BC(HW) -> (HW)BC
|
| 237 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
| 238 |
+
query_emb = feat_emb
|
| 239 |
+
# Transformer encoder
|
| 240 |
+
for layer in self.ft_layers:
|
| 241 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
| 242 |
+
|
| 243 |
+
# output logits
|
| 244 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
| 245 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
| 246 |
+
|
| 247 |
+
if code_only: # for training stage II
|
| 248 |
+
# logits doesn't need softmax before cross_entropy loss
|
| 249 |
+
return logits, lq_feat
|
| 250 |
+
|
| 251 |
+
# ################# Quantization ###################
|
| 252 |
+
# if self.training:
|
| 253 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
| 254 |
+
# # b(hw)c -> bc(hw) -> bchw
|
| 255 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
| 256 |
+
# ------------
|
| 257 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
| 258 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
| 259 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
| 260 |
+
# preserve gradients
|
| 261 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
| 262 |
+
|
| 263 |
+
if detach_16:
|
| 264 |
+
quant_feat = quant_feat.detach() # for training stage III
|
| 265 |
+
if adain:
|
| 266 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
| 267 |
+
|
| 268 |
+
# ################## Generator ####################
|
| 269 |
+
x = quant_feat
|
| 270 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
| 271 |
+
|
| 272 |
+
for i, block in enumerate(self.generator.blocks):
|
| 273 |
+
x = block(x)
|
| 274 |
+
if i in fuse_list: # fuse after i-th block
|
| 275 |
+
f_size = str(x.shape[-1])
|
| 276 |
+
if w>0:
|
| 277 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
| 278 |
+
out = x
|
| 279 |
+
# logits doesn't need softmax before cross_entropy loss
|
| 280 |
+
return out, logits, lq_feat
|
basicsr/archs/rrdbnet_arch.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 6 |
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ResidualDenseBlock(nn.Module):
|
| 10 |
+
"""Residual Dense Block.
|
| 11 |
+
|
| 12 |
+
Used in RRDB block in ESRGAN.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
num_feat (int): Channel number of intermediate features.
|
| 16 |
+
num_grow_ch (int): Channels for each growth.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
| 20 |
+
super(ResidualDenseBlock, self).__init__()
|
| 21 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
| 22 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 23 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 24 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
| 25 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
| 26 |
+
|
| 27 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 28 |
+
|
| 29 |
+
# initialization
|
| 30 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
x1 = self.lrelu(self.conv1(x))
|
| 34 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
| 35 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
| 36 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
| 37 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 38 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
| 39 |
+
return x5 * 0.2 + x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class RRDB(nn.Module):
|
| 43 |
+
"""Residual in Residual Dense Block.
|
| 44 |
+
|
| 45 |
+
Used in RRDB-Net in ESRGAN.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
num_feat (int): Channel number of intermediate features.
|
| 49 |
+
num_grow_ch (int): Channels for each growth.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
| 53 |
+
super(RRDB, self).__init__()
|
| 54 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 55 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 56 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
out = self.rdb1(x)
|
| 60 |
+
out = self.rdb2(out)
|
| 61 |
+
out = self.rdb3(out)
|
| 62 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
| 63 |
+
return out * 0.2 + x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@ARCH_REGISTRY.register()
|
| 67 |
+
class RRDBNet(nn.Module):
|
| 68 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
| 69 |
+
in ESRGAN.
|
| 70 |
+
|
| 71 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
| 72 |
+
|
| 73 |
+
We extend ESRGAN for scale x2 and scale x1.
|
| 74 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
| 75 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
| 76 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
num_in_ch (int): Channel number of inputs.
|
| 80 |
+
num_out_ch (int): Channel number of outputs.
|
| 81 |
+
num_feat (int): Channel number of intermediate features.
|
| 82 |
+
Default: 64
|
| 83 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
| 84 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
| 88 |
+
super(RRDBNet, self).__init__()
|
| 89 |
+
self.scale = scale
|
| 90 |
+
if scale == 2:
|
| 91 |
+
num_in_ch = num_in_ch * 4
|
| 92 |
+
elif scale == 1:
|
| 93 |
+
num_in_ch = num_in_ch * 16
|
| 94 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
| 95 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
| 96 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 97 |
+
# upsample
|
| 98 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 99 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
| 101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
| 102 |
+
|
| 103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
if self.scale == 2:
|
| 107 |
+
feat = pixel_unshuffle(x, scale=2)
|
| 108 |
+
elif self.scale == 1:
|
| 109 |
+
feat = pixel_unshuffle(x, scale=4)
|
| 110 |
+
else:
|
| 111 |
+
feat = x
|
| 112 |
+
feat = self.conv_first(feat)
|
| 113 |
+
body_feat = self.conv_body(self.body(feat))
|
| 114 |
+
feat = feat + body_feat
|
| 115 |
+
# upsample
|
| 116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
| 118 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
| 119 |
+
return out
|
basicsr/archs/vgg_arch.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
from torch import nn as nn
|
| 5 |
+
from torchvision.models import vgg as vgg
|
| 6 |
+
|
| 7 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 8 |
+
|
| 9 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
| 10 |
+
NAMES = {
|
| 11 |
+
'vgg11': [
|
| 12 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
| 13 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
| 14 |
+
'pool5'
|
| 15 |
+
],
|
| 16 |
+
'vgg13': [
|
| 17 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
| 19 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
| 20 |
+
],
|
| 21 |
+
'vgg16': [
|
| 22 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 23 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
| 24 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
| 25 |
+
'pool5'
|
| 26 |
+
],
|
| 27 |
+
'vgg19': [
|
| 28 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
| 29 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
| 30 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
| 31 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
| 32 |
+
]
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def insert_bn(names):
|
| 37 |
+
"""Insert bn layer after each conv.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
names (list): The list of layer names.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
list: The list of layer names with bn layers.
|
| 44 |
+
"""
|
| 45 |
+
names_bn = []
|
| 46 |
+
for name in names:
|
| 47 |
+
names_bn.append(name)
|
| 48 |
+
if 'conv' in name:
|
| 49 |
+
position = name.replace('conv', '')
|
| 50 |
+
names_bn.append('bn' + position)
|
| 51 |
+
return names_bn
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@ARCH_REGISTRY.register()
|
| 55 |
+
class VGGFeatureExtractor(nn.Module):
|
| 56 |
+
"""VGG network for feature extraction.
|
| 57 |
+
|
| 58 |
+
In this implementation, we allow users to choose whether use normalization
|
| 59 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
| 60 |
+
path must fit the vgg type.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
| 64 |
+
features according to the layer_name_list.
|
| 65 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
| 66 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
| 67 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
| 68 |
+
the input feature must in the range [0, 1]. Default: True.
|
| 69 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
| 70 |
+
Default: False.
|
| 71 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
| 72 |
+
optimized. Default: False.
|
| 73 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
| 74 |
+
will be removed. Default: False.
|
| 75 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self,
|
| 79 |
+
layer_name_list,
|
| 80 |
+
vgg_type='vgg19',
|
| 81 |
+
use_input_norm=True,
|
| 82 |
+
range_norm=False,
|
| 83 |
+
requires_grad=False,
|
| 84 |
+
remove_pooling=False,
|
| 85 |
+
pooling_stride=2):
|
| 86 |
+
super(VGGFeatureExtractor, self).__init__()
|
| 87 |
+
|
| 88 |
+
self.layer_name_list = layer_name_list
|
| 89 |
+
self.use_input_norm = use_input_norm
|
| 90 |
+
self.range_norm = range_norm
|
| 91 |
+
|
| 92 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
| 93 |
+
if 'bn' in vgg_type:
|
| 94 |
+
self.names = insert_bn(self.names)
|
| 95 |
+
|
| 96 |
+
# only borrow layers that will be used to avoid unused params
|
| 97 |
+
max_idx = 0
|
| 98 |
+
for v in layer_name_list:
|
| 99 |
+
idx = self.names.index(v)
|
| 100 |
+
if idx > max_idx:
|
| 101 |
+
max_idx = idx
|
| 102 |
+
|
| 103 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
| 104 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
| 105 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
| 106 |
+
vgg_net.load_state_dict(state_dict)
|
| 107 |
+
else:
|
| 108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
| 109 |
+
|
| 110 |
+
features = vgg_net.features[:max_idx + 1]
|
| 111 |
+
|
| 112 |
+
modified_net = OrderedDict()
|
| 113 |
+
for k, v in zip(self.names, features):
|
| 114 |
+
if 'pool' in k:
|
| 115 |
+
# if remove_pooling is true, pooling operation will be removed
|
| 116 |
+
if remove_pooling:
|
| 117 |
+
continue
|
| 118 |
+
else:
|
| 119 |
+
# in some cases, we may want to change the default stride
|
| 120 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
| 121 |
+
else:
|
| 122 |
+
modified_net[k] = v
|
| 123 |
+
|
| 124 |
+
self.vgg_net = nn.Sequential(modified_net)
|
| 125 |
+
|
| 126 |
+
if not requires_grad:
|
| 127 |
+
self.vgg_net.eval()
|
| 128 |
+
for param in self.parameters():
|
| 129 |
+
param.requires_grad = False
|
| 130 |
+
else:
|
| 131 |
+
self.vgg_net.train()
|
| 132 |
+
for param in self.parameters():
|
| 133 |
+
param.requires_grad = True
|
| 134 |
+
|
| 135 |
+
if self.use_input_norm:
|
| 136 |
+
# the mean is for image with range [0, 1]
|
| 137 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 138 |
+
# the std is for image with range [0, 1]
|
| 139 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
"""Forward function.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Tensor: Forward results.
|
| 149 |
+
"""
|
| 150 |
+
if self.range_norm:
|
| 151 |
+
x = (x + 1) / 2
|
| 152 |
+
if self.use_input_norm:
|
| 153 |
+
x = (x - self.mean) / self.std
|
| 154 |
+
output = {}
|
| 155 |
+
|
| 156 |
+
for key, layer in self.vgg_net._modules.items():
|
| 157 |
+
x = layer(x)
|
| 158 |
+
if key in self.layer_name_list:
|
| 159 |
+
output[key] = x.clone()
|
| 160 |
+
|
| 161 |
+
return output
|
basicsr/archs/vqgan_arch.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
| 3 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
| 4 |
+
|
| 5 |
+
'''
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import copy
|
| 11 |
+
from basicsr.utils import get_root_logger
|
| 12 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 13 |
+
|
| 14 |
+
def normalize(in_channels):
|
| 15 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.jit.script
|
| 19 |
+
def swish(x):
|
| 20 |
+
return x*torch.sigmoid(x)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Define VQVAE classes
|
| 24 |
+
class VectorQuantizer(nn.Module):
|
| 25 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
| 26 |
+
super(VectorQuantizer, self).__init__()
|
| 27 |
+
self.codebook_size = codebook_size # number of embeddings
|
| 28 |
+
self.emb_dim = emb_dim # dimension of embedding
|
| 29 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
| 30 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
| 31 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
| 32 |
+
|
| 33 |
+
def forward(self, z):
|
| 34 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 35 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
| 36 |
+
z_flattened = z.view(-1, self.emb_dim)
|
| 37 |
+
|
| 38 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 39 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
| 40 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
| 41 |
+
|
| 42 |
+
mean_distance = torch.mean(d)
|
| 43 |
+
# find closest encodings
|
| 44 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
| 45 |
+
# min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
| 46 |
+
# [0-1], higher score, higher confidence
|
| 47 |
+
# min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
| 48 |
+
|
| 49 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
| 50 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
| 51 |
+
|
| 52 |
+
# get quantized latent vectors
|
| 53 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
| 54 |
+
# compute loss for embedding
|
| 55 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
| 56 |
+
# preserve gradients
|
| 57 |
+
z_q = z + (z_q - z).detach()
|
| 58 |
+
|
| 59 |
+
# perplexity
|
| 60 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
| 61 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
| 62 |
+
# reshape back to match original input shape
|
| 63 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 64 |
+
|
| 65 |
+
return z_q, loss, {
|
| 66 |
+
"perplexity": perplexity,
|
| 67 |
+
"min_encodings": min_encodings,
|
| 68 |
+
"min_encoding_indices": min_encoding_indices,
|
| 69 |
+
"mean_distance": mean_distance
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_codebook_feat(self, indices, shape):
|
| 73 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
| 74 |
+
# shape: batch, height, width, channel
|
| 75 |
+
indices = indices.view(-1,1)
|
| 76 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
| 77 |
+
min_encodings.scatter_(1, indices, 1)
|
| 78 |
+
# get quantized latent vectors
|
| 79 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
| 80 |
+
|
| 81 |
+
if shape is not None: # reshape back to match original input shape
|
| 82 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
| 83 |
+
|
| 84 |
+
return z_q
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class GumbelQuantizer(nn.Module):
|
| 88 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.codebook_size = codebook_size # number of embeddings
|
| 91 |
+
self.emb_dim = emb_dim # dimension of embedding
|
| 92 |
+
self.straight_through = straight_through
|
| 93 |
+
self.temperature = temp_init
|
| 94 |
+
self.kl_weight = kl_weight
|
| 95 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
| 96 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
| 97 |
+
|
| 98 |
+
def forward(self, z):
|
| 99 |
+
hard = self.straight_through if self.training else True
|
| 100 |
+
|
| 101 |
+
logits = self.proj(z)
|
| 102 |
+
|
| 103 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
| 104 |
+
|
| 105 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
| 106 |
+
|
| 107 |
+
# + kl divergence to the prior loss
|
| 108 |
+
qy = F.softmax(logits, dim=1)
|
| 109 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
| 110 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
| 111 |
+
|
| 112 |
+
return z_q, diff, {
|
| 113 |
+
"min_encoding_indices": min_encoding_indices
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Downsample(nn.Module):
|
| 118 |
+
def __init__(self, in_channels):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
pad = (0, 1, 0, 1)
|
| 124 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 125 |
+
x = self.conv(x)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Upsample(nn.Module):
|
| 130 |
+
def __init__(self, in_channels):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 136 |
+
x = self.conv(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ResBlock(nn.Module):
|
| 142 |
+
def __init__(self, in_channels, out_channels=None):
|
| 143 |
+
super(ResBlock, self).__init__()
|
| 144 |
+
self.in_channels = in_channels
|
| 145 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 146 |
+
self.norm1 = normalize(in_channels)
|
| 147 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 148 |
+
self.norm2 = normalize(out_channels)
|
| 149 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 150 |
+
if self.in_channels != self.out_channels:
|
| 151 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 152 |
+
|
| 153 |
+
def forward(self, x_in):
|
| 154 |
+
x = x_in
|
| 155 |
+
x = self.norm1(x)
|
| 156 |
+
x = swish(x)
|
| 157 |
+
x = self.conv1(x)
|
| 158 |
+
x = self.norm2(x)
|
| 159 |
+
x = swish(x)
|
| 160 |
+
x = self.conv2(x)
|
| 161 |
+
if self.in_channels != self.out_channels:
|
| 162 |
+
x_in = self.conv_out(x_in)
|
| 163 |
+
|
| 164 |
+
return x + x_in
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class AttnBlock(nn.Module):
|
| 168 |
+
def __init__(self, in_channels):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.in_channels = in_channels
|
| 171 |
+
|
| 172 |
+
self.norm = normalize(in_channels)
|
| 173 |
+
self.q = torch.nn.Conv2d(
|
| 174 |
+
in_channels,
|
| 175 |
+
in_channels,
|
| 176 |
+
kernel_size=1,
|
| 177 |
+
stride=1,
|
| 178 |
+
padding=0
|
| 179 |
+
)
|
| 180 |
+
self.k = torch.nn.Conv2d(
|
| 181 |
+
in_channels,
|
| 182 |
+
in_channels,
|
| 183 |
+
kernel_size=1,
|
| 184 |
+
stride=1,
|
| 185 |
+
padding=0
|
| 186 |
+
)
|
| 187 |
+
self.v = torch.nn.Conv2d(
|
| 188 |
+
in_channels,
|
| 189 |
+
in_channels,
|
| 190 |
+
kernel_size=1,
|
| 191 |
+
stride=1,
|
| 192 |
+
padding=0
|
| 193 |
+
)
|
| 194 |
+
self.proj_out = torch.nn.Conv2d(
|
| 195 |
+
in_channels,
|
| 196 |
+
in_channels,
|
| 197 |
+
kernel_size=1,
|
| 198 |
+
stride=1,
|
| 199 |
+
padding=0
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def forward(self, x):
|
| 203 |
+
h_ = x
|
| 204 |
+
h_ = self.norm(h_)
|
| 205 |
+
q = self.q(h_)
|
| 206 |
+
k = self.k(h_)
|
| 207 |
+
v = self.v(h_)
|
| 208 |
+
|
| 209 |
+
# compute attention
|
| 210 |
+
b, c, h, w = q.shape
|
| 211 |
+
q = q.reshape(b, c, h*w)
|
| 212 |
+
q = q.permute(0, 2, 1)
|
| 213 |
+
k = k.reshape(b, c, h*w)
|
| 214 |
+
w_ = torch.bmm(q, k)
|
| 215 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 216 |
+
w_ = F.softmax(w_, dim=2)
|
| 217 |
+
|
| 218 |
+
# attend to values
|
| 219 |
+
v = v.reshape(b, c, h*w)
|
| 220 |
+
w_ = w_.permute(0, 2, 1)
|
| 221 |
+
h_ = torch.bmm(v, w_)
|
| 222 |
+
h_ = h_.reshape(b, c, h, w)
|
| 223 |
+
|
| 224 |
+
h_ = self.proj_out(h_)
|
| 225 |
+
|
| 226 |
+
return x+h_
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class Encoder(nn.Module):
|
| 230 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
| 231 |
+
super().__init__()
|
| 232 |
+
self.nf = nf
|
| 233 |
+
self.num_resolutions = len(ch_mult)
|
| 234 |
+
self.num_res_blocks = num_res_blocks
|
| 235 |
+
self.resolution = resolution
|
| 236 |
+
self.attn_resolutions = attn_resolutions
|
| 237 |
+
|
| 238 |
+
curr_res = self.resolution
|
| 239 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 240 |
+
|
| 241 |
+
blocks = []
|
| 242 |
+
# initial convultion
|
| 243 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
| 244 |
+
|
| 245 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
| 246 |
+
for i in range(self.num_resolutions):
|
| 247 |
+
block_in_ch = nf * in_ch_mult[i]
|
| 248 |
+
block_out_ch = nf * ch_mult[i]
|
| 249 |
+
for _ in range(self.num_res_blocks):
|
| 250 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
| 251 |
+
block_in_ch = block_out_ch
|
| 252 |
+
if curr_res in attn_resolutions:
|
| 253 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 254 |
+
|
| 255 |
+
if i != self.num_resolutions - 1:
|
| 256 |
+
blocks.append(Downsample(block_in_ch))
|
| 257 |
+
curr_res = curr_res // 2
|
| 258 |
+
|
| 259 |
+
# non-local attention block
|
| 260 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 261 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 262 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 263 |
+
|
| 264 |
+
# normalise and convert to latent size
|
| 265 |
+
blocks.append(normalize(block_in_ch))
|
| 266 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
| 267 |
+
self.blocks = nn.ModuleList(blocks)
|
| 268 |
+
|
| 269 |
+
def forward(self, x):
|
| 270 |
+
for block in self.blocks:
|
| 271 |
+
x = block(x)
|
| 272 |
+
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class Generator(nn.Module):
|
| 277 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
| 278 |
+
super().__init__()
|
| 279 |
+
self.nf = nf
|
| 280 |
+
self.ch_mult = ch_mult
|
| 281 |
+
self.num_resolutions = len(self.ch_mult)
|
| 282 |
+
self.num_res_blocks = res_blocks
|
| 283 |
+
self.resolution = img_size
|
| 284 |
+
self.attn_resolutions = attn_resolutions
|
| 285 |
+
self.in_channels = emb_dim
|
| 286 |
+
self.out_channels = 3
|
| 287 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
| 288 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
| 289 |
+
|
| 290 |
+
blocks = []
|
| 291 |
+
# initial conv
|
| 292 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
| 293 |
+
|
| 294 |
+
# non-local attention block
|
| 295 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 296 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 297 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
| 298 |
+
|
| 299 |
+
for i in reversed(range(self.num_resolutions)):
|
| 300 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
| 301 |
+
|
| 302 |
+
for _ in range(self.num_res_blocks):
|
| 303 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
| 304 |
+
block_in_ch = block_out_ch
|
| 305 |
+
|
| 306 |
+
if curr_res in self.attn_resolutions:
|
| 307 |
+
blocks.append(AttnBlock(block_in_ch))
|
| 308 |
+
|
| 309 |
+
if i != 0:
|
| 310 |
+
blocks.append(Upsample(block_in_ch))
|
| 311 |
+
curr_res = curr_res * 2
|
| 312 |
+
|
| 313 |
+
blocks.append(normalize(block_in_ch))
|
| 314 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
| 315 |
+
|
| 316 |
+
self.blocks = nn.ModuleList(blocks)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def forward(self, x):
|
| 320 |
+
for block in self.blocks:
|
| 321 |
+
x = block(x)
|
| 322 |
+
|
| 323 |
+
return x
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@ARCH_REGISTRY.register()
|
| 327 |
+
class VQAutoEncoder(nn.Module):
|
| 328 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
| 329 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
| 330 |
+
super().__init__()
|
| 331 |
+
logger = get_root_logger()
|
| 332 |
+
self.in_channels = 3
|
| 333 |
+
self.nf = nf
|
| 334 |
+
self.n_blocks = res_blocks
|
| 335 |
+
self.codebook_size = codebook_size
|
| 336 |
+
self.embed_dim = emb_dim
|
| 337 |
+
self.ch_mult = ch_mult
|
| 338 |
+
self.resolution = img_size
|
| 339 |
+
self.attn_resolutions = attn_resolutions
|
| 340 |
+
self.quantizer_type = quantizer
|
| 341 |
+
self.encoder = Encoder(
|
| 342 |
+
self.in_channels,
|
| 343 |
+
self.nf,
|
| 344 |
+
self.embed_dim,
|
| 345 |
+
self.ch_mult,
|
| 346 |
+
self.n_blocks,
|
| 347 |
+
self.resolution,
|
| 348 |
+
self.attn_resolutions
|
| 349 |
+
)
|
| 350 |
+
if self.quantizer_type == "nearest":
|
| 351 |
+
self.beta = beta #0.25
|
| 352 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
| 353 |
+
elif self.quantizer_type == "gumbel":
|
| 354 |
+
self.gumbel_num_hiddens = emb_dim
|
| 355 |
+
self.straight_through = gumbel_straight_through
|
| 356 |
+
self.kl_weight = gumbel_kl_weight
|
| 357 |
+
self.quantize = GumbelQuantizer(
|
| 358 |
+
self.codebook_size,
|
| 359 |
+
self.embed_dim,
|
| 360 |
+
self.gumbel_num_hiddens,
|
| 361 |
+
self.straight_through,
|
| 362 |
+
self.kl_weight
|
| 363 |
+
)
|
| 364 |
+
self.generator = Generator(
|
| 365 |
+
self.nf,
|
| 366 |
+
self.embed_dim,
|
| 367 |
+
self.ch_mult,
|
| 368 |
+
self.n_blocks,
|
| 369 |
+
self.resolution,
|
| 370 |
+
self.attn_resolutions
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if model_path is not None:
|
| 374 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
| 375 |
+
if 'params_ema' in chkpt:
|
| 376 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
| 377 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
| 378 |
+
elif 'params' in chkpt:
|
| 379 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
| 380 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
| 381 |
+
else:
|
| 382 |
+
raise ValueError(f'Wrong params!')
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
x = self.encoder(x)
|
| 387 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
| 388 |
+
x = self.generator(quant)
|
| 389 |
+
return x, codebook_loss, quant_stats
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# patch based discriminator
|
| 394 |
+
@ARCH_REGISTRY.register()
|
| 395 |
+
class VQGANDiscriminator(nn.Module):
|
| 396 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
| 397 |
+
super().__init__()
|
| 398 |
+
|
| 399 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
| 400 |
+
ndf_mult = 1
|
| 401 |
+
ndf_mult_prev = 1
|
| 402 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 403 |
+
ndf_mult_prev = ndf_mult
|
| 404 |
+
ndf_mult = min(2 ** n, 8)
|
| 405 |
+
layers += [
|
| 406 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
| 407 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
| 408 |
+
nn.LeakyReLU(0.2, True)
|
| 409 |
+
]
|
| 410 |
+
|
| 411 |
+
ndf_mult_prev = ndf_mult
|
| 412 |
+
ndf_mult = min(2 ** n_layers, 8)
|
| 413 |
+
|
| 414 |
+
layers += [
|
| 415 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
| 416 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
| 417 |
+
nn.LeakyReLU(0.2, True)
|
| 418 |
+
]
|
| 419 |
+
|
| 420 |
+
layers += [
|
| 421 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
| 422 |
+
self.main = nn.Sequential(*layers)
|
| 423 |
+
|
| 424 |
+
if model_path is not None:
|
| 425 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
| 426 |
+
if 'params_d' in chkpt:
|
| 427 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
| 428 |
+
elif 'params' in chkpt:
|
| 429 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
| 430 |
+
else:
|
| 431 |
+
raise ValueError(f'Wrong params!')
|
| 432 |
+
|
| 433 |
+
def forward(self, x):
|
| 434 |
+
return self.main(x)
|
basicsr/data/__init__.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.data
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from functools import partial
|
| 8 |
+
from os import path as osp
|
| 9 |
+
|
| 10 |
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
| 11 |
+
from basicsr.utils import get_root_logger, scandir
|
| 12 |
+
from basicsr.utils.dist_util import get_dist_info
|
| 13 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 14 |
+
|
| 15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
| 16 |
+
|
| 17 |
+
# automatically scan and import dataset modules for registry
|
| 18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
| 19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
| 20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
| 21 |
+
# import all the dataset modules
|
| 22 |
+
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_dataset(dataset_opt):
|
| 26 |
+
"""Build dataset from options.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dataset_opt (dict): Configuration for dataset. It must constain:
|
| 30 |
+
name (str): Dataset name.
|
| 31 |
+
type (str): Dataset type.
|
| 32 |
+
"""
|
| 33 |
+
dataset_opt = deepcopy(dataset_opt)
|
| 34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
| 35 |
+
logger = get_root_logger()
|
| 36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
|
| 37 |
+
return dataset
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
| 41 |
+
"""Build dataloader.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
| 45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
| 46 |
+
phase (str): 'train' or 'val'.
|
| 47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
| 48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
| 49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
| 50 |
+
Default: 1.
|
| 51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
| 52 |
+
phase. Default: False.
|
| 53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
| 54 |
+
seed (int | None): Seed. Default: None
|
| 55 |
+
"""
|
| 56 |
+
phase = dataset_opt['phase']
|
| 57 |
+
rank, _ = get_dist_info()
|
| 58 |
+
if phase == 'train':
|
| 59 |
+
if dist: # distributed training
|
| 60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
| 61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
| 62 |
+
else: # non-distributed training
|
| 63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
| 64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
| 65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
| 66 |
+
dataloader_args = dict(
|
| 67 |
+
dataset=dataset,
|
| 68 |
+
batch_size=batch_size,
|
| 69 |
+
shuffle=False,
|
| 70 |
+
num_workers=num_workers,
|
| 71 |
+
sampler=sampler,
|
| 72 |
+
drop_last=True)
|
| 73 |
+
if sampler is None:
|
| 74 |
+
dataloader_args['shuffle'] = True
|
| 75 |
+
dataloader_args['worker_init_fn'] = partial(
|
| 76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
| 77 |
+
elif phase in ['val', 'test']: # validation
|
| 78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
|
| 81 |
+
|
| 82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
| 83 |
+
|
| 84 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
| 85 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
| 86 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
| 87 |
+
logger = get_root_logger()
|
| 88 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
|
| 89 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
| 90 |
+
else:
|
| 91 |
+
# prefetch_mode=None: Normal dataloader
|
| 92 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
| 93 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
| 97 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
| 98 |
+
worker_seed = num_workers * rank + worker_id + seed
|
| 99 |
+
np.random.seed(worker_seed)
|
| 100 |
+
random.seed(worker_seed)
|
basicsr/data/data_sampler.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data.sampler import Sampler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EnlargedSampler(Sampler):
|
| 7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 8 |
+
|
| 9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
| 10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
| 11 |
+
time when restart the dataloader after each epoch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
| 15 |
+
num_replicas (int | None): Number of processes participating in
|
| 16 |
+
the training. It is usually the world_size.
|
| 17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
| 18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
| 22 |
+
self.dataset = dataset
|
| 23 |
+
self.num_replicas = num_replicas
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.epoch = 0
|
| 26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
| 27 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 28 |
+
|
| 29 |
+
def __iter__(self):
|
| 30 |
+
# deterministically shuffle based on epoch
|
| 31 |
+
g = torch.Generator()
|
| 32 |
+
g.manual_seed(self.epoch)
|
| 33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
| 34 |
+
|
| 35 |
+
dataset_size = len(self.dataset)
|
| 36 |
+
indices = [v % dataset_size for v in indices]
|
| 37 |
+
|
| 38 |
+
# subsample
|
| 39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 40 |
+
assert len(indices) == self.num_samples
|
| 41 |
+
|
| 42 |
+
return iter(indices)
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return self.num_samples
|
| 46 |
+
|
| 47 |
+
def set_epoch(self, epoch):
|
| 48 |
+
self.epoch = epoch
|
basicsr/data/data_util.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from os import path as osp
|
| 6 |
+
from PIL import Image, ImageDraw
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
from basicsr.data.transforms import mod_crop
|
| 10 |
+
from basicsr.utils import img2tensor, scandir
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def read_img_seq(path, require_mod_crop=False, scale=1):
|
| 14 |
+
"""Read a sequence of images from a given folder path.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
path (list[str] | str): List of image paths or image folder path.
|
| 18 |
+
require_mod_crop (bool): Require mod crop for each image.
|
| 19 |
+
Default: False.
|
| 20 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
| 24 |
+
"""
|
| 25 |
+
if isinstance(path, list):
|
| 26 |
+
img_paths = path
|
| 27 |
+
else:
|
| 28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
| 29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
| 30 |
+
if require_mod_crop:
|
| 31 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
| 32 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
| 33 |
+
imgs = torch.stack(imgs, dim=0)
|
| 34 |
+
return imgs
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
| 38 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
| 39 |
+
of images.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
crt_idx (int): Current center index.
|
| 43 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
| 44 |
+
num_frames (int): Reading num_frames frames.
|
| 45 |
+
padding (str): Padding mode, one of
|
| 46 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
| 47 |
+
Examples: current_idx = 0, num_frames = 5
|
| 48 |
+
The generated frame indices under different padding mode:
|
| 49 |
+
replicate: [0, 0, 0, 1, 2]
|
| 50 |
+
reflection: [2, 1, 0, 1, 2]
|
| 51 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
| 52 |
+
circle: [3, 4, 0, 1, 2]
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
list[int]: A list of indices.
|
| 56 |
+
"""
|
| 57 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
| 58 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
| 59 |
+
|
| 60 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
| 61 |
+
num_pad = num_frames // 2
|
| 62 |
+
|
| 63 |
+
indices = []
|
| 64 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
| 65 |
+
if i < 0:
|
| 66 |
+
if padding == 'replicate':
|
| 67 |
+
pad_idx = 0
|
| 68 |
+
elif padding == 'reflection':
|
| 69 |
+
pad_idx = -i
|
| 70 |
+
elif padding == 'reflection_circle':
|
| 71 |
+
pad_idx = crt_idx + num_pad - i
|
| 72 |
+
else:
|
| 73 |
+
pad_idx = num_frames + i
|
| 74 |
+
elif i > max_frame_num:
|
| 75 |
+
if padding == 'replicate':
|
| 76 |
+
pad_idx = max_frame_num
|
| 77 |
+
elif padding == 'reflection':
|
| 78 |
+
pad_idx = max_frame_num * 2 - i
|
| 79 |
+
elif padding == 'reflection_circle':
|
| 80 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
| 81 |
+
else:
|
| 82 |
+
pad_idx = i - num_frames
|
| 83 |
+
else:
|
| 84 |
+
pad_idx = i
|
| 85 |
+
indices.append(pad_idx)
|
| 86 |
+
return indices
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def paired_paths_from_lmdb(folders, keys):
|
| 90 |
+
"""Generate paired paths from lmdb files.
|
| 91 |
+
|
| 92 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
| 93 |
+
|
| 94 |
+
lq.lmdb
|
| 95 |
+
├── data.mdb
|
| 96 |
+
├── lock.mdb
|
| 97 |
+
├── meta_info.txt
|
| 98 |
+
|
| 99 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
| 100 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
| 101 |
+
|
| 102 |
+
The meta_info.txt is a specified txt file to record the meta information
|
| 103 |
+
of our datasets. It will be automatically created when preparing
|
| 104 |
+
datasets by our provided dataset tools.
|
| 105 |
+
Each line in the txt file records
|
| 106 |
+
1)image name (with extension),
|
| 107 |
+
2)image shape,
|
| 108 |
+
3)compression level, separated by a white space.
|
| 109 |
+
Example: `baboon.png (120,125,3) 1`
|
| 110 |
+
|
| 111 |
+
We use the image name without extension as the lmdb key.
|
| 112 |
+
Note that we use the same key for the corresponding lq and gt images.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 116 |
+
be [input_folder, gt_folder].
|
| 117 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 118 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 119 |
+
Note that this key is different from lmdb keys.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
list[str]: Returned path list.
|
| 123 |
+
"""
|
| 124 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 125 |
+
f'But got {len(folders)}')
|
| 126 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
| 127 |
+
input_folder, gt_folder = folders
|
| 128 |
+
input_key, gt_key = keys
|
| 129 |
+
|
| 130 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
| 131 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
| 132 |
+
f'formats. But received {input_key}: {input_folder}; '
|
| 133 |
+
f'{gt_key}: {gt_folder}')
|
| 134 |
+
# ensure that the two meta_info files are the same
|
| 135 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
| 136 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 137 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
| 138 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
| 139 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
| 140 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
| 141 |
+
else:
|
| 142 |
+
paths = []
|
| 143 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
| 144 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
| 145 |
+
return paths
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
| 149 |
+
"""Generate paired paths from an meta information file.
|
| 150 |
+
|
| 151 |
+
Each line in the meta information file contains the image names and
|
| 152 |
+
image shape (usually for gt), separated by a white space.
|
| 153 |
+
|
| 154 |
+
Example of an meta information file:
|
| 155 |
+
```
|
| 156 |
+
0001_s001.png (480,480,3)
|
| 157 |
+
0001_s002.png (480,480,3)
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 162 |
+
be [input_folder, gt_folder].
|
| 163 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 164 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 165 |
+
meta_info_file (str): Path to the meta information file.
|
| 166 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 167 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 168 |
+
for files in the input folder.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
list[str]: Returned path list.
|
| 172 |
+
"""
|
| 173 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 174 |
+
f'But got {len(folders)}')
|
| 175 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
| 176 |
+
input_folder, gt_folder = folders
|
| 177 |
+
input_key, gt_key = keys
|
| 178 |
+
|
| 179 |
+
with open(meta_info_file, 'r') as fin:
|
| 180 |
+
gt_names = [line.split(' ')[0] for line in fin]
|
| 181 |
+
|
| 182 |
+
paths = []
|
| 183 |
+
for gt_name in gt_names:
|
| 184 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
| 185 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 186 |
+
input_path = osp.join(input_folder, input_name)
|
| 187 |
+
gt_path = osp.join(gt_folder, gt_name)
|
| 188 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
| 189 |
+
return paths
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
| 193 |
+
"""Generate paired paths from folders.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
folders (list[str]): A list of folder path. The order of list should
|
| 197 |
+
be [input_folder, gt_folder].
|
| 198 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
| 199 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
| 200 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 201 |
+
template excludes the file extension. Usually the filename_tmpl is
|
| 202 |
+
for files in the input folder.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
list[str]: Returned path list.
|
| 206 |
+
"""
|
| 207 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
| 208 |
+
f'But got {len(folders)}')
|
| 209 |
+
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
|
| 210 |
+
input_folder, gt_folder = folders
|
| 211 |
+
input_key, gt_key = keys
|
| 212 |
+
|
| 213 |
+
input_paths = list(scandir(input_folder))
|
| 214 |
+
gt_paths = list(scandir(gt_folder))
|
| 215 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
| 216 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
| 217 |
+
paths = []
|
| 218 |
+
for gt_path in gt_paths:
|
| 219 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
| 220 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
| 221 |
+
input_path = osp.join(input_folder, input_name)
|
| 222 |
+
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
|
| 223 |
+
gt_path = osp.join(gt_folder, gt_path)
|
| 224 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
| 225 |
+
return paths
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def paths_from_folder(folder):
|
| 229 |
+
"""Generate paths from folder.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
folder (str): Folder path.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
list[str]: Returned path list.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
paths = list(scandir(folder))
|
| 239 |
+
paths = [osp.join(folder, path) for path in paths]
|
| 240 |
+
return paths
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def paths_from_lmdb(folder):
|
| 244 |
+
"""Generate paths from lmdb.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
folder (str): Folder path.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
list[str]: Returned path list.
|
| 251 |
+
"""
|
| 252 |
+
if not folder.endswith('.lmdb'):
|
| 253 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
| 254 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
| 255 |
+
paths = [line.split('.')[0] for line in fin]
|
| 256 |
+
return paths
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
| 260 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 264 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
np.array: The Gaussian kernel.
|
| 268 |
+
"""
|
| 269 |
+
from scipy.ndimage import filters as filters
|
| 270 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
| 271 |
+
# set element at the middle to one, a dirac delta
|
| 272 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
| 273 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
| 274 |
+
return filters.gaussian_filter(kernel, sigma)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
| 278 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
| 282 |
+
kernel_size (int): Kernel size. Default: 13.
|
| 283 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
| 284 |
+
Default: 4.
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
Tensor: DUF downsampled frames.
|
| 288 |
+
"""
|
| 289 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
| 290 |
+
|
| 291 |
+
squeeze_flag = False
|
| 292 |
+
if x.ndim == 4:
|
| 293 |
+
squeeze_flag = True
|
| 294 |
+
x = x.unsqueeze(0)
|
| 295 |
+
b, t, c, h, w = x.size()
|
| 296 |
+
x = x.view(-1, 1, h, w)
|
| 297 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
| 298 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
| 299 |
+
|
| 300 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
| 301 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
| 302 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
| 303 |
+
x = x[:, :, 2:-2, 2:-2]
|
| 304 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
| 305 |
+
if squeeze_flag:
|
| 306 |
+
x = x.squeeze(0)
|
| 307 |
+
return x
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def brush_stroke_mask(img, color=(255,255,255)):
|
| 311 |
+
min_num_vertex = 8
|
| 312 |
+
max_num_vertex = 28
|
| 313 |
+
mean_angle = 2*math.pi / 5
|
| 314 |
+
angle_range = 2*math.pi / 12
|
| 315 |
+
# training large mask ratio (training setting)
|
| 316 |
+
min_width = 30
|
| 317 |
+
max_width = 70
|
| 318 |
+
# very large mask ratio (test setting and refine after 200k)
|
| 319 |
+
# min_width = 80
|
| 320 |
+
# max_width = 120
|
| 321 |
+
def generate_mask(H, W, img=None):
|
| 322 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
| 323 |
+
mask = Image.new('RGB', (W, H), 0)
|
| 324 |
+
if img is not None: mask = img # Image.fromarray(img)
|
| 325 |
+
|
| 326 |
+
for _ in range(np.random.randint(1, 4)):
|
| 327 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
| 328 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
| 329 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
| 330 |
+
angles = []
|
| 331 |
+
vertex = []
|
| 332 |
+
for i in range(num_vertex):
|
| 333 |
+
if i % 2 == 0:
|
| 334 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
| 335 |
+
else:
|
| 336 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
| 337 |
+
|
| 338 |
+
h, w = mask.size
|
| 339 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
| 340 |
+
for i in range(num_vertex):
|
| 341 |
+
r = np.clip(
|
| 342 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
| 343 |
+
0, 2*average_radius)
|
| 344 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
| 345 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
| 346 |
+
vertex.append((int(new_x), int(new_y)))
|
| 347 |
+
|
| 348 |
+
draw = ImageDraw.Draw(mask)
|
| 349 |
+
width = int(np.random.uniform(min_width, max_width))
|
| 350 |
+
draw.line(vertex, fill=color, width=width)
|
| 351 |
+
for v in vertex:
|
| 352 |
+
draw.ellipse((v[0] - width//2,
|
| 353 |
+
v[1] - width//2,
|
| 354 |
+
v[0] + width//2,
|
| 355 |
+
v[1] + width//2),
|
| 356 |
+
fill=color)
|
| 357 |
+
|
| 358 |
+
return mask
|
| 359 |
+
|
| 360 |
+
width, height = img.size
|
| 361 |
+
mask = generate_mask(height, width, img)
|
| 362 |
+
return mask
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
|
| 366 |
+
"""Generate a random free form mask with configuration.
|
| 367 |
+
Args:
|
| 368 |
+
config: Config should have configuration including IMG_SHAPES,
|
| 369 |
+
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
|
| 370 |
+
Returns:
|
| 371 |
+
tuple: (top, left, height, width)
|
| 372 |
+
Link:
|
| 373 |
+
https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
|
| 374 |
+
"""
|
| 375 |
+
height = shape[0]
|
| 376 |
+
width = shape[1]
|
| 377 |
+
mask = np.zeros((height, width), np.float32)
|
| 378 |
+
times = np.random.randint(times-5, times)
|
| 379 |
+
for i in range(times):
|
| 380 |
+
start_x = np.random.randint(width)
|
| 381 |
+
start_y = np.random.randint(height)
|
| 382 |
+
for j in range(1 + np.random.randint(5)):
|
| 383 |
+
angle = 0.01 + np.random.randint(max_angle)
|
| 384 |
+
if i % 2 == 0:
|
| 385 |
+
angle = 2 * 3.1415926 - angle
|
| 386 |
+
length = 10 + np.random.randint(max_len-20, max_len)
|
| 387 |
+
brush_w = 5 + np.random.randint(max_width-30, max_width)
|
| 388 |
+
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
|
| 389 |
+
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
|
| 390 |
+
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
|
| 391 |
+
start_x, start_y = end_x, end_y
|
| 392 |
+
return mask.astype(np.float32)
|
basicsr/data/ffhq_blind_dataset.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os.path as osp
|
| 6 |
+
from scipy.io import loadmat
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
import torch.utils.data as data
|
| 10 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
|
| 11 |
+
adjust_hue, adjust_saturation, normalize)
|
| 12 |
+
from basicsr.data import gaussian_kernels as gaussian_kernels
|
| 13 |
+
from basicsr.data.transforms import augment
|
| 14 |
+
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
|
| 15 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 16 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 17 |
+
|
| 18 |
+
@DATASET_REGISTRY.register()
|
| 19 |
+
class FFHQBlindDataset(data.Dataset):
|
| 20 |
+
|
| 21 |
+
def __init__(self, opt):
|
| 22 |
+
super(FFHQBlindDataset, self).__init__()
|
| 23 |
+
logger = get_root_logger()
|
| 24 |
+
self.opt = opt
|
| 25 |
+
# file client (io backend)
|
| 26 |
+
self.file_client = None
|
| 27 |
+
self.io_backend_opt = opt['io_backend']
|
| 28 |
+
|
| 29 |
+
self.gt_folder = opt['dataroot_gt']
|
| 30 |
+
self.gt_size = opt.get('gt_size', 512)
|
| 31 |
+
self.in_size = opt.get('in_size', 512)
|
| 32 |
+
assert self.gt_size >= self.in_size, 'Wrong setting.'
|
| 33 |
+
|
| 34 |
+
self.mean = opt.get('mean', [0.5, 0.5, 0.5])
|
| 35 |
+
self.std = opt.get('std', [0.5, 0.5, 0.5])
|
| 36 |
+
|
| 37 |
+
self.component_path = opt.get('component_path', None)
|
| 38 |
+
self.latent_gt_path = opt.get('latent_gt_path', None)
|
| 39 |
+
|
| 40 |
+
if self.component_path is not None:
|
| 41 |
+
self.crop_components = True
|
| 42 |
+
self.components_dict = torch.load(self.component_path)
|
| 43 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
|
| 44 |
+
self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
|
| 45 |
+
self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
|
| 46 |
+
else:
|
| 47 |
+
self.crop_components = False
|
| 48 |
+
|
| 49 |
+
if self.latent_gt_path is not None:
|
| 50 |
+
self.load_latent_gt = True
|
| 51 |
+
self.latent_gt_dict = torch.load(self.latent_gt_path)
|
| 52 |
+
else:
|
| 53 |
+
self.load_latent_gt = False
|
| 54 |
+
|
| 55 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 56 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
| 57 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 58 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
|
| 59 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 60 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 61 |
+
else:
|
| 62 |
+
self.paths = paths_from_folder(self.gt_folder)
|
| 63 |
+
|
| 64 |
+
# inpainting mask
|
| 65 |
+
self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
|
| 66 |
+
if self.gen_inpaint_mask:
|
| 67 |
+
logger.info(f'generate mask ...')
|
| 68 |
+
# self.mask_max_angle = opt.get('mask_max_angle', 10)
|
| 69 |
+
# self.mask_max_len = opt.get('mask_max_len', 150)
|
| 70 |
+
# self.mask_max_width = opt.get('mask_max_width', 50)
|
| 71 |
+
# self.mask_draw_times = opt.get('mask_draw_times', 10)
|
| 72 |
+
# # print
|
| 73 |
+
# logger.info(f'mask_max_angle: {self.mask_max_angle}')
|
| 74 |
+
# logger.info(f'mask_max_len: {self.mask_max_len}')
|
| 75 |
+
# logger.info(f'mask_max_width: {self.mask_max_width}')
|
| 76 |
+
# logger.info(f'mask_draw_times: {self.mask_draw_times}')
|
| 77 |
+
|
| 78 |
+
# perform corrupt
|
| 79 |
+
self.use_corrupt = opt.get('use_corrupt', True)
|
| 80 |
+
self.use_motion_kernel = False
|
| 81 |
+
# self.use_motion_kernel = opt.get('use_motion_kernel', True)
|
| 82 |
+
|
| 83 |
+
if self.use_motion_kernel:
|
| 84 |
+
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
|
| 85 |
+
motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
|
| 86 |
+
self.motion_kernels = torch.load(motion_kernel_path)
|
| 87 |
+
|
| 88 |
+
if self.use_corrupt and not self.gen_inpaint_mask:
|
| 89 |
+
# degradation configurations
|
| 90 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
| 91 |
+
self.blur_sigma = opt['blur_sigma']
|
| 92 |
+
self.kernel_list = opt['kernel_list']
|
| 93 |
+
self.kernel_prob = opt['kernel_prob']
|
| 94 |
+
self.downsample_range = opt['downsample_range']
|
| 95 |
+
self.noise_range = opt['noise_range']
|
| 96 |
+
self.jpeg_range = opt['jpeg_range']
|
| 97 |
+
# print
|
| 98 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
| 99 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
| 100 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
| 101 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
| 102 |
+
|
| 103 |
+
# color jitter
|
| 104 |
+
self.color_jitter_prob = opt.get('color_jitter_prob', None)
|
| 105 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
|
| 106 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
| 107 |
+
if self.color_jitter_prob is not None:
|
| 108 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
| 109 |
+
|
| 110 |
+
# to gray
|
| 111 |
+
self.gray_prob = opt.get('gray_prob', 0.0)
|
| 112 |
+
if self.gray_prob is not None:
|
| 113 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
| 114 |
+
self.color_jitter_shift /= 255.
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def color_jitter(img, shift):
|
| 118 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
| 119 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
| 120 |
+
img = img + jitter_val
|
| 121 |
+
img = np.clip(img, 0, 1)
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
| 126 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
| 127 |
+
fn_idx = torch.randperm(4)
|
| 128 |
+
for fn_id in fn_idx:
|
| 129 |
+
if fn_id == 0 and brightness is not None:
|
| 130 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
| 131 |
+
img = adjust_brightness(img, brightness_factor)
|
| 132 |
+
|
| 133 |
+
if fn_id == 1 and contrast is not None:
|
| 134 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
| 135 |
+
img = adjust_contrast(img, contrast_factor)
|
| 136 |
+
|
| 137 |
+
if fn_id == 2 and saturation is not None:
|
| 138 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
| 139 |
+
img = adjust_saturation(img, saturation_factor)
|
| 140 |
+
|
| 141 |
+
if fn_id == 3 and hue is not None:
|
| 142 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
| 143 |
+
img = adjust_hue(img, hue_factor)
|
| 144 |
+
return img
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def get_component_locations(self, name, status):
|
| 148 |
+
components_bbox = self.components_dict[name]
|
| 149 |
+
if status[0]: # hflip
|
| 150 |
+
# exchange right and left eye
|
| 151 |
+
tmp = components_bbox['left_eye']
|
| 152 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
| 153 |
+
components_bbox['right_eye'] = tmp
|
| 154 |
+
# modify the width coordinate
|
| 155 |
+
components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
|
| 156 |
+
components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
|
| 157 |
+
components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
|
| 158 |
+
components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
|
| 159 |
+
|
| 160 |
+
locations_gt = {}
|
| 161 |
+
locations_in = {}
|
| 162 |
+
for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
|
| 163 |
+
mean = components_bbox[part][0:2]
|
| 164 |
+
half_len = components_bbox[part][2]
|
| 165 |
+
if 'eye' in part:
|
| 166 |
+
half_len *= self.eye_enlarge_ratio
|
| 167 |
+
elif part == 'nose':
|
| 168 |
+
half_len *= self.nose_enlarge_ratio
|
| 169 |
+
elif part == 'mouth':
|
| 170 |
+
half_len *= self.mouth_enlarge_ratio
|
| 171 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
| 172 |
+
loc = torch.from_numpy(loc).float()
|
| 173 |
+
locations_gt[part] = loc
|
| 174 |
+
loc_in = loc/(self.gt_size//self.in_size)
|
| 175 |
+
locations_in[part] = loc_in
|
| 176 |
+
return locations_gt, locations_in
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def __getitem__(self, index):
|
| 180 |
+
if self.file_client is None:
|
| 181 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 182 |
+
|
| 183 |
+
# load gt image
|
| 184 |
+
gt_path = self.paths[index]
|
| 185 |
+
name = osp.basename(gt_path)[:-4]
|
| 186 |
+
img_bytes = self.file_client.get(gt_path)
|
| 187 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 188 |
+
|
| 189 |
+
# random horizontal flip
|
| 190 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
| 191 |
+
|
| 192 |
+
if self.load_latent_gt:
|
| 193 |
+
if status[0]:
|
| 194 |
+
latent_gt = self.latent_gt_dict['hflip'][name]
|
| 195 |
+
else:
|
| 196 |
+
latent_gt = self.latent_gt_dict['orig'][name]
|
| 197 |
+
|
| 198 |
+
if self.crop_components:
|
| 199 |
+
locations_gt, locations_in = self.get_component_locations(name, status)
|
| 200 |
+
|
| 201 |
+
# generate in image
|
| 202 |
+
img_in = img_gt
|
| 203 |
+
if self.use_corrupt and not self.gen_inpaint_mask:
|
| 204 |
+
# motion blur
|
| 205 |
+
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
|
| 206 |
+
m_i = random.randint(0,31)
|
| 207 |
+
k = self.motion_kernels[f'{m_i:02d}']
|
| 208 |
+
img_in = cv2.filter2D(img_in,-1,k)
|
| 209 |
+
|
| 210 |
+
# gaussian blur
|
| 211 |
+
kernel = gaussian_kernels.random_mixed_kernels(
|
| 212 |
+
self.kernel_list,
|
| 213 |
+
self.kernel_prob,
|
| 214 |
+
self.blur_kernel_size,
|
| 215 |
+
self.blur_sigma,
|
| 216 |
+
self.blur_sigma,
|
| 217 |
+
[-math.pi, math.pi],
|
| 218 |
+
noise_range=None)
|
| 219 |
+
img_in = cv2.filter2D(img_in, -1, kernel)
|
| 220 |
+
|
| 221 |
+
# downsample
|
| 222 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
| 223 |
+
img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
|
| 224 |
+
|
| 225 |
+
# noise
|
| 226 |
+
if self.noise_range is not None:
|
| 227 |
+
noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
|
| 228 |
+
noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
|
| 229 |
+
img_in = img_in + noise
|
| 230 |
+
img_in = np.clip(img_in, 0, 1)
|
| 231 |
+
|
| 232 |
+
# jpeg
|
| 233 |
+
if self.jpeg_range is not None:
|
| 234 |
+
jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
|
| 235 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
|
| 236 |
+
_, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
|
| 237 |
+
img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
| 238 |
+
|
| 239 |
+
# resize to in_size
|
| 240 |
+
img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
|
| 241 |
+
|
| 242 |
+
# if self.gen_inpaint_mask:
|
| 243 |
+
# inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
|
| 244 |
+
# max_angle = self.mask_max_angle, max_len = self.mask_max_len,
|
| 245 |
+
# max_width = self.mask_max_width, times = self.mask_draw_times)
|
| 246 |
+
# img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
|
| 247 |
+
# 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
|
| 248 |
+
|
| 249 |
+
# inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
|
| 250 |
+
|
| 251 |
+
if self.gen_inpaint_mask:
|
| 252 |
+
img_in = (img_in*255).astype('uint8')
|
| 253 |
+
img_in = brush_stroke_mask(Image.fromarray(img_in))
|
| 254 |
+
img_in = np.array(img_in) / 255.
|
| 255 |
+
|
| 256 |
+
# random color jitter (only for lq)
|
| 257 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
| 258 |
+
img_in = self.color_jitter(img_in, self.color_jitter_shift)
|
| 259 |
+
# random to gray (only for lq)
|
| 260 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
| 261 |
+
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
|
| 262 |
+
img_in = np.tile(img_in[:, :, None], [1, 1, 3])
|
| 263 |
+
|
| 264 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 265 |
+
img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
|
| 266 |
+
|
| 267 |
+
# random color jitter (pytorch version) (only for lq)
|
| 268 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
| 269 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
| 270 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
| 271 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
| 272 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
| 273 |
+
img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
|
| 274 |
+
|
| 275 |
+
# round and clip
|
| 276 |
+
img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
|
| 277 |
+
|
| 278 |
+
# Set vgg range_norm=True if use the normalization here
|
| 279 |
+
# normalize
|
| 280 |
+
normalize(img_in, self.mean, self.std, inplace=True)
|
| 281 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 282 |
+
|
| 283 |
+
return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
|
| 284 |
+
|
| 285 |
+
if self.crop_components:
|
| 286 |
+
return_dict['locations_in'] = locations_in
|
| 287 |
+
return_dict['locations_gt'] = locations_gt
|
| 288 |
+
|
| 289 |
+
if self.load_latent_gt:
|
| 290 |
+
return_dict['latent_gt'] = latent_gt
|
| 291 |
+
|
| 292 |
+
# if self.gen_inpaint_mask:
|
| 293 |
+
# return_dict['inpaint_mask'] = inpaint_mask
|
| 294 |
+
|
| 295 |
+
return return_dict
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def __len__(self):
|
| 299 |
+
return len(self.paths)
|
basicsr/data/ffhq_blind_joint_dataset.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os.path as osp
|
| 6 |
+
from scipy.io import loadmat
|
| 7 |
+
import torch
|
| 8 |
+
import torch.utils.data as data
|
| 9 |
+
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
|
| 10 |
+
adjust_hue, adjust_saturation, normalize)
|
| 11 |
+
from basicsr.data import gaussian_kernels as gaussian_kernels
|
| 12 |
+
from basicsr.data.transforms import augment
|
| 13 |
+
from basicsr.data.data_util import paths_from_folder
|
| 14 |
+
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
| 15 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 16 |
+
|
| 17 |
+
@DATASET_REGISTRY.register()
|
| 18 |
+
class FFHQBlindJointDataset(data.Dataset):
|
| 19 |
+
|
| 20 |
+
def __init__(self, opt):
|
| 21 |
+
super(FFHQBlindJointDataset, self).__init__()
|
| 22 |
+
logger = get_root_logger()
|
| 23 |
+
self.opt = opt
|
| 24 |
+
# file client (io backend)
|
| 25 |
+
self.file_client = None
|
| 26 |
+
self.io_backend_opt = opt['io_backend']
|
| 27 |
+
|
| 28 |
+
self.gt_folder = opt['dataroot_gt']
|
| 29 |
+
self.gt_size = opt.get('gt_size', 512)
|
| 30 |
+
self.in_size = opt.get('in_size', 512)
|
| 31 |
+
assert self.gt_size >= self.in_size, 'Wrong setting.'
|
| 32 |
+
|
| 33 |
+
self.mean = opt.get('mean', [0.5, 0.5, 0.5])
|
| 34 |
+
self.std = opt.get('std', [0.5, 0.5, 0.5])
|
| 35 |
+
|
| 36 |
+
self.component_path = opt.get('component_path', None)
|
| 37 |
+
self.latent_gt_path = opt.get('latent_gt_path', None)
|
| 38 |
+
|
| 39 |
+
if self.component_path is not None:
|
| 40 |
+
self.crop_components = True
|
| 41 |
+
self.components_dict = torch.load(self.component_path)
|
| 42 |
+
self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
|
| 43 |
+
self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
|
| 44 |
+
self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
|
| 45 |
+
else:
|
| 46 |
+
self.crop_components = False
|
| 47 |
+
|
| 48 |
+
if self.latent_gt_path is not None:
|
| 49 |
+
self.load_latent_gt = True
|
| 50 |
+
self.latent_gt_dict = torch.load(self.latent_gt_path)
|
| 51 |
+
else:
|
| 52 |
+
self.load_latent_gt = False
|
| 53 |
+
|
| 54 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 55 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
| 56 |
+
if not self.gt_folder.endswith('.lmdb'):
|
| 57 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
|
| 58 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
| 59 |
+
self.paths = [line.split('.')[0] for line in fin]
|
| 60 |
+
else:
|
| 61 |
+
self.paths = paths_from_folder(self.gt_folder)
|
| 62 |
+
|
| 63 |
+
# perform corrupt
|
| 64 |
+
self.use_corrupt = opt.get('use_corrupt', True)
|
| 65 |
+
self.use_motion_kernel = False
|
| 66 |
+
# self.use_motion_kernel = opt.get('use_motion_kernel', True)
|
| 67 |
+
|
| 68 |
+
if self.use_motion_kernel:
|
| 69 |
+
self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
|
| 70 |
+
motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
|
| 71 |
+
self.motion_kernels = torch.load(motion_kernel_path)
|
| 72 |
+
|
| 73 |
+
if self.use_corrupt:
|
| 74 |
+
# degradation configurations
|
| 75 |
+
self.blur_kernel_size = self.opt['blur_kernel_size']
|
| 76 |
+
self.kernel_list = self.opt['kernel_list']
|
| 77 |
+
self.kernel_prob = self.opt['kernel_prob']
|
| 78 |
+
# Small degradation
|
| 79 |
+
self.blur_sigma = self.opt['blur_sigma']
|
| 80 |
+
self.downsample_range = self.opt['downsample_range']
|
| 81 |
+
self.noise_range = self.opt['noise_range']
|
| 82 |
+
self.jpeg_range = self.opt['jpeg_range']
|
| 83 |
+
# Large degradation
|
| 84 |
+
self.blur_sigma_large = self.opt['blur_sigma_large']
|
| 85 |
+
self.downsample_range_large = self.opt['downsample_range_large']
|
| 86 |
+
self.noise_range_large = self.opt['noise_range_large']
|
| 87 |
+
self.jpeg_range_large = self.opt['jpeg_range_large']
|
| 88 |
+
|
| 89 |
+
# print
|
| 90 |
+
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
|
| 91 |
+
logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
|
| 92 |
+
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
|
| 93 |
+
logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
|
| 94 |
+
|
| 95 |
+
# color jitter
|
| 96 |
+
self.color_jitter_prob = opt.get('color_jitter_prob', None)
|
| 97 |
+
self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
|
| 98 |
+
self.color_jitter_shift = opt.get('color_jitter_shift', 20)
|
| 99 |
+
if self.color_jitter_prob is not None:
|
| 100 |
+
logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
|
| 101 |
+
|
| 102 |
+
# to gray
|
| 103 |
+
self.gray_prob = opt.get('gray_prob', 0.0)
|
| 104 |
+
if self.gray_prob is not None:
|
| 105 |
+
logger.info(f'Use random gray. Prob: {self.gray_prob}')
|
| 106 |
+
self.color_jitter_shift /= 255.
|
| 107 |
+
|
| 108 |
+
@staticmethod
|
| 109 |
+
def color_jitter(img, shift):
|
| 110 |
+
"""jitter color: randomly jitter the RGB values, in numpy formats"""
|
| 111 |
+
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
|
| 112 |
+
img = img + jitter_val
|
| 113 |
+
img = np.clip(img, 0, 1)
|
| 114 |
+
return img
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def color_jitter_pt(img, brightness, contrast, saturation, hue):
|
| 118 |
+
"""jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
|
| 119 |
+
fn_idx = torch.randperm(4)
|
| 120 |
+
for fn_id in fn_idx:
|
| 121 |
+
if fn_id == 0 and brightness is not None:
|
| 122 |
+
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
|
| 123 |
+
img = adjust_brightness(img, brightness_factor)
|
| 124 |
+
|
| 125 |
+
if fn_id == 1 and contrast is not None:
|
| 126 |
+
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
|
| 127 |
+
img = adjust_contrast(img, contrast_factor)
|
| 128 |
+
|
| 129 |
+
if fn_id == 2 and saturation is not None:
|
| 130 |
+
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
|
| 131 |
+
img = adjust_saturation(img, saturation_factor)
|
| 132 |
+
|
| 133 |
+
if fn_id == 3 and hue is not None:
|
| 134 |
+
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
|
| 135 |
+
img = adjust_hue(img, hue_factor)
|
| 136 |
+
return img
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_component_locations(self, name, status):
|
| 140 |
+
components_bbox = self.components_dict[name]
|
| 141 |
+
if status[0]: # hflip
|
| 142 |
+
# exchange right and left eye
|
| 143 |
+
tmp = components_bbox['left_eye']
|
| 144 |
+
components_bbox['left_eye'] = components_bbox['right_eye']
|
| 145 |
+
components_bbox['right_eye'] = tmp
|
| 146 |
+
# modify the width coordinate
|
| 147 |
+
components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
|
| 148 |
+
components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
|
| 149 |
+
components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
|
| 150 |
+
components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
|
| 151 |
+
|
| 152 |
+
locations_gt = {}
|
| 153 |
+
locations_in = {}
|
| 154 |
+
for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
|
| 155 |
+
mean = components_bbox[part][0:2]
|
| 156 |
+
half_len = components_bbox[part][2]
|
| 157 |
+
if 'eye' in part:
|
| 158 |
+
half_len *= self.eye_enlarge_ratio
|
| 159 |
+
elif part == 'nose':
|
| 160 |
+
half_len *= self.nose_enlarge_ratio
|
| 161 |
+
elif part == 'mouth':
|
| 162 |
+
half_len *= self.mouth_enlarge_ratio
|
| 163 |
+
loc = np.hstack((mean - half_len + 1, mean + half_len))
|
| 164 |
+
loc = torch.from_numpy(loc).float()
|
| 165 |
+
locations_gt[part] = loc
|
| 166 |
+
loc_in = loc/(self.gt_size//self.in_size)
|
| 167 |
+
locations_in[part] = loc_in
|
| 168 |
+
return locations_gt, locations_in
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def __getitem__(self, index):
|
| 172 |
+
if self.file_client is None:
|
| 173 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 174 |
+
|
| 175 |
+
# load gt image
|
| 176 |
+
gt_path = self.paths[index]
|
| 177 |
+
name = osp.basename(gt_path)[:-4]
|
| 178 |
+
img_bytes = self.file_client.get(gt_path)
|
| 179 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 180 |
+
|
| 181 |
+
# random horizontal flip
|
| 182 |
+
img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
|
| 183 |
+
|
| 184 |
+
if self.load_latent_gt:
|
| 185 |
+
if status[0]:
|
| 186 |
+
latent_gt = self.latent_gt_dict['hflip'][name]
|
| 187 |
+
else:
|
| 188 |
+
latent_gt = self.latent_gt_dict['orig'][name]
|
| 189 |
+
|
| 190 |
+
if self.crop_components:
|
| 191 |
+
locations_gt, locations_in = self.get_component_locations(name, status)
|
| 192 |
+
|
| 193 |
+
# generate in image
|
| 194 |
+
img_in = img_gt
|
| 195 |
+
if self.use_corrupt:
|
| 196 |
+
# motion blur
|
| 197 |
+
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
|
| 198 |
+
m_i = random.randint(0,31)
|
| 199 |
+
k = self.motion_kernels[f'{m_i:02d}']
|
| 200 |
+
img_in = cv2.filter2D(img_in,-1,k)
|
| 201 |
+
|
| 202 |
+
# gaussian blur
|
| 203 |
+
kernel = gaussian_kernels.random_mixed_kernels(
|
| 204 |
+
self.kernel_list,
|
| 205 |
+
self.kernel_prob,
|
| 206 |
+
self.blur_kernel_size,
|
| 207 |
+
self.blur_sigma,
|
| 208 |
+
self.blur_sigma,
|
| 209 |
+
[-math.pi, math.pi],
|
| 210 |
+
noise_range=None)
|
| 211 |
+
img_in = cv2.filter2D(img_in, -1, kernel)
|
| 212 |
+
|
| 213 |
+
# downsample
|
| 214 |
+
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
|
| 215 |
+
img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
|
| 216 |
+
|
| 217 |
+
# noise
|
| 218 |
+
if self.noise_range is not None:
|
| 219 |
+
noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
|
| 220 |
+
noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
|
| 221 |
+
img_in = img_in + noise
|
| 222 |
+
img_in = np.clip(img_in, 0, 1)
|
| 223 |
+
|
| 224 |
+
# jpeg
|
| 225 |
+
if self.jpeg_range is not None:
|
| 226 |
+
jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
|
| 227 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
|
| 228 |
+
_, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
|
| 229 |
+
img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
| 230 |
+
|
| 231 |
+
# resize to in_size
|
| 232 |
+
img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# generate in_large with large degradation
|
| 236 |
+
img_in_large = img_gt
|
| 237 |
+
|
| 238 |
+
if self.use_corrupt:
|
| 239 |
+
# motion blur
|
| 240 |
+
if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
|
| 241 |
+
m_i = random.randint(0,31)
|
| 242 |
+
k = self.motion_kernels[f'{m_i:02d}']
|
| 243 |
+
img_in_large = cv2.filter2D(img_in_large,-1,k)
|
| 244 |
+
|
| 245 |
+
# gaussian blur
|
| 246 |
+
kernel = gaussian_kernels.random_mixed_kernels(
|
| 247 |
+
self.kernel_list,
|
| 248 |
+
self.kernel_prob,
|
| 249 |
+
self.blur_kernel_size,
|
| 250 |
+
self.blur_sigma_large,
|
| 251 |
+
self.blur_sigma_large,
|
| 252 |
+
[-math.pi, math.pi],
|
| 253 |
+
noise_range=None)
|
| 254 |
+
img_in_large = cv2.filter2D(img_in_large, -1, kernel)
|
| 255 |
+
|
| 256 |
+
# downsample
|
| 257 |
+
scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
|
| 258 |
+
img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
|
| 259 |
+
|
| 260 |
+
# noise
|
| 261 |
+
if self.noise_range_large is not None:
|
| 262 |
+
noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
|
| 263 |
+
noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
|
| 264 |
+
img_in_large = img_in_large + noise
|
| 265 |
+
img_in_large = np.clip(img_in_large, 0, 1)
|
| 266 |
+
|
| 267 |
+
# jpeg
|
| 268 |
+
if self.jpeg_range_large is not None:
|
| 269 |
+
jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
|
| 270 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
|
| 271 |
+
_, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
|
| 272 |
+
img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
| 273 |
+
|
| 274 |
+
# resize to in_size
|
| 275 |
+
img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# random color jitter (only for lq)
|
| 279 |
+
if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
| 280 |
+
img_in = self.color_jitter(img_in, self.color_jitter_shift)
|
| 281 |
+
img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
|
| 282 |
+
# random to gray (only for lq)
|
| 283 |
+
if self.gray_prob and np.random.uniform() < self.gray_prob:
|
| 284 |
+
img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
|
| 285 |
+
img_in = np.tile(img_in[:, :, None], [1, 1, 3])
|
| 286 |
+
img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
|
| 287 |
+
img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
|
| 288 |
+
|
| 289 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 290 |
+
img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
|
| 291 |
+
|
| 292 |
+
# random color jitter (pytorch version) (only for lq)
|
| 293 |
+
if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
| 294 |
+
brightness = self.opt.get('brightness', (0.5, 1.5))
|
| 295 |
+
contrast = self.opt.get('contrast', (0.5, 1.5))
|
| 296 |
+
saturation = self.opt.get('saturation', (0, 1.5))
|
| 297 |
+
hue = self.opt.get('hue', (-0.1, 0.1))
|
| 298 |
+
img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
|
| 299 |
+
img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
|
| 300 |
+
|
| 301 |
+
# round and clip
|
| 302 |
+
img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
|
| 303 |
+
img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
|
| 304 |
+
|
| 305 |
+
# Set vgg range_norm=True if use the normalization here
|
| 306 |
+
# normalize
|
| 307 |
+
normalize(img_in, self.mean, self.std, inplace=True)
|
| 308 |
+
normalize(img_in_large, self.mean, self.std, inplace=True)
|
| 309 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 310 |
+
|
| 311 |
+
return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
|
| 312 |
+
|
| 313 |
+
if self.crop_components:
|
| 314 |
+
return_dict['locations_in'] = locations_in
|
| 315 |
+
return_dict['locations_gt'] = locations_gt
|
| 316 |
+
|
| 317 |
+
if self.load_latent_gt:
|
| 318 |
+
return_dict['latent_gt'] = latent_gt
|
| 319 |
+
|
| 320 |
+
return return_dict
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def __len__(self):
|
| 324 |
+
return len(self.paths)
|
basicsr/data/gaussian_kernels.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
from scipy.ndimage.interpolation import shift
|
| 5 |
+
from scipy.stats import multivariate_normal
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
| 9 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
| 10 |
+
Args:
|
| 11 |
+
sig_x (float):
|
| 12 |
+
sig_y (float):
|
| 13 |
+
theta (float): Radian measurement.
|
| 14 |
+
Returns:
|
| 15 |
+
ndarray: Rotated sigma matrix.
|
| 16 |
+
"""
|
| 17 |
+
D = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
| 18 |
+
U = np.array([[np.cos(theta), -np.sin(theta)],
|
| 19 |
+
[np.sin(theta), np.cos(theta)]])
|
| 20 |
+
return np.dot(U, np.dot(D, U.T))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def mesh_grid(kernel_size):
|
| 24 |
+
"""Generate the mesh grid, centering at zero.
|
| 25 |
+
Args:
|
| 26 |
+
kernel_size (int):
|
| 27 |
+
Returns:
|
| 28 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
| 29 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
| 30 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
| 31 |
+
"""
|
| 32 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
| 33 |
+
xx, yy = np.meshgrid(ax, ax)
|
| 34 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
|
| 35 |
+
yy.reshape(kernel_size * kernel_size,
|
| 36 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
| 37 |
+
return xy, xx, yy
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pdf2(sigma_matrix, grid):
|
| 41 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
| 42 |
+
Args:
|
| 43 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
| 44 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
| 45 |
+
with the shape (K, K, 2), K is the kernel size.
|
| 46 |
+
Returns:
|
| 47 |
+
kernel (ndarrray): un-normalized kernel.
|
| 48 |
+
"""
|
| 49 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 50 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
| 51 |
+
return kernel
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def cdf2(D, grid):
|
| 55 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
| 56 |
+
Used in skewed Gaussian distribution.
|
| 57 |
+
Args:
|
| 58 |
+
D (ndarrasy): skew matrix.
|
| 59 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
| 60 |
+
with the shape (K, K, 2), K is the kernel size.
|
| 61 |
+
Returns:
|
| 62 |
+
cdf (ndarray): skewed cdf.
|
| 63 |
+
"""
|
| 64 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
| 65 |
+
grid = np.dot(grid, D)
|
| 66 |
+
cdf = rv.cdf(grid)
|
| 67 |
+
return cdf
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
|
| 71 |
+
"""Generate a bivariate skew Gaussian kernel.
|
| 72 |
+
Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
|
| 73 |
+
Args:
|
| 74 |
+
kernel_size (int):
|
| 75 |
+
sig_x (float):
|
| 76 |
+
sig_y (float):
|
| 77 |
+
theta (float): Radian measurement.
|
| 78 |
+
D (ndarrasy): skew matrix.
|
| 79 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 80 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 81 |
+
Returns:
|
| 82 |
+
kernel (ndarray): normalized kernel.
|
| 83 |
+
.. _A multivariate skew normal distribution:
|
| 84 |
+
https://www.sciencedirect.com/science/article/pii/S0047259X03001313
|
| 85 |
+
"""
|
| 86 |
+
if grid is None:
|
| 87 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 88 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 89 |
+
pdf = pdf2(sigma_matrix, grid)
|
| 90 |
+
cdf = cdf2(D, grid)
|
| 91 |
+
kernel = pdf * cdf
|
| 92 |
+
kernel = kernel / np.sum(kernel)
|
| 93 |
+
return kernel
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def mass_center_shift(kernel_size, kernel):
|
| 97 |
+
"""Calculate the shift of the mass center of a kenrel.
|
| 98 |
+
Args:
|
| 99 |
+
kernel_size (int):
|
| 100 |
+
kernel (ndarray): normalized kernel.
|
| 101 |
+
Returns:
|
| 102 |
+
delta_h (float):
|
| 103 |
+
delta_w (float):
|
| 104 |
+
"""
|
| 105 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
| 106 |
+
col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
|
| 107 |
+
delta_h = np.dot(row_sum, ax)
|
| 108 |
+
delta_w = np.dot(col_sum, ax)
|
| 109 |
+
return delta_h, delta_w
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def bivariate_skew_Gaussian_center(kernel_size,
|
| 113 |
+
sig_x,
|
| 114 |
+
sig_y,
|
| 115 |
+
theta,
|
| 116 |
+
D,
|
| 117 |
+
grid=None):
|
| 118 |
+
"""Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
|
| 119 |
+
Args:
|
| 120 |
+
kernel_size (int):
|
| 121 |
+
sig_x (float):
|
| 122 |
+
sig_y (float):
|
| 123 |
+
theta (float): Radian measurement.
|
| 124 |
+
D (ndarrasy): skew matrix.
|
| 125 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 126 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 127 |
+
Returns:
|
| 128 |
+
kernel (ndarray): centered and normalized kernel.
|
| 129 |
+
"""
|
| 130 |
+
if grid is None:
|
| 131 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 132 |
+
kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
|
| 133 |
+
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
|
| 134 |
+
kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
|
| 135 |
+
kernel = kernel / np.sum(kernel)
|
| 136 |
+
return kernel
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def bivariate_anisotropic_Gaussian(kernel_size,
|
| 140 |
+
sig_x,
|
| 141 |
+
sig_y,
|
| 142 |
+
theta,
|
| 143 |
+
grid=None):
|
| 144 |
+
"""Generate a bivariate anisotropic Gaussian kernel.
|
| 145 |
+
Args:
|
| 146 |
+
kernel_size (int):
|
| 147 |
+
sig_x (float):
|
| 148 |
+
sig_y (float):
|
| 149 |
+
theta (float): Radian measurement.
|
| 150 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 151 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 152 |
+
Returns:
|
| 153 |
+
kernel (ndarray): normalized kernel.
|
| 154 |
+
"""
|
| 155 |
+
if grid is None:
|
| 156 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 157 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 158 |
+
kernel = pdf2(sigma_matrix, grid)
|
| 159 |
+
kernel = kernel / np.sum(kernel)
|
| 160 |
+
return kernel
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
|
| 164 |
+
"""Generate a bivariate isotropic Gaussian kernel.
|
| 165 |
+
Args:
|
| 166 |
+
kernel_size (int):
|
| 167 |
+
sig (float):
|
| 168 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 169 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 170 |
+
Returns:
|
| 171 |
+
kernel (ndarray): normalized kernel.
|
| 172 |
+
"""
|
| 173 |
+
if grid is None:
|
| 174 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 175 |
+
sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
|
| 176 |
+
kernel = pdf2(sigma_matrix, grid)
|
| 177 |
+
kernel = kernel / np.sum(kernel)
|
| 178 |
+
return kernel
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def bivariate_generalized_Gaussian(kernel_size,
|
| 182 |
+
sig_x,
|
| 183 |
+
sig_y,
|
| 184 |
+
theta,
|
| 185 |
+
beta,
|
| 186 |
+
grid=None):
|
| 187 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
| 188 |
+
Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
|
| 189 |
+
by Pascal et. al (2013).
|
| 190 |
+
Args:
|
| 191 |
+
kernel_size (int):
|
| 192 |
+
sig_x (float):
|
| 193 |
+
sig_y (float):
|
| 194 |
+
theta (float): Radian measurement.
|
| 195 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
| 196 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 197 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 198 |
+
Returns:
|
| 199 |
+
kernel (ndarray): normalized kernel.
|
| 200 |
+
.. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
|
| 201 |
+
https://arxiv.org/abs/1302.6498
|
| 202 |
+
"""
|
| 203 |
+
if grid is None:
|
| 204 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 205 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 206 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 207 |
+
kernel = np.exp(
|
| 208 |
+
-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
| 209 |
+
kernel = kernel / np.sum(kernel)
|
| 210 |
+
return kernel
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
|
| 214 |
+
"""Generate a plateau-like anisotropic kernel.
|
| 215 |
+
1 / (1+x^(beta))
|
| 216 |
+
Args:
|
| 217 |
+
kernel_size (int):
|
| 218 |
+
sig_x (float):
|
| 219 |
+
sig_y (float):
|
| 220 |
+
theta (float): Radian measurement.
|
| 221 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
| 222 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 223 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 224 |
+
Returns:
|
| 225 |
+
kernel (ndarray): normalized kernel.
|
| 226 |
+
"""
|
| 227 |
+
if grid is None:
|
| 228 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 229 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 230 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 231 |
+
kernel = np.reciprocal(
|
| 232 |
+
np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
| 233 |
+
kernel = kernel / np.sum(kernel)
|
| 234 |
+
return kernel
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
|
| 238 |
+
"""Generate a plateau-like isotropic kernel.
|
| 239 |
+
1 / (1+x^(beta))
|
| 240 |
+
Args:
|
| 241 |
+
kernel_size (int):
|
| 242 |
+
sig (float):
|
| 243 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
| 244 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 245 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 246 |
+
Returns:
|
| 247 |
+
kernel (ndarray): normalized kernel.
|
| 248 |
+
"""
|
| 249 |
+
if grid is None:
|
| 250 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 251 |
+
sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
|
| 252 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 253 |
+
kernel = np.reciprocal(
|
| 254 |
+
np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
| 255 |
+
kernel = kernel / np.sum(kernel)
|
| 256 |
+
return kernel
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def random_bivariate_skew_Gaussian_center(kernel_size,
|
| 260 |
+
sigma_x_range,
|
| 261 |
+
sigma_y_range,
|
| 262 |
+
rotation_range,
|
| 263 |
+
noise_range=None,
|
| 264 |
+
strict=False):
|
| 265 |
+
"""Randomly generate bivariate skew Gaussian kernels at center.
|
| 266 |
+
Args:
|
| 267 |
+
kernel_size (int):
|
| 268 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 269 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 270 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 271 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 272 |
+
Returns:
|
| 273 |
+
kernel (ndarray):
|
| 274 |
+
"""
|
| 275 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 276 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 277 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 278 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 279 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 280 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 281 |
+
if strict:
|
| 282 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
| 283 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
| 284 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
| 285 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 286 |
+
|
| 287 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
| 288 |
+
thres = 3 / sigma_max
|
| 289 |
+
D = [[np.random.uniform(-thres, thres),
|
| 290 |
+
np.random.uniform(-thres, thres)],
|
| 291 |
+
[np.random.uniform(-thres, thres),
|
| 292 |
+
np.random.uniform(-thres, thres)]]
|
| 293 |
+
|
| 294 |
+
kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
|
| 295 |
+
rotation, D)
|
| 296 |
+
|
| 297 |
+
# add multiplicative noise
|
| 298 |
+
if noise_range is not None:
|
| 299 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 300 |
+
noise = np.random.uniform(
|
| 301 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 302 |
+
kernel = kernel * noise
|
| 303 |
+
kernel = kernel / np.sum(kernel)
|
| 304 |
+
if strict:
|
| 305 |
+
return kernel, sigma_x, sigma_y, rotation, D
|
| 306 |
+
else:
|
| 307 |
+
return kernel
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def random_bivariate_anisotropic_Gaussian(kernel_size,
|
| 311 |
+
sigma_x_range,
|
| 312 |
+
sigma_y_range,
|
| 313 |
+
rotation_range,
|
| 314 |
+
noise_range=None,
|
| 315 |
+
strict=False):
|
| 316 |
+
"""Randomly generate bivariate anisotropic Gaussian kernels.
|
| 317 |
+
Args:
|
| 318 |
+
kernel_size (int):
|
| 319 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 320 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 321 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 322 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 323 |
+
Returns:
|
| 324 |
+
kernel (ndarray):
|
| 325 |
+
"""
|
| 326 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 327 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 328 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 329 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 330 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 331 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 332 |
+
if strict:
|
| 333 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
| 334 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
| 335 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
| 336 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 337 |
+
|
| 338 |
+
kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
|
| 339 |
+
rotation)
|
| 340 |
+
|
| 341 |
+
# add multiplicative noise
|
| 342 |
+
if noise_range is not None:
|
| 343 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 344 |
+
noise = np.random.uniform(
|
| 345 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 346 |
+
kernel = kernel * noise
|
| 347 |
+
kernel = kernel / np.sum(kernel)
|
| 348 |
+
if strict:
|
| 349 |
+
return kernel, sigma_x, sigma_y, rotation
|
| 350 |
+
else:
|
| 351 |
+
return kernel
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def random_bivariate_isotropic_Gaussian(kernel_size,
|
| 355 |
+
sigma_range,
|
| 356 |
+
noise_range=None,
|
| 357 |
+
strict=False):
|
| 358 |
+
"""Randomly generate bivariate isotropic Gaussian kernels.
|
| 359 |
+
Args:
|
| 360 |
+
kernel_size (int):
|
| 361 |
+
sigma_range (tuple): [0.6, 5]
|
| 362 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 363 |
+
Returns:
|
| 364 |
+
kernel (ndarray):
|
| 365 |
+
"""
|
| 366 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 367 |
+
assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
|
| 368 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
| 369 |
+
|
| 370 |
+
kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
|
| 371 |
+
|
| 372 |
+
# add multiplicative noise
|
| 373 |
+
if noise_range is not None:
|
| 374 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 375 |
+
noise = np.random.uniform(
|
| 376 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 377 |
+
kernel = kernel * noise
|
| 378 |
+
kernel = kernel / np.sum(kernel)
|
| 379 |
+
if strict:
|
| 380 |
+
return kernel, sigma
|
| 381 |
+
else:
|
| 382 |
+
return kernel
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
| 386 |
+
sigma_x_range,
|
| 387 |
+
sigma_y_range,
|
| 388 |
+
rotation_range,
|
| 389 |
+
beta_range,
|
| 390 |
+
noise_range=None,
|
| 391 |
+
strict=False):
|
| 392 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
| 393 |
+
Args:
|
| 394 |
+
kernel_size (int):
|
| 395 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 396 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 397 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 398 |
+
beta_range (tuple): [0.5, 8]
|
| 399 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 400 |
+
Returns:
|
| 401 |
+
kernel (ndarray):
|
| 402 |
+
"""
|
| 403 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 404 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 405 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 406 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 407 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 408 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 409 |
+
if strict:
|
| 410 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
| 411 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
| 412 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
| 413 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 414 |
+
if np.random.uniform() < 0.5:
|
| 415 |
+
beta = np.random.uniform(beta_range[0], 1)
|
| 416 |
+
else:
|
| 417 |
+
beta = np.random.uniform(1, beta_range[1])
|
| 418 |
+
|
| 419 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
|
| 420 |
+
rotation, beta)
|
| 421 |
+
|
| 422 |
+
# add multiplicative noise
|
| 423 |
+
if noise_range is not None:
|
| 424 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 425 |
+
noise = np.random.uniform(
|
| 426 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 427 |
+
kernel = kernel * noise
|
| 428 |
+
kernel = kernel / np.sum(kernel)
|
| 429 |
+
if strict:
|
| 430 |
+
return kernel, sigma_x, sigma_y, rotation, beta
|
| 431 |
+
else:
|
| 432 |
+
return kernel
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def random_bivariate_plateau_type1(kernel_size,
|
| 436 |
+
sigma_x_range,
|
| 437 |
+
sigma_y_range,
|
| 438 |
+
rotation_range,
|
| 439 |
+
beta_range,
|
| 440 |
+
noise_range=None,
|
| 441 |
+
strict=False):
|
| 442 |
+
"""Randomly generate bivariate plateau type1 kernels.
|
| 443 |
+
Args:
|
| 444 |
+
kernel_size (int):
|
| 445 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 446 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 447 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
| 448 |
+
beta_range (tuple): [1, 4]
|
| 449 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 450 |
+
Returns:
|
| 451 |
+
kernel (ndarray):
|
| 452 |
+
"""
|
| 453 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 454 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 455 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 456 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
| 457 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 458 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 459 |
+
if strict:
|
| 460 |
+
sigma_max = np.max([sigma_x, sigma_y])
|
| 461 |
+
sigma_min = np.min([sigma_x, sigma_y])
|
| 462 |
+
sigma_x, sigma_y = sigma_max, sigma_min
|
| 463 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
| 464 |
+
if np.random.uniform() < 0.5:
|
| 465 |
+
beta = np.random.uniform(beta_range[0], 1)
|
| 466 |
+
else:
|
| 467 |
+
beta = np.random.uniform(1, beta_range[1])
|
| 468 |
+
|
| 469 |
+
kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
|
| 470 |
+
beta)
|
| 471 |
+
|
| 472 |
+
# add multiplicative noise
|
| 473 |
+
if noise_range is not None:
|
| 474 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 475 |
+
noise = np.random.uniform(
|
| 476 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 477 |
+
kernel = kernel * noise
|
| 478 |
+
kernel = kernel / np.sum(kernel)
|
| 479 |
+
if strict:
|
| 480 |
+
return kernel, sigma_x, sigma_y, rotation, beta
|
| 481 |
+
else:
|
| 482 |
+
return kernel
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def random_bivariate_plateau_type1_iso(kernel_size,
|
| 486 |
+
sigma_range,
|
| 487 |
+
beta_range,
|
| 488 |
+
noise_range=None,
|
| 489 |
+
strict=False):
|
| 490 |
+
"""Randomly generate bivariate plateau type1 kernels (iso).
|
| 491 |
+
Args:
|
| 492 |
+
kernel_size (int):
|
| 493 |
+
sigma_range (tuple): [0.6, 5]
|
| 494 |
+
beta_range (tuple): [1, 4]
|
| 495 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 496 |
+
Returns:
|
| 497 |
+
kernel (ndarray):
|
| 498 |
+
"""
|
| 499 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 500 |
+
assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
|
| 501 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
| 502 |
+
beta = np.random.uniform(beta_range[0], beta_range[1])
|
| 503 |
+
|
| 504 |
+
kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
|
| 505 |
+
|
| 506 |
+
# add multiplicative noise
|
| 507 |
+
if noise_range is not None:
|
| 508 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 509 |
+
noise = np.random.uniform(
|
| 510 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 511 |
+
kernel = kernel * noise
|
| 512 |
+
kernel = kernel / np.sum(kernel)
|
| 513 |
+
if strict:
|
| 514 |
+
return kernel, sigma, beta
|
| 515 |
+
else:
|
| 516 |
+
return kernel
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def random_mixed_kernels(kernel_list,
|
| 520 |
+
kernel_prob,
|
| 521 |
+
kernel_size=21,
|
| 522 |
+
sigma_x_range=[0.6, 5],
|
| 523 |
+
sigma_y_range=[0.6, 5],
|
| 524 |
+
rotation_range=[-math.pi, math.pi],
|
| 525 |
+
beta_range=[0.5, 8],
|
| 526 |
+
noise_range=None):
|
| 527 |
+
"""Randomly generate mixed kernels.
|
| 528 |
+
Args:
|
| 529 |
+
kernel_list (tuple): a list name of kenrel types,
|
| 530 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
|
| 531 |
+
kernel_prob (tuple): corresponding kernel probability for each kernel type
|
| 532 |
+
kernel_size (int):
|
| 533 |
+
sigma_x_range (tuple): [0.6, 5]
|
| 534 |
+
sigma_y_range (tuple): [0.6, 5]
|
| 535 |
+
rotation range (tuple): [-math.pi, math.pi]
|
| 536 |
+
beta_range (tuple): [0.5, 8]
|
| 537 |
+
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
|
| 538 |
+
Returns:
|
| 539 |
+
kernel (ndarray):
|
| 540 |
+
"""
|
| 541 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
| 542 |
+
if kernel_type == 'iso':
|
| 543 |
+
kernel = random_bivariate_isotropic_Gaussian(
|
| 544 |
+
kernel_size, sigma_x_range, noise_range=noise_range)
|
| 545 |
+
elif kernel_type == 'aniso':
|
| 546 |
+
kernel = random_bivariate_anisotropic_Gaussian(
|
| 547 |
+
kernel_size,
|
| 548 |
+
sigma_x_range,
|
| 549 |
+
sigma_y_range,
|
| 550 |
+
rotation_range,
|
| 551 |
+
noise_range=noise_range)
|
| 552 |
+
elif kernel_type == 'skew':
|
| 553 |
+
kernel = random_bivariate_skew_Gaussian_center(
|
| 554 |
+
kernel_size,
|
| 555 |
+
sigma_x_range,
|
| 556 |
+
sigma_y_range,
|
| 557 |
+
rotation_range,
|
| 558 |
+
noise_range=noise_range)
|
| 559 |
+
elif kernel_type == 'generalized':
|
| 560 |
+
kernel = random_bivariate_generalized_Gaussian(
|
| 561 |
+
kernel_size,
|
| 562 |
+
sigma_x_range,
|
| 563 |
+
sigma_y_range,
|
| 564 |
+
rotation_range,
|
| 565 |
+
beta_range,
|
| 566 |
+
noise_range=noise_range)
|
| 567 |
+
elif kernel_type == 'plateau_iso':
|
| 568 |
+
kernel = random_bivariate_plateau_type1_iso(
|
| 569 |
+
kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
|
| 570 |
+
elif kernel_type == 'plateau_aniso':
|
| 571 |
+
kernel = random_bivariate_plateau_type1(
|
| 572 |
+
kernel_size,
|
| 573 |
+
sigma_x_range,
|
| 574 |
+
sigma_y_range,
|
| 575 |
+
rotation_range,
|
| 576 |
+
beta_range,
|
| 577 |
+
noise_range=noise_range)
|
| 578 |
+
# add multiplicative noise
|
| 579 |
+
if noise_range is not None:
|
| 580 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
| 581 |
+
noise = np.random.uniform(
|
| 582 |
+
noise_range[0], noise_range[1], size=kernel.shape)
|
| 583 |
+
kernel = kernel * noise
|
| 584 |
+
kernel = kernel / np.sum(kernel)
|
| 585 |
+
return kernel
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def show_one_kernel():
|
| 589 |
+
import matplotlib.pyplot as plt
|
| 590 |
+
kernel_size = 21
|
| 591 |
+
|
| 592 |
+
# bivariate skew Gaussian
|
| 593 |
+
D = [[0, 0], [0, 0]]
|
| 594 |
+
D = [[3 / 4, 0], [0, 0.5]]
|
| 595 |
+
kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
|
| 596 |
+
# bivariate anisotropic Gaussian
|
| 597 |
+
kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
|
| 598 |
+
# bivariate anisotropic Gaussian
|
| 599 |
+
kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
|
| 600 |
+
# bivariate generalized Gaussian
|
| 601 |
+
kernel = bivariate_generalized_Gaussian(
|
| 602 |
+
kernel_size, 2, 4, -math.pi / 4, beta=4)
|
| 603 |
+
|
| 604 |
+
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
|
| 605 |
+
print(delta_h, delta_w)
|
| 606 |
+
|
| 607 |
+
fig, axs = plt.subplots(nrows=2, ncols=2)
|
| 608 |
+
# axs.set_axis_off()
|
| 609 |
+
ax = axs[0][0]
|
| 610 |
+
im = ax.matshow(kernel, cmap='jet', origin='upper')
|
| 611 |
+
fig.colorbar(im, ax=ax)
|
| 612 |
+
|
| 613 |
+
# image
|
| 614 |
+
ax = axs[0][1]
|
| 615 |
+
kernel_vis = kernel - np.min(kernel)
|
| 616 |
+
kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
|
| 617 |
+
ax.imshow(kernel_vis, interpolation='nearest')
|
| 618 |
+
|
| 619 |
+
_, xx, yy = mesh_grid(kernel_size)
|
| 620 |
+
# contour
|
| 621 |
+
ax = axs[1][0]
|
| 622 |
+
CS = ax.contour(xx, yy, kernel, origin='upper')
|
| 623 |
+
ax.clabel(CS, inline=1, fontsize=3)
|
| 624 |
+
|
| 625 |
+
# contourf
|
| 626 |
+
ax = axs[1][1]
|
| 627 |
+
kernel = kernel / np.max(kernel)
|
| 628 |
+
p = ax.contourf(
|
| 629 |
+
xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
|
| 630 |
+
fig.colorbar(p)
|
| 631 |
+
|
| 632 |
+
plt.show()
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def show_plateau_kernel():
|
| 636 |
+
import matplotlib.pyplot as plt
|
| 637 |
+
kernel_size = 21
|
| 638 |
+
|
| 639 |
+
kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
|
| 640 |
+
kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
|
| 641 |
+
kernel_gau = bivariate_generalized_Gaussian(
|
| 642 |
+
kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
|
| 643 |
+
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
|
| 644 |
+
print(delta_h, delta_w)
|
| 645 |
+
|
| 646 |
+
# kernel_slice = kernel[10, :]
|
| 647 |
+
# kernel_gau_slice = kernel_gau[10, :]
|
| 648 |
+
# kernel_norm_slice = kernel_norm[10, :]
|
| 649 |
+
# fig, ax = plt.subplots()
|
| 650 |
+
# t = list(range(1, 22))
|
| 651 |
+
|
| 652 |
+
# ax.plot(t, kernel_gau_slice)
|
| 653 |
+
# ax.plot(t, kernel_slice)
|
| 654 |
+
# ax.plot(t, kernel_norm_slice)
|
| 655 |
+
|
| 656 |
+
# t = np.arange(0, 10, 0.1)
|
| 657 |
+
# y = np.exp(-0.5 * t)
|
| 658 |
+
# y2 = np.reciprocal(1 + t)
|
| 659 |
+
# print(t.shape)
|
| 660 |
+
# print(y.shape)
|
| 661 |
+
# ax.plot(t, y)
|
| 662 |
+
# ax.plot(t, y2)
|
| 663 |
+
# plt.show()
|
| 664 |
+
|
| 665 |
+
fig, axs = plt.subplots(nrows=2, ncols=2)
|
| 666 |
+
# axs.set_axis_off()
|
| 667 |
+
ax = axs[0][0]
|
| 668 |
+
im = ax.matshow(kernel, cmap='jet', origin='upper')
|
| 669 |
+
fig.colorbar(im, ax=ax)
|
| 670 |
+
|
| 671 |
+
# image
|
| 672 |
+
ax = axs[0][1]
|
| 673 |
+
kernel_vis = kernel - np.min(kernel)
|
| 674 |
+
kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
|
| 675 |
+
ax.imshow(kernel_vis, interpolation='nearest')
|
| 676 |
+
|
| 677 |
+
_, xx, yy = mesh_grid(kernel_size)
|
| 678 |
+
# contour
|
| 679 |
+
ax = axs[1][0]
|
| 680 |
+
CS = ax.contour(xx, yy, kernel, origin='upper')
|
| 681 |
+
ax.clabel(CS, inline=1, fontsize=3)
|
| 682 |
+
|
| 683 |
+
# contourf
|
| 684 |
+
ax = axs[1][1]
|
| 685 |
+
kernel = kernel / np.max(kernel)
|
| 686 |
+
p = ax.contourf(
|
| 687 |
+
xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
|
| 688 |
+
fig.colorbar(p)
|
| 689 |
+
|
| 690 |
+
plt.show()
|
basicsr/data/paired_image_dataset.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils import data as data
|
| 2 |
+
from torchvision.transforms.functional import normalize
|
| 3 |
+
|
| 4 |
+
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
|
| 5 |
+
from basicsr.data.transforms import augment, paired_random_crop
|
| 6 |
+
from basicsr.utils import FileClient, imfrombytes, img2tensor
|
| 7 |
+
from basicsr.utils.registry import DATASET_REGISTRY
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@DATASET_REGISTRY.register()
|
| 11 |
+
class PairedImageDataset(data.Dataset):
|
| 12 |
+
"""Paired image dataset for image restoration.
|
| 13 |
+
|
| 14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
|
| 15 |
+
GT image pairs.
|
| 16 |
+
|
| 17 |
+
There are three modes:
|
| 18 |
+
1. 'lmdb': Use lmdb files.
|
| 19 |
+
If opt['io_backend'] == lmdb.
|
| 20 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
| 21 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
| 22 |
+
3. 'folder': Scan folders to generate paths.
|
| 23 |
+
The rest.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
| 27 |
+
dataroot_gt (str): Data root path for gt.
|
| 28 |
+
dataroot_lq (str): Data root path for lq.
|
| 29 |
+
meta_info_file (str): Path for meta information file.
|
| 30 |
+
io_backend (dict): IO backend type and other kwarg.
|
| 31 |
+
filename_tmpl (str): Template for each filename. Note that the
|
| 32 |
+
template excludes the file extension. Default: '{}'.
|
| 33 |
+
gt_size (int): Cropped patched size for gt patches.
|
| 34 |
+
use_flip (bool): Use horizontal flips.
|
| 35 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h
|
| 36 |
+
and w for implementation).
|
| 37 |
+
|
| 38 |
+
scale (bool): Scale, which will be added automatically.
|
| 39 |
+
phase (str): 'train' or 'val'.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, opt):
|
| 43 |
+
super(PairedImageDataset, self).__init__()
|
| 44 |
+
self.opt = opt
|
| 45 |
+
# file client (io backend)
|
| 46 |
+
self.file_client = None
|
| 47 |
+
self.io_backend_opt = opt['io_backend']
|
| 48 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
| 49 |
+
self.std = opt['std'] if 'std' in opt else None
|
| 50 |
+
|
| 51 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
| 52 |
+
if 'filename_tmpl' in opt:
|
| 53 |
+
self.filename_tmpl = opt['filename_tmpl']
|
| 54 |
+
else:
|
| 55 |
+
self.filename_tmpl = '{}'
|
| 56 |
+
|
| 57 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
| 58 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
| 59 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
| 60 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
| 61 |
+
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
| 62 |
+
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
| 63 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
| 64 |
+
else:
|
| 65 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, index):
|
| 68 |
+
if self.file_client is None:
|
| 69 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
| 70 |
+
|
| 71 |
+
scale = self.opt['scale']
|
| 72 |
+
|
| 73 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
| 74 |
+
# image range: [0, 1], float32.
|
| 75 |
+
gt_path = self.paths[index]['gt_path']
|
| 76 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
| 77 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
| 78 |
+
lq_path = self.paths[index]['lq_path']
|
| 79 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
| 80 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
| 81 |
+
|
| 82 |
+
# augmentation for training
|
| 83 |
+
if self.opt['phase'] == 'train':
|
| 84 |
+
gt_size = self.opt['gt_size']
|
| 85 |
+
# random crop
|
| 86 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
| 87 |
+
# flip, rotation
|
| 88 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
|
| 89 |
+
|
| 90 |
+
# TODO: color space transform
|
| 91 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 92 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
| 93 |
+
# normalize
|
| 94 |
+
if self.mean is not None or self.std is not None:
|
| 95 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
| 96 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
| 97 |
+
|
| 98 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.paths)
|
basicsr/data/prefetch_dataloader.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import queue as Queue
|
| 2 |
+
import threading
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PrefetchGenerator(threading.Thread):
|
| 8 |
+
"""A general prefetch generator.
|
| 9 |
+
|
| 10 |
+
Ref:
|
| 11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
generator: Python generator.
|
| 15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, generator, num_prefetch_queue):
|
| 19 |
+
threading.Thread.__init__(self)
|
| 20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
| 21 |
+
self.generator = generator
|
| 22 |
+
self.daemon = True
|
| 23 |
+
self.start()
|
| 24 |
+
|
| 25 |
+
def run(self):
|
| 26 |
+
for item in self.generator:
|
| 27 |
+
self.queue.put(item)
|
| 28 |
+
self.queue.put(None)
|
| 29 |
+
|
| 30 |
+
def __next__(self):
|
| 31 |
+
next_item = self.queue.get()
|
| 32 |
+
if next_item is None:
|
| 33 |
+
raise StopIteration
|
| 34 |
+
return next_item
|
| 35 |
+
|
| 36 |
+
def __iter__(self):
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PrefetchDataLoader(DataLoader):
|
| 41 |
+
"""Prefetch version of dataloader.
|
| 42 |
+
|
| 43 |
+
Ref:
|
| 44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
| 45 |
+
|
| 46 |
+
TODO:
|
| 47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
| 48 |
+
ddp.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
| 52 |
+
kwargs (dict): Other arguments for dataloader.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
| 56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
| 57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
| 58 |
+
|
| 59 |
+
def __iter__(self):
|
| 60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class CPUPrefetcher():
|
| 64 |
+
"""CPU prefetcher.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
loader: Dataloader.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, loader):
|
| 71 |
+
self.ori_loader = loader
|
| 72 |
+
self.loader = iter(loader)
|
| 73 |
+
|
| 74 |
+
def next(self):
|
| 75 |
+
try:
|
| 76 |
+
return next(self.loader)
|
| 77 |
+
except StopIteration:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def reset(self):
|
| 81 |
+
self.loader = iter(self.ori_loader)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CUDAPrefetcher():
|
| 85 |
+
"""CUDA prefetcher.
|
| 86 |
+
|
| 87 |
+
Ref:
|
| 88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
| 89 |
+
|
| 90 |
+
It may consums more GPU memory.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
loader: Dataloader.
|
| 94 |
+
opt (dict): Options.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, loader, opt):
|
| 98 |
+
self.ori_loader = loader
|
| 99 |
+
self.loader = iter(loader)
|
| 100 |
+
self.opt = opt
|
| 101 |
+
self.stream = torch.cuda.Stream()
|
| 102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 103 |
+
self.preload()
|
| 104 |
+
|
| 105 |
+
def preload(self):
|
| 106 |
+
try:
|
| 107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
| 108 |
+
except StopIteration:
|
| 109 |
+
self.batch = None
|
| 110 |
+
return None
|
| 111 |
+
# put tensors to gpu
|
| 112 |
+
with torch.cuda.stream(self.stream):
|
| 113 |
+
for k, v in self.batch.items():
|
| 114 |
+
if torch.is_tensor(v):
|
| 115 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
| 116 |
+
|
| 117 |
+
def next(self):
|
| 118 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
| 119 |
+
batch = self.batch
|
| 120 |
+
self.preload()
|
| 121 |
+
return batch
|
| 122 |
+
|
| 123 |
+
def reset(self):
|
| 124 |
+
self.loader = iter(self.ori_loader)
|
| 125 |
+
self.preload()
|
basicsr/data/transforms.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def mod_crop(img, scale):
|
| 6 |
+
"""Mod crop images, used during testing.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
img (ndarray): Input image.
|
| 10 |
+
scale (int): Scale factor.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
ndarray: Result image.
|
| 14 |
+
"""
|
| 15 |
+
img = img.copy()
|
| 16 |
+
if img.ndim in (2, 3):
|
| 17 |
+
h, w = img.shape[0], img.shape[1]
|
| 18 |
+
h_remainder, w_remainder = h % scale, w % scale
|
| 19 |
+
img = img[:h - h_remainder, :w - w_remainder, ...]
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
| 22 |
+
return img
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
|
| 26 |
+
"""Paired random crop.
|
| 27 |
+
|
| 28 |
+
It crops lists of lq and gt images with corresponding locations.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_gts (list[ndarray] | ndarray): GT images. Note that all images
|
| 32 |
+
should have the same shape. If the input is an ndarray, it will
|
| 33 |
+
be transformed to a list containing itself.
|
| 34 |
+
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
|
| 35 |
+
should have the same shape. If the input is an ndarray, it will
|
| 36 |
+
be transformed to a list containing itself.
|
| 37 |
+
gt_patch_size (int): GT patch size.
|
| 38 |
+
scale (int): Scale factor.
|
| 39 |
+
gt_path (str): Path to ground-truth.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
list[ndarray] | ndarray: GT images and LQ images. If returned results
|
| 43 |
+
only have one element, just return ndarray.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
if not isinstance(img_gts, list):
|
| 47 |
+
img_gts = [img_gts]
|
| 48 |
+
if not isinstance(img_lqs, list):
|
| 49 |
+
img_lqs = [img_lqs]
|
| 50 |
+
|
| 51 |
+
h_lq, w_lq, _ = img_lqs[0].shape
|
| 52 |
+
h_gt, w_gt, _ = img_gts[0].shape
|
| 53 |
+
lq_patch_size = gt_patch_size // scale
|
| 54 |
+
|
| 55 |
+
if h_gt != h_lq * scale or w_gt != w_lq * scale:
|
| 56 |
+
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
|
| 57 |
+
f'multiplication of LQ ({h_lq}, {w_lq}).')
|
| 58 |
+
if h_lq < lq_patch_size or w_lq < lq_patch_size:
|
| 59 |
+
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
|
| 60 |
+
f'({lq_patch_size}, {lq_patch_size}). '
|
| 61 |
+
f'Please remove {gt_path}.')
|
| 62 |
+
|
| 63 |
+
# randomly choose top and left coordinates for lq patch
|
| 64 |
+
top = random.randint(0, h_lq - lq_patch_size)
|
| 65 |
+
left = random.randint(0, w_lq - lq_patch_size)
|
| 66 |
+
|
| 67 |
+
# crop lq patch
|
| 68 |
+
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
|
| 69 |
+
|
| 70 |
+
# crop corresponding gt patch
|
| 71 |
+
top_gt, left_gt = int(top * scale), int(left * scale)
|
| 72 |
+
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
|
| 73 |
+
if len(img_gts) == 1:
|
| 74 |
+
img_gts = img_gts[0]
|
| 75 |
+
if len(img_lqs) == 1:
|
| 76 |
+
img_lqs = img_lqs[0]
|
| 77 |
+
return img_gts, img_lqs
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
|
| 81 |
+
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
|
| 82 |
+
|
| 83 |
+
We use vertical flip and transpose for rotation implementation.
|
| 84 |
+
All the images in the list use the same augmentation.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
|
| 88 |
+
is an ndarray, it will be transformed to a list.
|
| 89 |
+
hflip (bool): Horizontal flip. Default: True.
|
| 90 |
+
rotation (bool): Ratotation. Default: True.
|
| 91 |
+
flows (list[ndarray]: Flows to be augmented. If the input is an
|
| 92 |
+
ndarray, it will be transformed to a list.
|
| 93 |
+
Dimension is (h, w, 2). Default: None.
|
| 94 |
+
return_status (bool): Return the status of flip and rotation.
|
| 95 |
+
Default: False.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
list[ndarray] | ndarray: Augmented images and flows. If returned
|
| 99 |
+
results only have one element, just return ndarray.
|
| 100 |
+
|
| 101 |
+
"""
|
| 102 |
+
hflip = hflip and random.random() < 0.5
|
| 103 |
+
vflip = rotation and random.random() < 0.5
|
| 104 |
+
rot90 = rotation and random.random() < 0.5
|
| 105 |
+
|
| 106 |
+
def _augment(img):
|
| 107 |
+
if hflip: # horizontal
|
| 108 |
+
cv2.flip(img, 1, img)
|
| 109 |
+
if vflip: # vertical
|
| 110 |
+
cv2.flip(img, 0, img)
|
| 111 |
+
if rot90:
|
| 112 |
+
img = img.transpose(1, 0, 2)
|
| 113 |
+
return img
|
| 114 |
+
|
| 115 |
+
def _augment_flow(flow):
|
| 116 |
+
if hflip: # horizontal
|
| 117 |
+
cv2.flip(flow, 1, flow)
|
| 118 |
+
flow[:, :, 0] *= -1
|
| 119 |
+
if vflip: # vertical
|
| 120 |
+
cv2.flip(flow, 0, flow)
|
| 121 |
+
flow[:, :, 1] *= -1
|
| 122 |
+
if rot90:
|
| 123 |
+
flow = flow.transpose(1, 0, 2)
|
| 124 |
+
flow = flow[:, :, [1, 0]]
|
| 125 |
+
return flow
|
| 126 |
+
|
| 127 |
+
if not isinstance(imgs, list):
|
| 128 |
+
imgs = [imgs]
|
| 129 |
+
imgs = [_augment(img) for img in imgs]
|
| 130 |
+
if len(imgs) == 1:
|
| 131 |
+
imgs = imgs[0]
|
| 132 |
+
|
| 133 |
+
if flows is not None:
|
| 134 |
+
if not isinstance(flows, list):
|
| 135 |
+
flows = [flows]
|
| 136 |
+
flows = [_augment_flow(flow) for flow in flows]
|
| 137 |
+
if len(flows) == 1:
|
| 138 |
+
flows = flows[0]
|
| 139 |
+
return imgs, flows
|
| 140 |
+
else:
|
| 141 |
+
if return_status:
|
| 142 |
+
return imgs, (hflip, vflip, rot90)
|
| 143 |
+
else:
|
| 144 |
+
return imgs
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def img_rotate(img, angle, center=None, scale=1.0):
|
| 148 |
+
"""Rotate image.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
img (ndarray): Image to be rotated.
|
| 152 |
+
angle (float): Rotation angle in degrees. Positive values mean
|
| 153 |
+
counter-clockwise rotation.
|
| 154 |
+
center (tuple[int]): Rotation center. If the center is None,
|
| 155 |
+
initialize it as the center of the image. Default: None.
|
| 156 |
+
scale (float): Isotropic scale factor. Default: 1.0.
|
| 157 |
+
"""
|
| 158 |
+
(h, w) = img.shape[:2]
|
| 159 |
+
|
| 160 |
+
if center is None:
|
| 161 |
+
center = (w // 2, h // 2)
|
| 162 |
+
|
| 163 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 164 |
+
rotated_img = cv2.warpAffine(img, matrix, (w, h))
|
| 165 |
+
return rotated_img
|
basicsr/losses/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
from basicsr.utils import get_root_logger
|
| 4 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
| 5 |
+
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
|
| 6 |
+
gradient_penalty_loss, r1_penalty)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
|
| 10 |
+
'r1_penalty', 'g_path_regularize'
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_loss(opt):
|
| 15 |
+
"""Build loss from options.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
opt (dict): Configuration. It must constain:
|
| 19 |
+
type (str): Model type.
|
| 20 |
+
"""
|
| 21 |
+
opt = deepcopy(opt)
|
| 22 |
+
loss_type = opt.pop('type')
|
| 23 |
+
loss = LOSS_REGISTRY.get(loss_type)(**opt)
|
| 24 |
+
logger = get_root_logger()
|
| 25 |
+
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
|
| 26 |
+
return loss
|
basicsr/losses/loss_util.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def reduce_loss(loss, reduction):
|
| 6 |
+
"""Reduce loss as specified.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
loss (Tensor): Elementwise loss tensor.
|
| 10 |
+
reduction (str): Options are 'none', 'mean' and 'sum'.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
Tensor: Reduced loss tensor.
|
| 14 |
+
"""
|
| 15 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
| 16 |
+
# none: 0, elementwise_mean:1, sum: 2
|
| 17 |
+
if reduction_enum == 0:
|
| 18 |
+
return loss
|
| 19 |
+
elif reduction_enum == 1:
|
| 20 |
+
return loss.mean()
|
| 21 |
+
else:
|
| 22 |
+
return loss.sum()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean'):
|
| 26 |
+
"""Apply element-wise weight and reduce loss.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
loss (Tensor): Element-wise loss.
|
| 30 |
+
weight (Tensor): Element-wise weights. Default: None.
|
| 31 |
+
reduction (str): Same as built-in losses of PyTorch. Options are
|
| 32 |
+
'none', 'mean' and 'sum'. Default: 'mean'.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Tensor: Loss values.
|
| 36 |
+
"""
|
| 37 |
+
# if weight is specified, apply element-wise weight
|
| 38 |
+
if weight is not None:
|
| 39 |
+
assert weight.dim() == loss.dim()
|
| 40 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
| 41 |
+
loss = loss * weight
|
| 42 |
+
|
| 43 |
+
# if weight is not specified or reduction is sum, just reduce the loss
|
| 44 |
+
if weight is None or reduction == 'sum':
|
| 45 |
+
loss = reduce_loss(loss, reduction)
|
| 46 |
+
# if reduction is mean, then compute mean over weight region
|
| 47 |
+
elif reduction == 'mean':
|
| 48 |
+
if weight.size(1) > 1:
|
| 49 |
+
weight = weight.sum()
|
| 50 |
+
else:
|
| 51 |
+
weight = weight.sum() * loss.size(1)
|
| 52 |
+
loss = loss.sum() / weight
|
| 53 |
+
|
| 54 |
+
return loss
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def weighted_loss(loss_func):
|
| 58 |
+
"""Create a weighted version of a given loss function.
|
| 59 |
+
|
| 60 |
+
To use this decorator, the loss function must have the signature like
|
| 61 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
| 62 |
+
element-wise loss without any reduction. This decorator will add weight
|
| 63 |
+
and reduction arguments to the function. The decorated function will have
|
| 64 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
| 65 |
+
**kwargs)`.
|
| 66 |
+
|
| 67 |
+
:Example:
|
| 68 |
+
|
| 69 |
+
>>> import torch
|
| 70 |
+
>>> @weighted_loss
|
| 71 |
+
>>> def l1_loss(pred, target):
|
| 72 |
+
>>> return (pred - target).abs()
|
| 73 |
+
|
| 74 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
| 75 |
+
>>> target = torch.Tensor([1, 1, 1])
|
| 76 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
| 77 |
+
|
| 78 |
+
>>> l1_loss(pred, target)
|
| 79 |
+
tensor(1.3333)
|
| 80 |
+
>>> l1_loss(pred, target, weight)
|
| 81 |
+
tensor(1.5000)
|
| 82 |
+
>>> l1_loss(pred, target, reduction='none')
|
| 83 |
+
tensor([1., 1., 2.])
|
| 84 |
+
>>> l1_loss(pred, target, weight, reduction='sum')
|
| 85 |
+
tensor(3.)
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
@functools.wraps(loss_func)
|
| 89 |
+
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
|
| 90 |
+
# get element-wise loss
|
| 91 |
+
loss = loss_func(pred, target, **kwargs)
|
| 92 |
+
loss = weight_reduce_loss(loss, weight, reduction)
|
| 93 |
+
return loss
|
| 94 |
+
|
| 95 |
+
return wrapper
|
basicsr/losses/losses.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import lpips
|
| 3 |
+
import torch
|
| 4 |
+
from torch import autograd as autograd
|
| 5 |
+
from torch import nn as nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from basicsr.archs.vgg_arch import VGGFeatureExtractor
|
| 9 |
+
from basicsr.utils.registry import LOSS_REGISTRY
|
| 10 |
+
from .loss_util import weighted_loss
|
| 11 |
+
|
| 12 |
+
_reduction_modes = ['none', 'mean', 'sum']
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@weighted_loss
|
| 16 |
+
def l1_loss(pred, target):
|
| 17 |
+
return F.l1_loss(pred, target, reduction='none')
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@weighted_loss
|
| 21 |
+
def mse_loss(pred, target):
|
| 22 |
+
return F.mse_loss(pred, target, reduction='none')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@weighted_loss
|
| 26 |
+
def charbonnier_loss(pred, target, eps=1e-12):
|
| 27 |
+
return torch.sqrt((pred - target)**2 + eps)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@LOSS_REGISTRY.register()
|
| 31 |
+
class L1Loss(nn.Module):
|
| 32 |
+
"""L1 (mean absolute error, MAE) loss.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
| 36 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 37 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 41 |
+
super(L1Loss, self).__init__()
|
| 42 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 43 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
| 44 |
+
|
| 45 |
+
self.loss_weight = loss_weight
|
| 46 |
+
self.reduction = reduction
|
| 47 |
+
|
| 48 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 49 |
+
"""
|
| 50 |
+
Args:
|
| 51 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 52 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 53 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 54 |
+
weights. Default: None.
|
| 55 |
+
"""
|
| 56 |
+
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@LOSS_REGISTRY.register()
|
| 60 |
+
class MSELoss(nn.Module):
|
| 61 |
+
"""MSE (L2) loss.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
|
| 65 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 66 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, loss_weight=1.0, reduction='mean'):
|
| 70 |
+
super(MSELoss, self).__init__()
|
| 71 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 72 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
| 73 |
+
|
| 74 |
+
self.loss_weight = loss_weight
|
| 75 |
+
self.reduction = reduction
|
| 76 |
+
|
| 77 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 78 |
+
"""
|
| 79 |
+
Args:
|
| 80 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 81 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 82 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 83 |
+
weights. Default: None.
|
| 84 |
+
"""
|
| 85 |
+
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@LOSS_REGISTRY.register()
|
| 89 |
+
class CharbonnierLoss(nn.Module):
|
| 90 |
+
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
|
| 91 |
+
variant of L1Loss).
|
| 92 |
+
|
| 93 |
+
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
|
| 94 |
+
Super-Resolution".
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
|
| 98 |
+
reduction (str): Specifies the reduction to apply to the output.
|
| 99 |
+
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
|
| 100 |
+
eps (float): A value used to control the curvature near zero.
|
| 101 |
+
Default: 1e-12.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
|
| 105 |
+
super(CharbonnierLoss, self).__init__()
|
| 106 |
+
if reduction not in ['none', 'mean', 'sum']:
|
| 107 |
+
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
|
| 108 |
+
|
| 109 |
+
self.loss_weight = loss_weight
|
| 110 |
+
self.reduction = reduction
|
| 111 |
+
self.eps = eps
|
| 112 |
+
|
| 113 |
+
def forward(self, pred, target, weight=None, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
Args:
|
| 116 |
+
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
|
| 117 |
+
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
|
| 118 |
+
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
|
| 119 |
+
weights. Default: None.
|
| 120 |
+
"""
|
| 121 |
+
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@LOSS_REGISTRY.register()
|
| 125 |
+
class WeightedTVLoss(L1Loss):
|
| 126 |
+
"""Weighted TV loss.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, loss_weight=1.0):
|
| 133 |
+
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
|
| 134 |
+
|
| 135 |
+
def forward(self, pred, weight=None):
|
| 136 |
+
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
|
| 137 |
+
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
|
| 138 |
+
|
| 139 |
+
loss = x_diff + y_diff
|
| 140 |
+
|
| 141 |
+
return loss
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@LOSS_REGISTRY.register()
|
| 145 |
+
class PerceptualLoss(nn.Module):
|
| 146 |
+
"""Perceptual loss with commonly used style loss.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
| 150 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
| 151 |
+
feature layer (before relu5_4) will be extracted with weight
|
| 152 |
+
1.0 in calculting losses.
|
| 153 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
| 154 |
+
Default: 'vgg19'.
|
| 155 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
| 156 |
+
Default: True.
|
| 157 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
| 158 |
+
Default: False.
|
| 159 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
| 160 |
+
loss will be calculated and the loss will multiplied by the
|
| 161 |
+
weight. Default: 1.0.
|
| 162 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
| 163 |
+
calculated and the loss will multiplied by the weight.
|
| 164 |
+
Default: 0.
|
| 165 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def __init__(self,
|
| 169 |
+
layer_weights,
|
| 170 |
+
vgg_type='vgg19',
|
| 171 |
+
use_input_norm=True,
|
| 172 |
+
range_norm=False,
|
| 173 |
+
perceptual_weight=1.0,
|
| 174 |
+
style_weight=0.,
|
| 175 |
+
criterion='l1'):
|
| 176 |
+
super(PerceptualLoss, self).__init__()
|
| 177 |
+
self.perceptual_weight = perceptual_weight
|
| 178 |
+
self.style_weight = style_weight
|
| 179 |
+
self.layer_weights = layer_weights
|
| 180 |
+
self.vgg = VGGFeatureExtractor(
|
| 181 |
+
layer_name_list=list(layer_weights.keys()),
|
| 182 |
+
vgg_type=vgg_type,
|
| 183 |
+
use_input_norm=use_input_norm,
|
| 184 |
+
range_norm=range_norm)
|
| 185 |
+
|
| 186 |
+
self.criterion_type = criterion
|
| 187 |
+
if self.criterion_type == 'l1':
|
| 188 |
+
self.criterion = torch.nn.L1Loss()
|
| 189 |
+
elif self.criterion_type == 'l2':
|
| 190 |
+
self.criterion = torch.nn.L2loss()
|
| 191 |
+
elif self.criterion_type == 'mse':
|
| 192 |
+
self.criterion = torch.nn.MSELoss(reduction='mean')
|
| 193 |
+
elif self.criterion_type == 'fro':
|
| 194 |
+
self.criterion = None
|
| 195 |
+
else:
|
| 196 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
| 197 |
+
|
| 198 |
+
def forward(self, x, gt):
|
| 199 |
+
"""Forward function.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
| 203 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tensor: Forward results.
|
| 207 |
+
"""
|
| 208 |
+
# extract vgg features
|
| 209 |
+
x_features = self.vgg(x)
|
| 210 |
+
gt_features = self.vgg(gt.detach())
|
| 211 |
+
|
| 212 |
+
# calculate perceptual loss
|
| 213 |
+
if self.perceptual_weight > 0:
|
| 214 |
+
percep_loss = 0
|
| 215 |
+
for k in x_features.keys():
|
| 216 |
+
if self.criterion_type == 'fro':
|
| 217 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
| 218 |
+
else:
|
| 219 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
| 220 |
+
percep_loss *= self.perceptual_weight
|
| 221 |
+
else:
|
| 222 |
+
percep_loss = None
|
| 223 |
+
|
| 224 |
+
# calculate style loss
|
| 225 |
+
if self.style_weight > 0:
|
| 226 |
+
style_loss = 0
|
| 227 |
+
for k in x_features.keys():
|
| 228 |
+
if self.criterion_type == 'fro':
|
| 229 |
+
style_loss += torch.norm(
|
| 230 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
| 231 |
+
else:
|
| 232 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
| 233 |
+
gt_features[k])) * self.layer_weights[k]
|
| 234 |
+
style_loss *= self.style_weight
|
| 235 |
+
else:
|
| 236 |
+
style_loss = None
|
| 237 |
+
|
| 238 |
+
return percep_loss, style_loss
|
| 239 |
+
|
| 240 |
+
def _gram_mat(self, x):
|
| 241 |
+
"""Calculate Gram matrix.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
torch.Tensor: Gram matrix.
|
| 248 |
+
"""
|
| 249 |
+
n, c, h, w = x.size()
|
| 250 |
+
features = x.view(n, c, w * h)
|
| 251 |
+
features_t = features.transpose(1, 2)
|
| 252 |
+
gram = features.bmm(features_t) / (c * h * w)
|
| 253 |
+
return gram
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@LOSS_REGISTRY.register()
|
| 257 |
+
class LPIPSLoss(nn.Module):
|
| 258 |
+
def __init__(self,
|
| 259 |
+
loss_weight=1.0,
|
| 260 |
+
use_input_norm=True,
|
| 261 |
+
range_norm=False,):
|
| 262 |
+
super(LPIPSLoss, self).__init__()
|
| 263 |
+
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
|
| 264 |
+
self.loss_weight = loss_weight
|
| 265 |
+
self.use_input_norm = use_input_norm
|
| 266 |
+
self.range_norm = range_norm
|
| 267 |
+
|
| 268 |
+
if self.use_input_norm:
|
| 269 |
+
# the mean is for image with range [0, 1]
|
| 270 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 271 |
+
# the std is for image with range [0, 1]
|
| 272 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 273 |
+
|
| 274 |
+
def forward(self, pred, target):
|
| 275 |
+
if self.range_norm:
|
| 276 |
+
pred = (pred + 1) / 2
|
| 277 |
+
target = (target + 1) / 2
|
| 278 |
+
if self.use_input_norm:
|
| 279 |
+
pred = (pred - self.mean) / self.std
|
| 280 |
+
target = (target - self.mean) / self.std
|
| 281 |
+
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
|
| 282 |
+
return self.loss_weight * lpips_loss.mean()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
@LOSS_REGISTRY.register()
|
| 286 |
+
class GANLoss(nn.Module):
|
| 287 |
+
"""Define GAN loss.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
|
| 291 |
+
real_label_val (float): The value for real label. Default: 1.0.
|
| 292 |
+
fake_label_val (float): The value for fake label. Default: 0.0.
|
| 293 |
+
loss_weight (float): Loss weight. Default: 1.0.
|
| 294 |
+
Note that loss_weight is only for generators; and it is always 1.0
|
| 295 |
+
for discriminators.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
|
| 299 |
+
super(GANLoss, self).__init__()
|
| 300 |
+
self.gan_type = gan_type
|
| 301 |
+
self.loss_weight = loss_weight
|
| 302 |
+
self.real_label_val = real_label_val
|
| 303 |
+
self.fake_label_val = fake_label_val
|
| 304 |
+
|
| 305 |
+
if self.gan_type == 'vanilla':
|
| 306 |
+
self.loss = nn.BCEWithLogitsLoss()
|
| 307 |
+
elif self.gan_type == 'lsgan':
|
| 308 |
+
self.loss = nn.MSELoss()
|
| 309 |
+
elif self.gan_type == 'wgan':
|
| 310 |
+
self.loss = self._wgan_loss
|
| 311 |
+
elif self.gan_type == 'wgan_softplus':
|
| 312 |
+
self.loss = self._wgan_softplus_loss
|
| 313 |
+
elif self.gan_type == 'hinge':
|
| 314 |
+
self.loss = nn.ReLU()
|
| 315 |
+
else:
|
| 316 |
+
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
|
| 317 |
+
|
| 318 |
+
def _wgan_loss(self, input, target):
|
| 319 |
+
"""wgan loss.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
input (Tensor): Input tensor.
|
| 323 |
+
target (bool): Target label.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
Tensor: wgan loss.
|
| 327 |
+
"""
|
| 328 |
+
return -input.mean() if target else input.mean()
|
| 329 |
+
|
| 330 |
+
def _wgan_softplus_loss(self, input, target):
|
| 331 |
+
"""wgan loss with soft plus. softplus is a smooth approximation to the
|
| 332 |
+
ReLU function.
|
| 333 |
+
|
| 334 |
+
In StyleGAN2, it is called:
|
| 335 |
+
Logistic loss for discriminator;
|
| 336 |
+
Non-saturating loss for generator.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
input (Tensor): Input tensor.
|
| 340 |
+
target (bool): Target label.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
Tensor: wgan loss.
|
| 344 |
+
"""
|
| 345 |
+
return F.softplus(-input).mean() if target else F.softplus(input).mean()
|
| 346 |
+
|
| 347 |
+
def get_target_label(self, input, target_is_real):
|
| 348 |
+
"""Get target label.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
input (Tensor): Input tensor.
|
| 352 |
+
target_is_real (bool): Whether the target is real or fake.
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
|
| 356 |
+
return Tensor.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
if self.gan_type in ['wgan', 'wgan_softplus']:
|
| 360 |
+
return target_is_real
|
| 361 |
+
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
|
| 362 |
+
return input.new_ones(input.size()) * target_val
|
| 363 |
+
|
| 364 |
+
def forward(self, input, target_is_real, is_disc=False):
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
input (Tensor): The input for the loss module, i.e., the network
|
| 368 |
+
prediction.
|
| 369 |
+
target_is_real (bool): Whether the targe is real or fake.
|
| 370 |
+
is_disc (bool): Whether the loss for discriminators or not.
|
| 371 |
+
Default: False.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Tensor: GAN loss value.
|
| 375 |
+
"""
|
| 376 |
+
if self.gan_type == 'hinge':
|
| 377 |
+
if is_disc: # for discriminators in hinge-gan
|
| 378 |
+
input = -input if target_is_real else input
|
| 379 |
+
loss = self.loss(1 + input).mean()
|
| 380 |
+
else: # for generators in hinge-gan
|
| 381 |
+
loss = -input.mean()
|
| 382 |
+
else: # other gan types
|
| 383 |
+
target_label = self.get_target_label(input, target_is_real)
|
| 384 |
+
loss = self.loss(input, target_label)
|
| 385 |
+
|
| 386 |
+
# loss_weight is always 1.0 for discriminators
|
| 387 |
+
return loss if is_disc else loss * self.loss_weight
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def r1_penalty(real_pred, real_img):
|
| 391 |
+
"""R1 regularization for discriminator. The core idea is to
|
| 392 |
+
penalize the gradient on real data alone: when the
|
| 393 |
+
generator distribution produces the true data distribution
|
| 394 |
+
and the discriminator is equal to 0 on the data manifold, the
|
| 395 |
+
gradient penalty ensures that the discriminator cannot create
|
| 396 |
+
a non-zero gradient orthogonal to the data manifold without
|
| 397 |
+
suffering a loss in the GAN game.
|
| 398 |
+
|
| 399 |
+
Ref:
|
| 400 |
+
Eq. 9 in Which training methods for GANs do actually converge.
|
| 401 |
+
"""
|
| 402 |
+
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
|
| 403 |
+
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
|
| 404 |
+
return grad_penalty
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
|
| 408 |
+
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
|
| 409 |
+
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
|
| 410 |
+
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
|
| 411 |
+
|
| 412 |
+
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
|
| 413 |
+
|
| 414 |
+
path_penalty = (path_lengths - path_mean).pow(2).mean()
|
| 415 |
+
|
| 416 |
+
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
|
| 420 |
+
"""Calculate gradient penalty for wgan-gp.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
discriminator (nn.Module): Network for the discriminator.
|
| 424 |
+
real_data (Tensor): Real input data.
|
| 425 |
+
fake_data (Tensor): Fake input data.
|
| 426 |
+
weight (Tensor): Weight tensor. Default: None.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tensor: A tensor for gradient penalty.
|
| 430 |
+
"""
|
| 431 |
+
|
| 432 |
+
batch_size = real_data.size(0)
|
| 433 |
+
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
|
| 434 |
+
|
| 435 |
+
# interpolate between real_data and fake_data
|
| 436 |
+
interpolates = alpha * real_data + (1. - alpha) * fake_data
|
| 437 |
+
interpolates = autograd.Variable(interpolates, requires_grad=True)
|
| 438 |
+
|
| 439 |
+
disc_interpolates = discriminator(interpolates)
|
| 440 |
+
gradients = autograd.grad(
|
| 441 |
+
outputs=disc_interpolates,
|
| 442 |
+
inputs=interpolates,
|
| 443 |
+
grad_outputs=torch.ones_like(disc_interpolates),
|
| 444 |
+
create_graph=True,
|
| 445 |
+
retain_graph=True,
|
| 446 |
+
only_inputs=True)[0]
|
| 447 |
+
|
| 448 |
+
if weight is not None:
|
| 449 |
+
gradients = gradients * weight
|
| 450 |
+
|
| 451 |
+
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
|
| 452 |
+
if weight is not None:
|
| 453 |
+
gradients_penalty /= torch.mean(weight)
|
| 454 |
+
|
| 455 |
+
return gradients_penalty
|
basicsr/metrics/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
| 4 |
+
from .psnr_ssim import calculate_psnr, calculate_ssim
|
| 5 |
+
|
| 6 |
+
__all__ = ['calculate_psnr', 'calculate_ssim']
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def calculate_metric(data, opt):
|
| 10 |
+
"""Calculate metric from data and options.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
opt (dict): Configuration. It must constain:
|
| 14 |
+
type (str): Model type.
|
| 15 |
+
"""
|
| 16 |
+
opt = deepcopy(opt)
|
| 17 |
+
metric_type = opt.pop('type')
|
| 18 |
+
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
|
| 19 |
+
return metric
|
basicsr/metrics/metric_util.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from basicsr.utils.matlab_functions import bgr2ycbcr
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def reorder_image(img, input_order='HWC'):
|
| 7 |
+
"""Reorder images to 'HWC' order.
|
| 8 |
+
|
| 9 |
+
If the input_order is (h, w), return (h, w, 1);
|
| 10 |
+
If the input_order is (c, h, w), return (h, w, c);
|
| 11 |
+
If the input_order is (h, w, c), return as it is.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
img (ndarray): Input image.
|
| 15 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 16 |
+
If the input image shape is (h, w), input_order will not have
|
| 17 |
+
effects. Default: 'HWC'.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
ndarray: reordered image.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
if input_order not in ['HWC', 'CHW']:
|
| 24 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
|
| 25 |
+
if len(img.shape) == 2:
|
| 26 |
+
img = img[..., None]
|
| 27 |
+
if input_order == 'CHW':
|
| 28 |
+
img = img.transpose(1, 2, 0)
|
| 29 |
+
return img
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def to_y_channel(img):
|
| 33 |
+
"""Change to Y channel of YCbCr.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
img (ndarray): Images with range [0, 255].
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
(ndarray): Images with range [0, 255] (float type) without round.
|
| 40 |
+
"""
|
| 41 |
+
img = img.astype(np.float32) / 255.
|
| 42 |
+
if img.ndim == 3 and img.shape[2] == 3:
|
| 43 |
+
img = bgr2ycbcr(img, y_only=True)
|
| 44 |
+
img = img[..., None]
|
| 45 |
+
return img * 255.
|
basicsr/metrics/psnr_ssim.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from basicsr.metrics.metric_util import reorder_image, to_y_channel
|
| 5 |
+
from basicsr.utils.registry import METRIC_REGISTRY
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@METRIC_REGISTRY.register()
|
| 9 |
+
def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
|
| 10 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
| 11 |
+
|
| 12 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
img1 (ndarray): Images with range [0, 255].
|
| 16 |
+
img2 (ndarray): Images with range [0, 255].
|
| 17 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 18 |
+
pixels are not involved in the PSNR calculation.
|
| 19 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 20 |
+
Default: 'HWC'.
|
| 21 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
float: psnr result.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 28 |
+
if input_order not in ['HWC', 'CHW']:
|
| 29 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
| 30 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 31 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 32 |
+
img1 = img1.astype(np.float64)
|
| 33 |
+
img2 = img2.astype(np.float64)
|
| 34 |
+
|
| 35 |
+
if crop_border != 0:
|
| 36 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 37 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 38 |
+
|
| 39 |
+
if test_y_channel:
|
| 40 |
+
img1 = to_y_channel(img1)
|
| 41 |
+
img2 = to_y_channel(img2)
|
| 42 |
+
|
| 43 |
+
mse = np.mean((img1 - img2)**2)
|
| 44 |
+
if mse == 0:
|
| 45 |
+
return float('inf')
|
| 46 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _ssim(img1, img2):
|
| 50 |
+
"""Calculate SSIM (structural similarity) for one channel images.
|
| 51 |
+
|
| 52 |
+
It is called by func:`calculate_ssim`.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 56 |
+
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
float: ssim result.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
C1 = (0.01 * 255)**2
|
| 63 |
+
C2 = (0.03 * 255)**2
|
| 64 |
+
|
| 65 |
+
img1 = img1.astype(np.float64)
|
| 66 |
+
img2 = img2.astype(np.float64)
|
| 67 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 68 |
+
window = np.outer(kernel, kernel.transpose())
|
| 69 |
+
|
| 70 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
|
| 71 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 72 |
+
mu1_sq = mu1**2
|
| 73 |
+
mu2_sq = mu2**2
|
| 74 |
+
mu1_mu2 = mu1 * mu2
|
| 75 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 76 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 77 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 78 |
+
|
| 79 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
| 80 |
+
return ssim_map.mean()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@METRIC_REGISTRY.register()
|
| 84 |
+
def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
|
| 85 |
+
"""Calculate SSIM (structural similarity).
|
| 86 |
+
|
| 87 |
+
Ref:
|
| 88 |
+
Image quality assessment: From error visibility to structural similarity
|
| 89 |
+
|
| 90 |
+
The results are the same as that of the official released MATLAB code in
|
| 91 |
+
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
|
| 92 |
+
|
| 93 |
+
For three-channel images, SSIM is calculated for each channel and then
|
| 94 |
+
averaged.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
img1 (ndarray): Images with range [0, 255].
|
| 98 |
+
img2 (ndarray): Images with range [0, 255].
|
| 99 |
+
crop_border (int): Cropped pixels in each edge of an image. These
|
| 100 |
+
pixels are not involved in the SSIM calculation.
|
| 101 |
+
input_order (str): Whether the input order is 'HWC' or 'CHW'.
|
| 102 |
+
Default: 'HWC'.
|
| 103 |
+
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
float: ssim result.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
| 110 |
+
if input_order not in ['HWC', 'CHW']:
|
| 111 |
+
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
|
| 112 |
+
img1 = reorder_image(img1, input_order=input_order)
|
| 113 |
+
img2 = reorder_image(img2, input_order=input_order)
|
| 114 |
+
img1 = img1.astype(np.float64)
|
| 115 |
+
img2 = img2.astype(np.float64)
|
| 116 |
+
|
| 117 |
+
if crop_border != 0:
|
| 118 |
+
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 119 |
+
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
|
| 120 |
+
|
| 121 |
+
if test_y_channel:
|
| 122 |
+
img1 = to_y_channel(img1)
|
| 123 |
+
img2 = to_y_channel(img2)
|
| 124 |
+
|
| 125 |
+
ssims = []
|
| 126 |
+
for i in range(img1.shape[2]):
|
| 127 |
+
ssims.append(_ssim(img1[..., i], img2[..., i]))
|
| 128 |
+
return np.array(ssims).mean()
|
basicsr/models/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
from basicsr.utils import get_root_logger, scandir
|
| 6 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 7 |
+
|
| 8 |
+
__all__ = ['build_model']
|
| 9 |
+
|
| 10 |
+
# automatically scan and import model modules for registry
|
| 11 |
+
# scan all the files under the 'models' folder and collect files ending with
|
| 12 |
+
# '_model.py'
|
| 13 |
+
model_folder = osp.dirname(osp.abspath(__file__))
|
| 14 |
+
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
|
| 15 |
+
# import all the model modules
|
| 16 |
+
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_model(opt):
|
| 20 |
+
"""Build model from options.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
opt (dict): Configuration. It must constain:
|
| 24 |
+
model_type (str): Model type.
|
| 25 |
+
"""
|
| 26 |
+
opt = deepcopy(opt)
|
| 27 |
+
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
|
| 28 |
+
logger = get_root_logger()
|
| 29 |
+
logger.info(f'Model [{model.__class__.__name__}] is created.')
|
| 30 |
+
return model
|
basicsr/models/base_model.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
| 7 |
+
|
| 8 |
+
from basicsr.models import lr_scheduler as lr_scheduler
|
| 9 |
+
from basicsr.utils.dist_util import master_only
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger('basicsr')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseModel():
|
| 15 |
+
"""Base model."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
self.opt = opt
|
| 19 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
| 20 |
+
self.is_train = opt['is_train']
|
| 21 |
+
self.schedulers = []
|
| 22 |
+
self.optimizers = []
|
| 23 |
+
|
| 24 |
+
def feed_data(self, data):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def optimize_parameters(self):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def get_current_visuals(self):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def save(self, epoch, current_iter):
|
| 34 |
+
"""Save networks and training state."""
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
|
| 38 |
+
"""Validation function.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
dataloader (torch.utils.data.DataLoader): Validation dataloader.
|
| 42 |
+
current_iter (int): Current iteration.
|
| 43 |
+
tb_logger (tensorboard logger): Tensorboard logger.
|
| 44 |
+
save_img (bool): Whether to save images. Default: False.
|
| 45 |
+
"""
|
| 46 |
+
if self.opt['dist']:
|
| 47 |
+
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 48 |
+
else:
|
| 49 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 50 |
+
|
| 51 |
+
def model_ema(self, decay=0.999):
|
| 52 |
+
net_g = self.get_bare_model(self.net_g)
|
| 53 |
+
|
| 54 |
+
net_g_params = dict(net_g.named_parameters())
|
| 55 |
+
net_g_ema_params = dict(self.net_g_ema.named_parameters())
|
| 56 |
+
|
| 57 |
+
for k in net_g_ema_params.keys():
|
| 58 |
+
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
|
| 59 |
+
|
| 60 |
+
def get_current_log(self):
|
| 61 |
+
return self.log_dict
|
| 62 |
+
|
| 63 |
+
def model_to_device(self, net):
|
| 64 |
+
"""Model to device. It also warps models with DistributedDataParallel
|
| 65 |
+
or DataParallel.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
net (nn.Module)
|
| 69 |
+
"""
|
| 70 |
+
net = net.to(self.device)
|
| 71 |
+
if self.opt['dist']:
|
| 72 |
+
find_unused_parameters = self.opt.get('find_unused_parameters', False)
|
| 73 |
+
net = DistributedDataParallel(
|
| 74 |
+
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
|
| 75 |
+
elif self.opt['num_gpu'] > 1:
|
| 76 |
+
net = DataParallel(net)
|
| 77 |
+
return net
|
| 78 |
+
|
| 79 |
+
def get_optimizer(self, optim_type, params, lr, **kwargs):
|
| 80 |
+
if optim_type == 'Adam':
|
| 81 |
+
optimizer = torch.optim.Adam(params, lr, **kwargs)
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
|
| 84 |
+
return optimizer
|
| 85 |
+
|
| 86 |
+
def setup_schedulers(self):
|
| 87 |
+
"""Set up schedulers."""
|
| 88 |
+
train_opt = self.opt['train']
|
| 89 |
+
scheduler_type = train_opt['scheduler'].pop('type')
|
| 90 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
| 91 |
+
for optimizer in self.optimizers:
|
| 92 |
+
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
|
| 93 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
| 94 |
+
for optimizer in self.optimizers:
|
| 95 |
+
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
|
| 96 |
+
else:
|
| 97 |
+
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
|
| 98 |
+
|
| 99 |
+
def get_bare_model(self, net):
|
| 100 |
+
"""Get bare model, especially under wrapping with
|
| 101 |
+
DistributedDataParallel or DataParallel.
|
| 102 |
+
"""
|
| 103 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 104 |
+
net = net.module
|
| 105 |
+
return net
|
| 106 |
+
|
| 107 |
+
@master_only
|
| 108 |
+
def print_network(self, net):
|
| 109 |
+
"""Print the str and parameter number of a network.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
net (nn.Module)
|
| 113 |
+
"""
|
| 114 |
+
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
| 115 |
+
net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
|
| 116 |
+
else:
|
| 117 |
+
net_cls_str = f'{net.__class__.__name__}'
|
| 118 |
+
|
| 119 |
+
net = self.get_bare_model(net)
|
| 120 |
+
net_str = str(net)
|
| 121 |
+
net_params = sum(map(lambda x: x.numel(), net.parameters()))
|
| 122 |
+
|
| 123 |
+
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
|
| 124 |
+
logger.info(net_str)
|
| 125 |
+
|
| 126 |
+
def _set_lr(self, lr_groups_l):
|
| 127 |
+
"""Set learning rate for warmup.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
lr_groups_l (list): List for lr_groups, each for an optimizer.
|
| 131 |
+
"""
|
| 132 |
+
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
|
| 133 |
+
for param_group, lr in zip(optimizer.param_groups, lr_groups):
|
| 134 |
+
param_group['lr'] = lr
|
| 135 |
+
|
| 136 |
+
def _get_init_lr(self):
|
| 137 |
+
"""Get the initial lr, which is set by the scheduler.
|
| 138 |
+
"""
|
| 139 |
+
init_lr_groups_l = []
|
| 140 |
+
for optimizer in self.optimizers:
|
| 141 |
+
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
|
| 142 |
+
return init_lr_groups_l
|
| 143 |
+
|
| 144 |
+
def update_learning_rate(self, current_iter, warmup_iter=-1):
|
| 145 |
+
"""Update learning rate.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
current_iter (int): Current iteration.
|
| 149 |
+
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
|
| 150 |
+
Default: -1.
|
| 151 |
+
"""
|
| 152 |
+
if current_iter > 1:
|
| 153 |
+
for scheduler in self.schedulers:
|
| 154 |
+
scheduler.step()
|
| 155 |
+
# set up warm-up learning rate
|
| 156 |
+
if current_iter < warmup_iter:
|
| 157 |
+
# get initial lr for each group
|
| 158 |
+
init_lr_g_l = self._get_init_lr()
|
| 159 |
+
# modify warming-up learning rates
|
| 160 |
+
# currently only support linearly warm up
|
| 161 |
+
warm_up_lr_l = []
|
| 162 |
+
for init_lr_g in init_lr_g_l:
|
| 163 |
+
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
|
| 164 |
+
# set learning rate
|
| 165 |
+
self._set_lr(warm_up_lr_l)
|
| 166 |
+
|
| 167 |
+
def get_current_learning_rate(self):
|
| 168 |
+
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
|
| 169 |
+
|
| 170 |
+
@master_only
|
| 171 |
+
def save_network(self, net, net_label, current_iter, param_key='params'):
|
| 172 |
+
"""Save networks.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
net (nn.Module | list[nn.Module]): Network(s) to be saved.
|
| 176 |
+
net_label (str): Network label.
|
| 177 |
+
current_iter (int): Current iter number.
|
| 178 |
+
param_key (str | list[str]): The parameter key(s) to save network.
|
| 179 |
+
Default: 'params'.
|
| 180 |
+
"""
|
| 181 |
+
if current_iter == -1:
|
| 182 |
+
current_iter = 'latest'
|
| 183 |
+
save_filename = f'{net_label}_{current_iter}.pth'
|
| 184 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
| 185 |
+
|
| 186 |
+
net = net if isinstance(net, list) else [net]
|
| 187 |
+
param_key = param_key if isinstance(param_key, list) else [param_key]
|
| 188 |
+
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
|
| 189 |
+
|
| 190 |
+
save_dict = {}
|
| 191 |
+
for net_, param_key_ in zip(net, param_key):
|
| 192 |
+
net_ = self.get_bare_model(net_)
|
| 193 |
+
state_dict = net_.state_dict()
|
| 194 |
+
for key, param in state_dict.items():
|
| 195 |
+
if key.startswith('module.'): # remove unnecessary 'module.'
|
| 196 |
+
key = key[7:]
|
| 197 |
+
state_dict[key] = param.cpu()
|
| 198 |
+
save_dict[param_key_] = state_dict
|
| 199 |
+
|
| 200 |
+
torch.save(save_dict, save_path)
|
| 201 |
+
|
| 202 |
+
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
|
| 203 |
+
"""Print keys with differnet name or different size when loading models.
|
| 204 |
+
|
| 205 |
+
1. Print keys with differnet names.
|
| 206 |
+
2. If strict=False, print the same key but with different tensor size.
|
| 207 |
+
It also ignore these keys with different sizes (not load).
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
crt_net (torch model): Current network.
|
| 211 |
+
load_net (dict): Loaded network.
|
| 212 |
+
strict (bool): Whether strictly loaded. Default: True.
|
| 213 |
+
"""
|
| 214 |
+
crt_net = self.get_bare_model(crt_net)
|
| 215 |
+
crt_net = crt_net.state_dict()
|
| 216 |
+
crt_net_keys = set(crt_net.keys())
|
| 217 |
+
load_net_keys = set(load_net.keys())
|
| 218 |
+
|
| 219 |
+
if crt_net_keys != load_net_keys:
|
| 220 |
+
logger.warning('Current net - loaded net:')
|
| 221 |
+
for v in sorted(list(crt_net_keys - load_net_keys)):
|
| 222 |
+
logger.warning(f' {v}')
|
| 223 |
+
logger.warning('Loaded net - current net:')
|
| 224 |
+
for v in sorted(list(load_net_keys - crt_net_keys)):
|
| 225 |
+
logger.warning(f' {v}')
|
| 226 |
+
|
| 227 |
+
# check the size for the same keys
|
| 228 |
+
if not strict:
|
| 229 |
+
common_keys = crt_net_keys & load_net_keys
|
| 230 |
+
for k in common_keys:
|
| 231 |
+
if crt_net[k].size() != load_net[k].size():
|
| 232 |
+
logger.warning(f'Size different, ignore [{k}]: crt_net: '
|
| 233 |
+
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
|
| 234 |
+
load_net[k + '.ignore'] = load_net.pop(k)
|
| 235 |
+
|
| 236 |
+
def load_network(self, net, load_path, strict=True, param_key='params'):
|
| 237 |
+
"""Load network.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
load_path (str): The path of networks to be loaded.
|
| 241 |
+
net (nn.Module): Network.
|
| 242 |
+
strict (bool): Whether strictly loaded.
|
| 243 |
+
param_key (str): The parameter key of loaded network. If set to
|
| 244 |
+
None, use the root 'path'.
|
| 245 |
+
Default: 'params'.
|
| 246 |
+
"""
|
| 247 |
+
net = self.get_bare_model(net)
|
| 248 |
+
logger.info(f'Loading {net.__class__.__name__} model from {load_path}.')
|
| 249 |
+
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
|
| 250 |
+
if param_key is not None:
|
| 251 |
+
if param_key not in load_net and 'params' in load_net:
|
| 252 |
+
param_key = 'params'
|
| 253 |
+
logger.info('Loading: params_ema does not exist, use params.')
|
| 254 |
+
load_net = load_net[param_key]
|
| 255 |
+
# remove unnecessary 'module.'
|
| 256 |
+
for k, v in deepcopy(load_net).items():
|
| 257 |
+
if k.startswith('module.'):
|
| 258 |
+
load_net[k[7:]] = v
|
| 259 |
+
load_net.pop(k)
|
| 260 |
+
self._print_different_keys_loading(net, load_net, strict)
|
| 261 |
+
net.load_state_dict(load_net, strict=strict)
|
| 262 |
+
|
| 263 |
+
@master_only
|
| 264 |
+
def save_training_state(self, epoch, current_iter):
|
| 265 |
+
"""Save training states during training, which will be used for
|
| 266 |
+
resuming.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
epoch (int): Current epoch.
|
| 270 |
+
current_iter (int): Current iteration.
|
| 271 |
+
"""
|
| 272 |
+
if current_iter != -1:
|
| 273 |
+
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
|
| 274 |
+
for o in self.optimizers:
|
| 275 |
+
state['optimizers'].append(o.state_dict())
|
| 276 |
+
for s in self.schedulers:
|
| 277 |
+
state['schedulers'].append(s.state_dict())
|
| 278 |
+
save_filename = f'{current_iter}.state'
|
| 279 |
+
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
|
| 280 |
+
torch.save(state, save_path)
|
| 281 |
+
|
| 282 |
+
def resume_training(self, resume_state):
|
| 283 |
+
"""Reload the optimizers and schedulers for resumed training.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
resume_state (dict): Resume state.
|
| 287 |
+
"""
|
| 288 |
+
resume_optimizers = resume_state['optimizers']
|
| 289 |
+
resume_schedulers = resume_state['schedulers']
|
| 290 |
+
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
| 291 |
+
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
| 292 |
+
for i, o in enumerate(resume_optimizers):
|
| 293 |
+
self.optimizers[i].load_state_dict(o)
|
| 294 |
+
for i, s in enumerate(resume_schedulers):
|
| 295 |
+
self.schedulers[i].load_state_dict(s)
|
| 296 |
+
|
| 297 |
+
def reduce_loss_dict(self, loss_dict):
|
| 298 |
+
"""reduce loss dict.
|
| 299 |
+
|
| 300 |
+
In distributed training, it averages the losses among different GPUs .
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
loss_dict (OrderedDict): Loss dict.
|
| 304 |
+
"""
|
| 305 |
+
with torch.no_grad():
|
| 306 |
+
if self.opt['dist']:
|
| 307 |
+
keys = []
|
| 308 |
+
losses = []
|
| 309 |
+
for name, value in loss_dict.items():
|
| 310 |
+
keys.append(name)
|
| 311 |
+
losses.append(value)
|
| 312 |
+
losses = torch.stack(losses, 0)
|
| 313 |
+
torch.distributed.reduce(losses, dst=0)
|
| 314 |
+
if self.opt['rank'] == 0:
|
| 315 |
+
losses /= self.opt['world_size']
|
| 316 |
+
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
| 317 |
+
|
| 318 |
+
log_dict = OrderedDict()
|
| 319 |
+
for name, value in loss_dict.items():
|
| 320 |
+
log_dict[name] = value.mean().item()
|
| 321 |
+
|
| 322 |
+
return log_dict
|
basicsr/models/codeformer_idx_model.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from basicsr.archs import build_network
|
| 7 |
+
from basicsr.metrics import calculate_metric
|
| 8 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from .sr_model import SRModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@MODEL_REGISTRY.register()
|
| 15 |
+
class CodeFormerIdxModel(SRModel):
|
| 16 |
+
def feed_data(self, data):
|
| 17 |
+
self.gt = data['gt'].to(self.device)
|
| 18 |
+
self.input = data['in'].to(self.device)
|
| 19 |
+
self.b = self.gt.shape[0]
|
| 20 |
+
|
| 21 |
+
if 'latent_gt' in data:
|
| 22 |
+
self.idx_gt = data['latent_gt'].to(self.device)
|
| 23 |
+
self.idx_gt = self.idx_gt.view(self.b, -1)
|
| 24 |
+
else:
|
| 25 |
+
self.idx_gt = None
|
| 26 |
+
|
| 27 |
+
def init_training_settings(self):
|
| 28 |
+
logger = get_root_logger()
|
| 29 |
+
train_opt = self.opt['train']
|
| 30 |
+
|
| 31 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 32 |
+
if self.ema_decay > 0:
|
| 33 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 34 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 35 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 36 |
+
# There is no need to wrap with DistributedDataParallel
|
| 37 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
| 38 |
+
# load pretrained model
|
| 39 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 40 |
+
if load_path is not None:
|
| 41 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
| 42 |
+
else:
|
| 43 |
+
self.model_ema(0) # copy net_g weight
|
| 44 |
+
self.net_g_ema.eval()
|
| 45 |
+
|
| 46 |
+
if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
|
| 47 |
+
self.generate_idx_gt = False
|
| 48 |
+
elif self.opt.get('network_vqgan', None) is not None:
|
| 49 |
+
self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
|
| 50 |
+
self.hq_vqgan_fix.eval()
|
| 51 |
+
self.generate_idx_gt = True
|
| 52 |
+
for param in self.hq_vqgan_fix.parameters():
|
| 53 |
+
param.requires_grad = False
|
| 54 |
+
else:
|
| 55 |
+
raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
|
| 56 |
+
|
| 57 |
+
logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
|
| 58 |
+
|
| 59 |
+
self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
|
| 60 |
+
self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
|
| 61 |
+
self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
|
| 62 |
+
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
|
| 63 |
+
|
| 64 |
+
self.net_g.train()
|
| 65 |
+
|
| 66 |
+
# set up optimizers and schedulers
|
| 67 |
+
self.setup_optimizers()
|
| 68 |
+
self.setup_schedulers()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def setup_optimizers(self):
|
| 72 |
+
train_opt = self.opt['train']
|
| 73 |
+
# optimizer g
|
| 74 |
+
optim_params_g = []
|
| 75 |
+
for k, v in self.net_g.named_parameters():
|
| 76 |
+
if v.requires_grad:
|
| 77 |
+
optim_params_g.append(v)
|
| 78 |
+
else:
|
| 79 |
+
logger = get_root_logger()
|
| 80 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 81 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 82 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
| 83 |
+
self.optimizers.append(self.optimizer_g)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def optimize_parameters(self, current_iter):
|
| 87 |
+
logger = get_root_logger()
|
| 88 |
+
# optimize net_g
|
| 89 |
+
self.optimizer_g.zero_grad()
|
| 90 |
+
|
| 91 |
+
if self.generate_idx_gt:
|
| 92 |
+
x = self.hq_vqgan_fix.encoder(self.gt)
|
| 93 |
+
_, _, quant_stats = self.hq_vqgan_fix.quantize(x)
|
| 94 |
+
min_encoding_indices = quant_stats['min_encoding_indices']
|
| 95 |
+
self.idx_gt = min_encoding_indices.view(self.b, -1)
|
| 96 |
+
|
| 97 |
+
if self.hq_feat_loss:
|
| 98 |
+
# quant_feats
|
| 99 |
+
quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
|
| 100 |
+
|
| 101 |
+
logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
|
| 102 |
+
|
| 103 |
+
l_g_total = 0
|
| 104 |
+
loss_dict = OrderedDict()
|
| 105 |
+
# hq_feat_loss
|
| 106 |
+
if self.hq_feat_loss: # codebook loss
|
| 107 |
+
l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
|
| 108 |
+
l_g_total += l_feat_encoder
|
| 109 |
+
loss_dict['l_feat_encoder'] = l_feat_encoder
|
| 110 |
+
|
| 111 |
+
# cross_entropy_loss
|
| 112 |
+
if self.cross_entropy_loss:
|
| 113 |
+
# b(hw)n -> bn(hw)
|
| 114 |
+
cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
|
| 115 |
+
l_g_total += cross_entropy_loss
|
| 116 |
+
loss_dict['cross_entropy_loss'] = cross_entropy_loss
|
| 117 |
+
|
| 118 |
+
l_g_total.backward()
|
| 119 |
+
self.optimizer_g.step()
|
| 120 |
+
|
| 121 |
+
if self.ema_decay > 0:
|
| 122 |
+
self.model_ema(decay=self.ema_decay)
|
| 123 |
+
|
| 124 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test(self):
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
if hasattr(self, 'net_g_ema'):
|
| 130 |
+
self.net_g_ema.eval()
|
| 131 |
+
self.output, _, _ = self.net_g_ema(self.input, w=0)
|
| 132 |
+
else:
|
| 133 |
+
logger = get_root_logger()
|
| 134 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
| 135 |
+
self.net_g.eval()
|
| 136 |
+
self.output, _, _ = self.net_g(self.input, w=0)
|
| 137 |
+
self.net_g.train()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 141 |
+
if self.opt['rank'] == 0:
|
| 142 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 146 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 147 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 148 |
+
if with_metrics:
|
| 149 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
| 150 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
| 151 |
+
|
| 152 |
+
for idx, val_data in enumerate(dataloader):
|
| 153 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 154 |
+
self.feed_data(val_data)
|
| 155 |
+
self.test()
|
| 156 |
+
|
| 157 |
+
visuals = self.get_current_visuals()
|
| 158 |
+
sr_img = tensor2img([visuals['result']])
|
| 159 |
+
if 'gt' in visuals:
|
| 160 |
+
gt_img = tensor2img([visuals['gt']])
|
| 161 |
+
del self.gt
|
| 162 |
+
|
| 163 |
+
# tentative for out of GPU memory
|
| 164 |
+
del self.lq
|
| 165 |
+
del self.output
|
| 166 |
+
torch.cuda.empty_cache()
|
| 167 |
+
|
| 168 |
+
if save_img:
|
| 169 |
+
if self.opt['is_train']:
|
| 170 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
| 171 |
+
f'{img_name}_{current_iter}.png')
|
| 172 |
+
else:
|
| 173 |
+
if self.opt['val']['suffix']:
|
| 174 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 175 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
| 176 |
+
else:
|
| 177 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 178 |
+
f'{img_name}_{self.opt["name"]}.png')
|
| 179 |
+
imwrite(sr_img, save_img_path)
|
| 180 |
+
|
| 181 |
+
if with_metrics:
|
| 182 |
+
# calculate metrics
|
| 183 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
| 184 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
| 185 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
| 186 |
+
pbar.update(1)
|
| 187 |
+
pbar.set_description(f'Test {img_name}')
|
| 188 |
+
pbar.close()
|
| 189 |
+
|
| 190 |
+
if with_metrics:
|
| 191 |
+
for metric in self.metric_results.keys():
|
| 192 |
+
self.metric_results[metric] /= (idx + 1)
|
| 193 |
+
|
| 194 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
| 198 |
+
log_str = f'Validation {dataset_name}\n'
|
| 199 |
+
for metric, value in self.metric_results.items():
|
| 200 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
| 201 |
+
logger = get_root_logger()
|
| 202 |
+
logger.info(log_str)
|
| 203 |
+
if tb_logger:
|
| 204 |
+
for metric, value in self.metric_results.items():
|
| 205 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_current_visuals(self):
|
| 209 |
+
out_dict = OrderedDict()
|
| 210 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 211 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 212 |
+
return out_dict
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def save(self, epoch, current_iter):
|
| 216 |
+
if self.ema_decay > 0:
|
| 217 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
| 218 |
+
else:
|
| 219 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 220 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/models/codeformer_joint_model.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from basicsr.archs import build_network
|
| 8 |
+
from basicsr.losses import build_loss
|
| 9 |
+
from basicsr.metrics import calculate_metric
|
| 10 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 11 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from .sr_model import SRModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@MODEL_REGISTRY.register()
|
| 17 |
+
class CodeFormerJointModel(SRModel):
|
| 18 |
+
def feed_data(self, data):
|
| 19 |
+
self.gt = data['gt'].to(self.device)
|
| 20 |
+
self.input = data['in'].to(self.device)
|
| 21 |
+
self.input_large_de = data['in_large_de'].to(self.device)
|
| 22 |
+
self.b = self.gt.shape[0]
|
| 23 |
+
|
| 24 |
+
if 'latent_gt' in data:
|
| 25 |
+
self.idx_gt = data['latent_gt'].to(self.device)
|
| 26 |
+
self.idx_gt = self.idx_gt.view(self.b, -1)
|
| 27 |
+
else:
|
| 28 |
+
self.idx_gt = None
|
| 29 |
+
|
| 30 |
+
def init_training_settings(self):
|
| 31 |
+
logger = get_root_logger()
|
| 32 |
+
train_opt = self.opt['train']
|
| 33 |
+
|
| 34 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 35 |
+
if self.ema_decay > 0:
|
| 36 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 37 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 38 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 39 |
+
# There is no need to wrap with DistributedDataParallel
|
| 40 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
| 41 |
+
# load pretrained model
|
| 42 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 43 |
+
if load_path is not None:
|
| 44 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
| 45 |
+
else:
|
| 46 |
+
self.model_ema(0) # copy net_g weight
|
| 47 |
+
self.net_g_ema.eval()
|
| 48 |
+
|
| 49 |
+
if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
|
| 50 |
+
self.generate_idx_gt = False
|
| 51 |
+
elif self.opt.get('network_vqgan', None) is not None:
|
| 52 |
+
self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
|
| 53 |
+
self.hq_vqgan_fix.eval()
|
| 54 |
+
self.generate_idx_gt = True
|
| 55 |
+
for param in self.hq_vqgan_fix.parameters():
|
| 56 |
+
param.requires_grad = False
|
| 57 |
+
else:
|
| 58 |
+
raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
|
| 59 |
+
|
| 60 |
+
logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
|
| 61 |
+
|
| 62 |
+
self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
|
| 63 |
+
self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
|
| 64 |
+
self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
|
| 65 |
+
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
|
| 66 |
+
self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
|
| 67 |
+
|
| 68 |
+
# define network net_d
|
| 69 |
+
self.net_d = build_network(self.opt['network_d'])
|
| 70 |
+
self.net_d = self.model_to_device(self.net_d)
|
| 71 |
+
self.print_network(self.net_d)
|
| 72 |
+
|
| 73 |
+
# load pretrained models
|
| 74 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
| 75 |
+
if load_path is not None:
|
| 76 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
| 77 |
+
|
| 78 |
+
self.net_g.train()
|
| 79 |
+
self.net_d.train()
|
| 80 |
+
|
| 81 |
+
# define losses
|
| 82 |
+
if train_opt.get('pixel_opt'):
|
| 83 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
| 84 |
+
else:
|
| 85 |
+
self.cri_pix = None
|
| 86 |
+
|
| 87 |
+
if train_opt.get('perceptual_opt'):
|
| 88 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
| 89 |
+
else:
|
| 90 |
+
self.cri_perceptual = None
|
| 91 |
+
|
| 92 |
+
if train_opt.get('gan_opt'):
|
| 93 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
self.fix_generator = train_opt.get('fix_generator', True)
|
| 97 |
+
logger.info(f'fix_generator: {self.fix_generator}')
|
| 98 |
+
|
| 99 |
+
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
|
| 100 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
| 101 |
+
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
|
| 102 |
+
|
| 103 |
+
# set up optimizers and schedulers
|
| 104 |
+
self.setup_optimizers()
|
| 105 |
+
self.setup_schedulers()
|
| 106 |
+
|
| 107 |
+
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
|
| 108 |
+
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
|
| 109 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| 110 |
+
|
| 111 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
| 112 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
| 113 |
+
return d_weight
|
| 114 |
+
|
| 115 |
+
def setup_optimizers(self):
|
| 116 |
+
train_opt = self.opt['train']
|
| 117 |
+
# optimizer g
|
| 118 |
+
optim_params_g = []
|
| 119 |
+
for k, v in self.net_g.named_parameters():
|
| 120 |
+
if v.requires_grad:
|
| 121 |
+
optim_params_g.append(v)
|
| 122 |
+
else:
|
| 123 |
+
logger = get_root_logger()
|
| 124 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 125 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 126 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
| 127 |
+
self.optimizers.append(self.optimizer_g)
|
| 128 |
+
# optimizer d
|
| 129 |
+
optim_type = train_opt['optim_d'].pop('type')
|
| 130 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
| 131 |
+
self.optimizers.append(self.optimizer_d)
|
| 132 |
+
|
| 133 |
+
def gray_resize_for_identity(self, out, size=128):
|
| 134 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
| 135 |
+
out_gray = out_gray.unsqueeze(1)
|
| 136 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
| 137 |
+
return out_gray
|
| 138 |
+
|
| 139 |
+
def optimize_parameters(self, current_iter):
|
| 140 |
+
logger = get_root_logger()
|
| 141 |
+
# optimize net_g
|
| 142 |
+
for p in self.net_d.parameters():
|
| 143 |
+
p.requires_grad = False
|
| 144 |
+
|
| 145 |
+
self.optimizer_g.zero_grad()
|
| 146 |
+
|
| 147 |
+
if self.generate_idx_gt:
|
| 148 |
+
x = self.hq_vqgan_fix.encoder(self.gt)
|
| 149 |
+
output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
|
| 150 |
+
min_encoding_indices = quant_stats['min_encoding_indices']
|
| 151 |
+
self.idx_gt = min_encoding_indices.view(self.b, -1)
|
| 152 |
+
|
| 153 |
+
if current_iter <= 40000: # small degradation
|
| 154 |
+
small_per_n = 1
|
| 155 |
+
w = 1
|
| 156 |
+
elif current_iter <= 80000: # small degradation
|
| 157 |
+
small_per_n = 1
|
| 158 |
+
w = 1.3
|
| 159 |
+
elif current_iter <= 120000: # large degradation
|
| 160 |
+
small_per_n = 120000
|
| 161 |
+
w = 0
|
| 162 |
+
else: # mixed degradation
|
| 163 |
+
small_per_n = 15
|
| 164 |
+
w = 1.3
|
| 165 |
+
|
| 166 |
+
if current_iter % small_per_n == 0:
|
| 167 |
+
self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True)
|
| 168 |
+
large_de = False
|
| 169 |
+
else:
|
| 170 |
+
logits, lq_feat = self.net_g(self.input_large_de, code_only=True)
|
| 171 |
+
large_de = True
|
| 172 |
+
|
| 173 |
+
if self.hq_feat_loss:
|
| 174 |
+
# quant_feats
|
| 175 |
+
quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
|
| 176 |
+
|
| 177 |
+
l_g_total = 0
|
| 178 |
+
loss_dict = OrderedDict()
|
| 179 |
+
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
|
| 180 |
+
# hq_feat_loss
|
| 181 |
+
if not 'transformer' in self.opt['network_g']['fix_modules']:
|
| 182 |
+
if self.hq_feat_loss: # codebook loss
|
| 183 |
+
l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
|
| 184 |
+
l_g_total += l_feat_encoder
|
| 185 |
+
loss_dict['l_feat_encoder'] = l_feat_encoder
|
| 186 |
+
|
| 187 |
+
# cross_entropy_loss
|
| 188 |
+
if self.cross_entropy_loss:
|
| 189 |
+
# b(hw)n -> bn(hw)
|
| 190 |
+
cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
|
| 191 |
+
l_g_total += cross_entropy_loss
|
| 192 |
+
loss_dict['cross_entropy_loss'] = cross_entropy_loss
|
| 193 |
+
|
| 194 |
+
# pixel loss
|
| 195 |
+
if not large_de: # when large degradation don't need image-level loss
|
| 196 |
+
if self.cri_pix:
|
| 197 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
| 198 |
+
l_g_total += l_g_pix
|
| 199 |
+
loss_dict['l_g_pix'] = l_g_pix
|
| 200 |
+
|
| 201 |
+
# perceptual loss
|
| 202 |
+
if self.cri_perceptual:
|
| 203 |
+
l_g_percep = self.cri_perceptual(self.output, self.gt)
|
| 204 |
+
l_g_total += l_g_percep
|
| 205 |
+
loss_dict['l_g_percep'] = l_g_percep
|
| 206 |
+
|
| 207 |
+
# gan loss
|
| 208 |
+
if current_iter > self.net_d_start_iter:
|
| 209 |
+
fake_g_pred = self.net_d(self.output)
|
| 210 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
| 211 |
+
recon_loss = l_g_pix + l_g_percep
|
| 212 |
+
if not self.fix_generator:
|
| 213 |
+
last_layer = self.net_g.module.generator.blocks[-1].weight
|
| 214 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
| 215 |
+
else:
|
| 216 |
+
largest_fuse_size = self.opt['network_g']['connect_list'][-1]
|
| 217 |
+
last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
|
| 218 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
| 219 |
+
|
| 220 |
+
d_weight *= self.scale_adaptive_gan_weight # 0.8
|
| 221 |
+
loss_dict['d_weight'] = d_weight
|
| 222 |
+
l_g_total += d_weight * l_g_gan
|
| 223 |
+
loss_dict['l_g_gan'] = d_weight * l_g_gan
|
| 224 |
+
|
| 225 |
+
l_g_total.backward()
|
| 226 |
+
self.optimizer_g.step()
|
| 227 |
+
|
| 228 |
+
if self.ema_decay > 0:
|
| 229 |
+
self.model_ema(decay=self.ema_decay)
|
| 230 |
+
|
| 231 |
+
# optimize net_d
|
| 232 |
+
if not large_de:
|
| 233 |
+
if current_iter > self.net_d_start_iter:
|
| 234 |
+
for p in self.net_d.parameters():
|
| 235 |
+
p.requires_grad = True
|
| 236 |
+
|
| 237 |
+
self.optimizer_d.zero_grad()
|
| 238 |
+
# real
|
| 239 |
+
real_d_pred = self.net_d(self.gt)
|
| 240 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
| 241 |
+
loss_dict['l_d_real'] = l_d_real
|
| 242 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
| 243 |
+
l_d_real.backward()
|
| 244 |
+
# fake
|
| 245 |
+
fake_d_pred = self.net_d(self.output.detach())
|
| 246 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
| 247 |
+
loss_dict['l_d_fake'] = l_d_fake
|
| 248 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
| 249 |
+
l_d_fake.backward()
|
| 250 |
+
|
| 251 |
+
self.optimizer_d.step()
|
| 252 |
+
|
| 253 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def test(self):
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
if hasattr(self, 'net_g_ema'):
|
| 259 |
+
self.net_g_ema.eval()
|
| 260 |
+
self.output, _, _ = self.net_g_ema(self.input, w=1)
|
| 261 |
+
else:
|
| 262 |
+
logger = get_root_logger()
|
| 263 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
| 264 |
+
self.net_g.eval()
|
| 265 |
+
self.output, _, _ = self.net_g(self.input, w=1)
|
| 266 |
+
self.net_g.train()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 270 |
+
if self.opt['rank'] == 0:
|
| 271 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 275 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 276 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 277 |
+
if with_metrics:
|
| 278 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
| 279 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
| 280 |
+
|
| 281 |
+
for idx, val_data in enumerate(dataloader):
|
| 282 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 283 |
+
self.feed_data(val_data)
|
| 284 |
+
self.test()
|
| 285 |
+
|
| 286 |
+
visuals = self.get_current_visuals()
|
| 287 |
+
sr_img = tensor2img([visuals['result']])
|
| 288 |
+
if 'gt' in visuals:
|
| 289 |
+
gt_img = tensor2img([visuals['gt']])
|
| 290 |
+
del self.gt
|
| 291 |
+
|
| 292 |
+
# tentative for out of GPU memory
|
| 293 |
+
del self.lq
|
| 294 |
+
del self.output
|
| 295 |
+
torch.cuda.empty_cache()
|
| 296 |
+
|
| 297 |
+
if save_img:
|
| 298 |
+
if self.opt['is_train']:
|
| 299 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
| 300 |
+
f'{img_name}_{current_iter}.png')
|
| 301 |
+
else:
|
| 302 |
+
if self.opt['val']['suffix']:
|
| 303 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 304 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
| 305 |
+
else:
|
| 306 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 307 |
+
f'{img_name}_{self.opt["name"]}.png')
|
| 308 |
+
imwrite(sr_img, save_img_path)
|
| 309 |
+
|
| 310 |
+
if with_metrics:
|
| 311 |
+
# calculate metrics
|
| 312 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
| 313 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
| 314 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
| 315 |
+
pbar.update(1)
|
| 316 |
+
pbar.set_description(f'Test {img_name}')
|
| 317 |
+
pbar.close()
|
| 318 |
+
|
| 319 |
+
if with_metrics:
|
| 320 |
+
for metric in self.metric_results.keys():
|
| 321 |
+
self.metric_results[metric] /= (idx + 1)
|
| 322 |
+
|
| 323 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
| 327 |
+
log_str = f'Validation {dataset_name}\n'
|
| 328 |
+
for metric, value in self.metric_results.items():
|
| 329 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
| 330 |
+
logger = get_root_logger()
|
| 331 |
+
logger.info(log_str)
|
| 332 |
+
if tb_logger:
|
| 333 |
+
for metric, value in self.metric_results.items():
|
| 334 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def get_current_visuals(self):
|
| 338 |
+
out_dict = OrderedDict()
|
| 339 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 340 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 341 |
+
return out_dict
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def save(self, epoch, current_iter):
|
| 345 |
+
if self.ema_decay > 0:
|
| 346 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
| 347 |
+
else:
|
| 348 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 349 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
| 350 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/models/codeformer_model.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from basicsr.archs import build_network
|
| 7 |
+
from basicsr.losses import build_loss
|
| 8 |
+
from basicsr.metrics import calculate_metric
|
| 9 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 10 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from .sr_model import SRModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@MODEL_REGISTRY.register()
|
| 16 |
+
class CodeFormerModel(SRModel):
|
| 17 |
+
def feed_data(self, data):
|
| 18 |
+
self.gt = data['gt'].to(self.device)
|
| 19 |
+
self.input = data['in'].to(self.device)
|
| 20 |
+
self.b = self.gt.shape[0]
|
| 21 |
+
|
| 22 |
+
if 'latent_gt' in data:
|
| 23 |
+
self.idx_gt = data['latent_gt'].to(self.device)
|
| 24 |
+
self.idx_gt = self.idx_gt.view(self.b, -1)
|
| 25 |
+
else:
|
| 26 |
+
self.idx_gt = None
|
| 27 |
+
|
| 28 |
+
def init_training_settings(self):
|
| 29 |
+
logger = get_root_logger()
|
| 30 |
+
train_opt = self.opt['train']
|
| 31 |
+
|
| 32 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 33 |
+
if self.ema_decay > 0:
|
| 34 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 35 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 36 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 37 |
+
# There is no need to wrap with DistributedDataParallel
|
| 38 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
| 39 |
+
# load pretrained model
|
| 40 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 41 |
+
if load_path is not None:
|
| 42 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
| 43 |
+
else:
|
| 44 |
+
self.model_ema(0) # copy net_g weight
|
| 45 |
+
self.net_g_ema.eval()
|
| 46 |
+
|
| 47 |
+
if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
|
| 48 |
+
self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
|
| 49 |
+
self.hq_vqgan_fix.eval()
|
| 50 |
+
self.generate_idx_gt = True
|
| 51 |
+
for param in self.hq_vqgan_fix.parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
else:
|
| 54 |
+
self.generate_idx_gt = False
|
| 55 |
+
|
| 56 |
+
self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
|
| 57 |
+
self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
|
| 58 |
+
self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
|
| 59 |
+
self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
|
| 60 |
+
self.fidelity_weight = train_opt.get('fidelity_weight', 1.0)
|
| 61 |
+
self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
self.net_g.train()
|
| 65 |
+
# define network net_d
|
| 66 |
+
if self.fidelity_weight > 0:
|
| 67 |
+
self.net_d = build_network(self.opt['network_d'])
|
| 68 |
+
self.net_d = self.model_to_device(self.net_d)
|
| 69 |
+
self.print_network(self.net_d)
|
| 70 |
+
|
| 71 |
+
# load pretrained models
|
| 72 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
| 73 |
+
if load_path is not None:
|
| 74 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
| 75 |
+
|
| 76 |
+
self.net_d.train()
|
| 77 |
+
|
| 78 |
+
# define losses
|
| 79 |
+
if train_opt.get('pixel_opt'):
|
| 80 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
| 81 |
+
else:
|
| 82 |
+
self.cri_pix = None
|
| 83 |
+
|
| 84 |
+
if train_opt.get('perceptual_opt'):
|
| 85 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
| 86 |
+
else:
|
| 87 |
+
self.cri_perceptual = None
|
| 88 |
+
|
| 89 |
+
if train_opt.get('gan_opt'):
|
| 90 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
self.fix_generator = train_opt.get('fix_generator', True)
|
| 94 |
+
logger.info(f'fix_generator: {self.fix_generator}')
|
| 95 |
+
|
| 96 |
+
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
|
| 97 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
| 98 |
+
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
|
| 99 |
+
|
| 100 |
+
# set up optimizers and schedulers
|
| 101 |
+
self.setup_optimizers()
|
| 102 |
+
self.setup_schedulers()
|
| 103 |
+
|
| 104 |
+
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
|
| 105 |
+
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
|
| 106 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| 107 |
+
|
| 108 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
| 109 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
| 110 |
+
return d_weight
|
| 111 |
+
|
| 112 |
+
def setup_optimizers(self):
|
| 113 |
+
train_opt = self.opt['train']
|
| 114 |
+
# optimizer g
|
| 115 |
+
optim_params_g = []
|
| 116 |
+
for k, v in self.net_g.named_parameters():
|
| 117 |
+
if v.requires_grad:
|
| 118 |
+
optim_params_g.append(v)
|
| 119 |
+
else:
|
| 120 |
+
logger = get_root_logger()
|
| 121 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 122 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 123 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
| 124 |
+
self.optimizers.append(self.optimizer_g)
|
| 125 |
+
# optimizer d
|
| 126 |
+
if self.fidelity_weight > 0:
|
| 127 |
+
optim_type = train_opt['optim_d'].pop('type')
|
| 128 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
| 129 |
+
self.optimizers.append(self.optimizer_d)
|
| 130 |
+
|
| 131 |
+
def gray_resize_for_identity(self, out, size=128):
|
| 132 |
+
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
|
| 133 |
+
out_gray = out_gray.unsqueeze(1)
|
| 134 |
+
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
|
| 135 |
+
return out_gray
|
| 136 |
+
|
| 137 |
+
def optimize_parameters(self, current_iter):
|
| 138 |
+
logger = get_root_logger()
|
| 139 |
+
# optimize net_g
|
| 140 |
+
for p in self.net_d.parameters():
|
| 141 |
+
p.requires_grad = False
|
| 142 |
+
|
| 143 |
+
self.optimizer_g.zero_grad()
|
| 144 |
+
|
| 145 |
+
if self.generate_idx_gt:
|
| 146 |
+
x = self.hq_vqgan_fix.encoder(self.gt)
|
| 147 |
+
output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
|
| 148 |
+
min_encoding_indices = quant_stats['min_encoding_indices']
|
| 149 |
+
self.idx_gt = min_encoding_indices.view(self.b, -1)
|
| 150 |
+
|
| 151 |
+
if self.fidelity_weight > 0:
|
| 152 |
+
self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True)
|
| 153 |
+
else:
|
| 154 |
+
logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
|
| 155 |
+
|
| 156 |
+
if self.hq_feat_loss:
|
| 157 |
+
# quant_feats
|
| 158 |
+
quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
|
| 159 |
+
|
| 160 |
+
l_g_total = 0
|
| 161 |
+
loss_dict = OrderedDict()
|
| 162 |
+
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
|
| 163 |
+
# hq_feat_loss
|
| 164 |
+
if self.hq_feat_loss: # codebook loss
|
| 165 |
+
l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
|
| 166 |
+
l_g_total += l_feat_encoder
|
| 167 |
+
loss_dict['l_feat_encoder'] = l_feat_encoder
|
| 168 |
+
|
| 169 |
+
# cross_entropy_loss
|
| 170 |
+
if self.cross_entropy_loss:
|
| 171 |
+
# b(hw)n -> bn(hw)
|
| 172 |
+
cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
|
| 173 |
+
l_g_total += cross_entropy_loss
|
| 174 |
+
loss_dict['cross_entropy_loss'] = cross_entropy_loss
|
| 175 |
+
|
| 176 |
+
if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss
|
| 177 |
+
# pixel loss
|
| 178 |
+
if self.cri_pix:
|
| 179 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
| 180 |
+
l_g_total += l_g_pix
|
| 181 |
+
loss_dict['l_g_pix'] = l_g_pix
|
| 182 |
+
|
| 183 |
+
# perceptual loss
|
| 184 |
+
if self.cri_perceptual:
|
| 185 |
+
l_g_percep = self.cri_perceptual(self.output, self.gt)
|
| 186 |
+
l_g_total += l_g_percep
|
| 187 |
+
loss_dict['l_g_percep'] = l_g_percep
|
| 188 |
+
|
| 189 |
+
# gan loss
|
| 190 |
+
if current_iter > self.net_d_start_iter:
|
| 191 |
+
fake_g_pred = self.net_d(self.output)
|
| 192 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
| 193 |
+
recon_loss = l_g_pix + l_g_percep
|
| 194 |
+
if not self.fix_generator:
|
| 195 |
+
last_layer = self.net_g.module.generator.blocks[-1].weight
|
| 196 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
| 197 |
+
else:
|
| 198 |
+
largest_fuse_size = self.opt['network_g']['connect_list'][-1]
|
| 199 |
+
last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
|
| 200 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
| 201 |
+
|
| 202 |
+
d_weight *= self.scale_adaptive_gan_weight # 0.8
|
| 203 |
+
loss_dict['d_weight'] = d_weight
|
| 204 |
+
l_g_total += d_weight * l_g_gan
|
| 205 |
+
loss_dict['l_g_gan'] = d_weight * l_g_gan
|
| 206 |
+
|
| 207 |
+
l_g_total.backward()
|
| 208 |
+
self.optimizer_g.step()
|
| 209 |
+
|
| 210 |
+
if self.ema_decay > 0:
|
| 211 |
+
self.model_ema(decay=self.ema_decay)
|
| 212 |
+
|
| 213 |
+
# optimize net_d
|
| 214 |
+
if current_iter > self.net_d_start_iter and self.fidelity_weight > 0:
|
| 215 |
+
for p in self.net_d.parameters():
|
| 216 |
+
p.requires_grad = True
|
| 217 |
+
|
| 218 |
+
self.optimizer_d.zero_grad()
|
| 219 |
+
# real
|
| 220 |
+
real_d_pred = self.net_d(self.gt)
|
| 221 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
| 222 |
+
loss_dict['l_d_real'] = l_d_real
|
| 223 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
| 224 |
+
l_d_real.backward()
|
| 225 |
+
# fake
|
| 226 |
+
fake_d_pred = self.net_d(self.output.detach())
|
| 227 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
| 228 |
+
loss_dict['l_d_fake'] = l_d_fake
|
| 229 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
| 230 |
+
l_d_fake.backward()
|
| 231 |
+
|
| 232 |
+
self.optimizer_d.step()
|
| 233 |
+
|
| 234 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def test(self):
|
| 238 |
+
with torch.no_grad():
|
| 239 |
+
if hasattr(self, 'net_g_ema'):
|
| 240 |
+
self.net_g_ema.eval()
|
| 241 |
+
self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight)
|
| 242 |
+
else:
|
| 243 |
+
logger = get_root_logger()
|
| 244 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
| 245 |
+
self.net_g.eval()
|
| 246 |
+
self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight)
|
| 247 |
+
self.net_g.train()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 251 |
+
if self.opt['rank'] == 0:
|
| 252 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 256 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 257 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 258 |
+
if with_metrics:
|
| 259 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
| 260 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
| 261 |
+
|
| 262 |
+
for idx, val_data in enumerate(dataloader):
|
| 263 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 264 |
+
self.feed_data(val_data)
|
| 265 |
+
self.test()
|
| 266 |
+
|
| 267 |
+
visuals = self.get_current_visuals()
|
| 268 |
+
sr_img = tensor2img([visuals['result']])
|
| 269 |
+
if 'gt' in visuals:
|
| 270 |
+
gt_img = tensor2img([visuals['gt']])
|
| 271 |
+
del self.gt
|
| 272 |
+
|
| 273 |
+
# tentative for out of GPU memory
|
| 274 |
+
del self.lq
|
| 275 |
+
del self.output
|
| 276 |
+
torch.cuda.empty_cache()
|
| 277 |
+
|
| 278 |
+
if save_img:
|
| 279 |
+
if self.opt['is_train']:
|
| 280 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
| 281 |
+
f'{img_name}_{current_iter}.png')
|
| 282 |
+
else:
|
| 283 |
+
if self.opt['val']['suffix']:
|
| 284 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 285 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
| 286 |
+
else:
|
| 287 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 288 |
+
f'{img_name}_{self.opt["name"]}.png')
|
| 289 |
+
imwrite(sr_img, save_img_path)
|
| 290 |
+
|
| 291 |
+
if with_metrics:
|
| 292 |
+
# calculate metrics
|
| 293 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
| 294 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
| 295 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
| 296 |
+
pbar.update(1)
|
| 297 |
+
pbar.set_description(f'Test {img_name}')
|
| 298 |
+
pbar.close()
|
| 299 |
+
|
| 300 |
+
if with_metrics:
|
| 301 |
+
for metric in self.metric_results.keys():
|
| 302 |
+
self.metric_results[metric] /= (idx + 1)
|
| 303 |
+
|
| 304 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
| 308 |
+
log_str = f'Validation {dataset_name}\n'
|
| 309 |
+
for metric, value in self.metric_results.items():
|
| 310 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
| 311 |
+
logger = get_root_logger()
|
| 312 |
+
logger.info(log_str)
|
| 313 |
+
if tb_logger:
|
| 314 |
+
for metric, value in self.metric_results.items():
|
| 315 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def get_current_visuals(self):
|
| 319 |
+
out_dict = OrderedDict()
|
| 320 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 321 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 322 |
+
return out_dict
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def save(self, epoch, current_iter):
|
| 326 |
+
if self.ema_decay > 0:
|
| 327 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
| 328 |
+
else:
|
| 329 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 330 |
+
if self.fidelity_weight > 0:
|
| 331 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
| 332 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/models/lr_scheduler.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MultiStepRestartLR(_LRScheduler):
|
| 7 |
+
""" MultiStep with restarts learning rate scheme.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 11 |
+
milestones (list): Iterations that will decrease learning rate.
|
| 12 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
| 13 |
+
restarts (list): Restart iterations. Default: [0].
|
| 14 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 15 |
+
Default: [1].
|
| 16 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
|
| 20 |
+
self.milestones = Counter(milestones)
|
| 21 |
+
self.gamma = gamma
|
| 22 |
+
self.restarts = restarts
|
| 23 |
+
self.restart_weights = restart_weights
|
| 24 |
+
assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
|
| 25 |
+
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
| 26 |
+
|
| 27 |
+
def get_lr(self):
|
| 28 |
+
if self.last_epoch in self.restarts:
|
| 29 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
| 30 |
+
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
|
| 31 |
+
if self.last_epoch not in self.milestones:
|
| 32 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
| 33 |
+
return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_position_from_periods(iteration, cumulative_period):
|
| 37 |
+
"""Get the position from a period list.
|
| 38 |
+
|
| 39 |
+
It will return the index of the right-closest number in the period list.
|
| 40 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
| 41 |
+
if iteration == 50, return 0;
|
| 42 |
+
if iteration == 210, return 2;
|
| 43 |
+
if iteration == 300, return 2.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
iteration (int): Current iteration.
|
| 47 |
+
cumulative_period (list[int]): Cumulative period list.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
int: The position of the right-closest number in the period list.
|
| 51 |
+
"""
|
| 52 |
+
for i, period in enumerate(cumulative_period):
|
| 53 |
+
if iteration <= period:
|
| 54 |
+
return i
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
| 58 |
+
""" Cosine annealing with restarts learning rate scheme.
|
| 59 |
+
|
| 60 |
+
An example of config:
|
| 61 |
+
periods = [10, 10, 10, 10]
|
| 62 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
| 63 |
+
eta_min=1e-7
|
| 64 |
+
|
| 65 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
| 66 |
+
scheduler will restart with the weights in restart_weights.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
| 70 |
+
periods (list): Period for each cosine anneling cycle.
|
| 71 |
+
restart_weights (list): Restart weights at each restart iteration.
|
| 72 |
+
Default: [1].
|
| 73 |
+
eta_min (float): The mimimum lr. Default: 0.
|
| 74 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
|
| 78 |
+
self.periods = periods
|
| 79 |
+
self.restart_weights = restart_weights
|
| 80 |
+
self.eta_min = eta_min
|
| 81 |
+
assert (len(self.periods) == len(
|
| 82 |
+
self.restart_weights)), 'periods and restart_weights should have the same length.'
|
| 83 |
+
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
|
| 84 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
| 85 |
+
|
| 86 |
+
def get_lr(self):
|
| 87 |
+
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
|
| 88 |
+
current_weight = self.restart_weights[idx]
|
| 89 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
| 90 |
+
current_period = self.periods[idx]
|
| 91 |
+
|
| 92 |
+
return [
|
| 93 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
| 94 |
+
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
|
| 95 |
+
for base_lr in self.base_lrs
|
| 96 |
+
]
|
basicsr/models/sr_model.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from basicsr.archs import build_network
|
| 7 |
+
from basicsr.losses import build_loss
|
| 8 |
+
from basicsr.metrics import calculate_metric
|
| 9 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 10 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 11 |
+
from .base_model import BaseModel
|
| 12 |
+
|
| 13 |
+
@MODEL_REGISTRY.register()
|
| 14 |
+
class SRModel(BaseModel):
|
| 15 |
+
"""Base SR model for single image super-resolution."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
super(SRModel, self).__init__(opt)
|
| 19 |
+
|
| 20 |
+
# define network
|
| 21 |
+
self.net_g = build_network(opt['network_g'])
|
| 22 |
+
self.net_g = self.model_to_device(self.net_g)
|
| 23 |
+
self.print_network(self.net_g)
|
| 24 |
+
|
| 25 |
+
# load pretrained models
|
| 26 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 27 |
+
if load_path is not None:
|
| 28 |
+
param_key = self.opt['path'].get('param_key_g', 'params')
|
| 29 |
+
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
|
| 30 |
+
|
| 31 |
+
if self.is_train:
|
| 32 |
+
self.init_training_settings()
|
| 33 |
+
|
| 34 |
+
def init_training_settings(self):
|
| 35 |
+
self.net_g.train()
|
| 36 |
+
train_opt = self.opt['train']
|
| 37 |
+
|
| 38 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 39 |
+
if self.ema_decay > 0:
|
| 40 |
+
logger = get_root_logger()
|
| 41 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 42 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 43 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 44 |
+
# There is no need to wrap with DistributedDataParallel
|
| 45 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
| 46 |
+
# load pretrained model
|
| 47 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 48 |
+
if load_path is not None:
|
| 49 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
| 50 |
+
else:
|
| 51 |
+
self.model_ema(0) # copy net_g weight
|
| 52 |
+
self.net_g_ema.eval()
|
| 53 |
+
|
| 54 |
+
# define losses
|
| 55 |
+
if train_opt.get('pixel_opt'):
|
| 56 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
| 57 |
+
else:
|
| 58 |
+
self.cri_pix = None
|
| 59 |
+
|
| 60 |
+
if train_opt.get('perceptual_opt'):
|
| 61 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
| 62 |
+
else:
|
| 63 |
+
self.cri_perceptual = None
|
| 64 |
+
|
| 65 |
+
if self.cri_pix is None and self.cri_perceptual is None:
|
| 66 |
+
raise ValueError('Both pixel and perceptual losses are None.')
|
| 67 |
+
|
| 68 |
+
# set up optimizers and schedulers
|
| 69 |
+
self.setup_optimizers()
|
| 70 |
+
self.setup_schedulers()
|
| 71 |
+
|
| 72 |
+
def setup_optimizers(self):
|
| 73 |
+
train_opt = self.opt['train']
|
| 74 |
+
optim_params = []
|
| 75 |
+
for k, v in self.net_g.named_parameters():
|
| 76 |
+
if v.requires_grad:
|
| 77 |
+
optim_params.append(v)
|
| 78 |
+
else:
|
| 79 |
+
logger = get_root_logger()
|
| 80 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 81 |
+
|
| 82 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 83 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
|
| 84 |
+
self.optimizers.append(self.optimizer_g)
|
| 85 |
+
|
| 86 |
+
def feed_data(self, data):
|
| 87 |
+
self.lq = data['lq'].to(self.device)
|
| 88 |
+
if 'gt' in data:
|
| 89 |
+
self.gt = data['gt'].to(self.device)
|
| 90 |
+
|
| 91 |
+
def optimize_parameters(self, current_iter):
|
| 92 |
+
self.optimizer_g.zero_grad()
|
| 93 |
+
self.output = self.net_g(self.lq)
|
| 94 |
+
|
| 95 |
+
l_total = 0
|
| 96 |
+
loss_dict = OrderedDict()
|
| 97 |
+
# pixel loss
|
| 98 |
+
if self.cri_pix:
|
| 99 |
+
l_pix = self.cri_pix(self.output, self.gt)
|
| 100 |
+
l_total += l_pix
|
| 101 |
+
loss_dict['l_pix'] = l_pix
|
| 102 |
+
# perceptual loss
|
| 103 |
+
if self.cri_perceptual:
|
| 104 |
+
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
|
| 105 |
+
if l_percep is not None:
|
| 106 |
+
l_total += l_percep
|
| 107 |
+
loss_dict['l_percep'] = l_percep
|
| 108 |
+
if l_style is not None:
|
| 109 |
+
l_total += l_style
|
| 110 |
+
loss_dict['l_style'] = l_style
|
| 111 |
+
|
| 112 |
+
l_total.backward()
|
| 113 |
+
self.optimizer_g.step()
|
| 114 |
+
|
| 115 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 116 |
+
|
| 117 |
+
if self.ema_decay > 0:
|
| 118 |
+
self.model_ema(decay=self.ema_decay)
|
| 119 |
+
|
| 120 |
+
def test(self):
|
| 121 |
+
if hasattr(self, 'ema_decay'):
|
| 122 |
+
self.net_g_ema.eval()
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
self.output = self.net_g_ema(self.lq)
|
| 125 |
+
else:
|
| 126 |
+
self.net_g.eval()
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
self.output = self.net_g(self.lq)
|
| 129 |
+
self.net_g.train()
|
| 130 |
+
|
| 131 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 132 |
+
if self.opt['rank'] == 0:
|
| 133 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 134 |
+
|
| 135 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 136 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 137 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 138 |
+
if with_metrics:
|
| 139 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
| 140 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
| 141 |
+
|
| 142 |
+
for idx, val_data in enumerate(dataloader):
|
| 143 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 144 |
+
self.feed_data(val_data)
|
| 145 |
+
self.test()
|
| 146 |
+
|
| 147 |
+
visuals = self.get_current_visuals()
|
| 148 |
+
sr_img = tensor2img([visuals['result']])
|
| 149 |
+
if 'gt' in visuals:
|
| 150 |
+
gt_img = tensor2img([visuals['gt']])
|
| 151 |
+
del self.gt
|
| 152 |
+
|
| 153 |
+
# tentative for out of GPU memory
|
| 154 |
+
del self.lq
|
| 155 |
+
del self.output
|
| 156 |
+
torch.cuda.empty_cache()
|
| 157 |
+
|
| 158 |
+
if save_img:
|
| 159 |
+
if self.opt['is_train']:
|
| 160 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
| 161 |
+
f'{img_name}_{current_iter}.png')
|
| 162 |
+
else:
|
| 163 |
+
if self.opt['val']['suffix']:
|
| 164 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 165 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
| 166 |
+
else:
|
| 167 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 168 |
+
f'{img_name}_{self.opt["name"]}.png')
|
| 169 |
+
imwrite(sr_img, save_img_path)
|
| 170 |
+
|
| 171 |
+
if with_metrics:
|
| 172 |
+
# calculate metrics
|
| 173 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
| 174 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
| 175 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
| 176 |
+
pbar.update(1)
|
| 177 |
+
pbar.set_description(f'Test {img_name}')
|
| 178 |
+
pbar.close()
|
| 179 |
+
|
| 180 |
+
if with_metrics:
|
| 181 |
+
for metric in self.metric_results.keys():
|
| 182 |
+
self.metric_results[metric] /= (idx + 1)
|
| 183 |
+
|
| 184 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
| 185 |
+
|
| 186 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
| 187 |
+
log_str = f'Validation {dataset_name}\n'
|
| 188 |
+
for metric, value in self.metric_results.items():
|
| 189 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
| 190 |
+
logger = get_root_logger()
|
| 191 |
+
logger.info(log_str)
|
| 192 |
+
if tb_logger:
|
| 193 |
+
for metric, value in self.metric_results.items():
|
| 194 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 195 |
+
|
| 196 |
+
def get_current_visuals(self):
|
| 197 |
+
out_dict = OrderedDict()
|
| 198 |
+
out_dict['lq'] = self.lq.detach().cpu()
|
| 199 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 200 |
+
if hasattr(self, 'gt'):
|
| 201 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 202 |
+
return out_dict
|
| 203 |
+
|
| 204 |
+
def save(self, epoch, current_iter):
|
| 205 |
+
if hasattr(self, 'ema_decay'):
|
| 206 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
| 207 |
+
else:
|
| 208 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 209 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/models/vqgan_model.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from os import path as osp
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
from basicsr.archs import build_network
|
| 7 |
+
from basicsr.losses import build_loss
|
| 8 |
+
from basicsr.metrics import calculate_metric
|
| 9 |
+
from basicsr.utils import get_root_logger, imwrite, tensor2img
|
| 10 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from .sr_model import SRModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@MODEL_REGISTRY.register()
|
| 16 |
+
class VQGANModel(SRModel):
|
| 17 |
+
def feed_data(self, data):
|
| 18 |
+
self.gt = data['gt'].to(self.device)
|
| 19 |
+
self.b = self.gt.shape[0]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def init_training_settings(self):
|
| 23 |
+
logger = get_root_logger()
|
| 24 |
+
train_opt = self.opt['train']
|
| 25 |
+
|
| 26 |
+
self.ema_decay = train_opt.get('ema_decay', 0)
|
| 27 |
+
if self.ema_decay > 0:
|
| 28 |
+
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
|
| 29 |
+
# define network net_g with Exponential Moving Average (EMA)
|
| 30 |
+
# net_g_ema is used only for testing on one GPU and saving
|
| 31 |
+
# There is no need to wrap with DistributedDataParallel
|
| 32 |
+
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
|
| 33 |
+
# load pretrained model
|
| 34 |
+
load_path = self.opt['path'].get('pretrain_network_g', None)
|
| 35 |
+
if load_path is not None:
|
| 36 |
+
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
|
| 37 |
+
else:
|
| 38 |
+
self.model_ema(0) # copy net_g weight
|
| 39 |
+
self.net_g_ema.eval()
|
| 40 |
+
|
| 41 |
+
# define network net_d
|
| 42 |
+
self.net_d = build_network(self.opt['network_d'])
|
| 43 |
+
self.net_d = self.model_to_device(self.net_d)
|
| 44 |
+
self.print_network(self.net_d)
|
| 45 |
+
|
| 46 |
+
# load pretrained models
|
| 47 |
+
load_path = self.opt['path'].get('pretrain_network_d', None)
|
| 48 |
+
if load_path is not None:
|
| 49 |
+
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
|
| 50 |
+
|
| 51 |
+
self.net_g.train()
|
| 52 |
+
self.net_d.train()
|
| 53 |
+
|
| 54 |
+
# define losses
|
| 55 |
+
if train_opt.get('pixel_opt'):
|
| 56 |
+
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
|
| 57 |
+
else:
|
| 58 |
+
self.cri_pix = None
|
| 59 |
+
|
| 60 |
+
if train_opt.get('perceptual_opt'):
|
| 61 |
+
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
|
| 62 |
+
else:
|
| 63 |
+
self.cri_perceptual = None
|
| 64 |
+
|
| 65 |
+
if train_opt.get('gan_opt'):
|
| 66 |
+
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
|
| 67 |
+
|
| 68 |
+
if train_opt.get('codebook_opt'):
|
| 69 |
+
self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
|
| 70 |
+
else:
|
| 71 |
+
self.l_weight_codebook = 1.0
|
| 72 |
+
|
| 73 |
+
self.vqgan_quantizer = self.opt['network_g']['quantizer']
|
| 74 |
+
logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
|
| 75 |
+
|
| 76 |
+
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
|
| 77 |
+
self.net_d_iters = train_opt.get('net_d_iters', 1)
|
| 78 |
+
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
|
| 79 |
+
self.disc_weight = train_opt.get('disc_weight', 0.8)
|
| 80 |
+
|
| 81 |
+
# set up optimizers and schedulers
|
| 82 |
+
self.setup_optimizers()
|
| 83 |
+
self.setup_schedulers()
|
| 84 |
+
|
| 85 |
+
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
|
| 86 |
+
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
|
| 87 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| 88 |
+
|
| 89 |
+
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
|
| 90 |
+
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
|
| 91 |
+
return d_weight
|
| 92 |
+
|
| 93 |
+
def adopt_weight(self, weight, global_step, threshold=0, value=0.):
|
| 94 |
+
if global_step < threshold:
|
| 95 |
+
weight = value
|
| 96 |
+
return weight
|
| 97 |
+
|
| 98 |
+
def setup_optimizers(self):
|
| 99 |
+
train_opt = self.opt['train']
|
| 100 |
+
# optimizer g
|
| 101 |
+
optim_params_g = []
|
| 102 |
+
for k, v in self.net_g.named_parameters():
|
| 103 |
+
if v.requires_grad:
|
| 104 |
+
optim_params_g.append(v)
|
| 105 |
+
else:
|
| 106 |
+
logger = get_root_logger()
|
| 107 |
+
logger.warning(f'Params {k} will not be optimized.')
|
| 108 |
+
optim_type = train_opt['optim_g'].pop('type')
|
| 109 |
+
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
|
| 110 |
+
self.optimizers.append(self.optimizer_g)
|
| 111 |
+
# optimizer d
|
| 112 |
+
optim_type = train_opt['optim_d'].pop('type')
|
| 113 |
+
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
|
| 114 |
+
self.optimizers.append(self.optimizer_d)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def optimize_parameters(self, current_iter):
|
| 118 |
+
logger = get_root_logger()
|
| 119 |
+
loss_dict = OrderedDict()
|
| 120 |
+
if self.opt['network_g']['quantizer'] == 'gumbel':
|
| 121 |
+
self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
|
| 122 |
+
if current_iter%1000 == 0:
|
| 123 |
+
logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
|
| 124 |
+
|
| 125 |
+
# optimize net_g
|
| 126 |
+
for p in self.net_d.parameters():
|
| 127 |
+
p.requires_grad = False
|
| 128 |
+
|
| 129 |
+
self.optimizer_g.zero_grad()
|
| 130 |
+
self.output, l_codebook, quant_stats = self.net_g(self.gt)
|
| 131 |
+
|
| 132 |
+
l_codebook = l_codebook*self.l_weight_codebook
|
| 133 |
+
|
| 134 |
+
l_g_total = 0
|
| 135 |
+
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
|
| 136 |
+
# pixel loss
|
| 137 |
+
if self.cri_pix:
|
| 138 |
+
l_g_pix = self.cri_pix(self.output, self.gt)
|
| 139 |
+
l_g_total += l_g_pix
|
| 140 |
+
loss_dict['l_g_pix'] = l_g_pix
|
| 141 |
+
# perceptual loss
|
| 142 |
+
if self.cri_perceptual:
|
| 143 |
+
l_g_percep = self.cri_perceptual(self.output, self.gt)
|
| 144 |
+
l_g_total += l_g_percep
|
| 145 |
+
loss_dict['l_g_percep'] = l_g_percep
|
| 146 |
+
|
| 147 |
+
# gan loss
|
| 148 |
+
if current_iter > self.net_d_start_iter:
|
| 149 |
+
# fake_g_pred = self.net_d(self.output_1024)
|
| 150 |
+
fake_g_pred = self.net_d(self.output)
|
| 151 |
+
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
|
| 152 |
+
recon_loss = l_g_total
|
| 153 |
+
last_layer = self.net_g.module.generator.blocks[-1].weight
|
| 154 |
+
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
|
| 155 |
+
d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
|
| 156 |
+
d_weight *= self.disc_weight # tamming setting 0.8
|
| 157 |
+
l_g_total += d_weight * l_g_gan
|
| 158 |
+
loss_dict['l_g_gan'] = d_weight * l_g_gan
|
| 159 |
+
|
| 160 |
+
l_g_total += l_codebook
|
| 161 |
+
loss_dict['l_codebook'] = l_codebook
|
| 162 |
+
|
| 163 |
+
l_g_total.backward()
|
| 164 |
+
self.optimizer_g.step()
|
| 165 |
+
|
| 166 |
+
# optimize net_d
|
| 167 |
+
if current_iter > self.net_d_start_iter:
|
| 168 |
+
for p in self.net_d.parameters():
|
| 169 |
+
p.requires_grad = True
|
| 170 |
+
|
| 171 |
+
self.optimizer_d.zero_grad()
|
| 172 |
+
# real
|
| 173 |
+
real_d_pred = self.net_d(self.gt)
|
| 174 |
+
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
|
| 175 |
+
loss_dict['l_d_real'] = l_d_real
|
| 176 |
+
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
|
| 177 |
+
l_d_real.backward()
|
| 178 |
+
# fake
|
| 179 |
+
fake_d_pred = self.net_d(self.output.detach())
|
| 180 |
+
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
|
| 181 |
+
loss_dict['l_d_fake'] = l_d_fake
|
| 182 |
+
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
|
| 183 |
+
l_d_fake.backward()
|
| 184 |
+
self.optimizer_d.step()
|
| 185 |
+
|
| 186 |
+
self.log_dict = self.reduce_loss_dict(loss_dict)
|
| 187 |
+
|
| 188 |
+
if self.ema_decay > 0:
|
| 189 |
+
self.model_ema(decay=self.ema_decay)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def test(self):
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
if hasattr(self, 'net_g_ema'):
|
| 195 |
+
self.net_g_ema.eval()
|
| 196 |
+
self.output, _, _ = self.net_g_ema(self.gt)
|
| 197 |
+
else:
|
| 198 |
+
logger = get_root_logger()
|
| 199 |
+
logger.warning('Do not have self.net_g_ema, use self.net_g.')
|
| 200 |
+
self.net_g.eval()
|
| 201 |
+
self.output, _, _ = self.net_g(self.gt)
|
| 202 |
+
self.net_g.train()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 206 |
+
if self.opt['rank'] == 0:
|
| 207 |
+
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
| 211 |
+
dataset_name = dataloader.dataset.opt['name']
|
| 212 |
+
with_metrics = self.opt['val'].get('metrics') is not None
|
| 213 |
+
if with_metrics:
|
| 214 |
+
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
|
| 215 |
+
pbar = tqdm(total=len(dataloader), unit='image')
|
| 216 |
+
|
| 217 |
+
for idx, val_data in enumerate(dataloader):
|
| 218 |
+
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
|
| 219 |
+
self.feed_data(val_data)
|
| 220 |
+
self.test()
|
| 221 |
+
|
| 222 |
+
visuals = self.get_current_visuals()
|
| 223 |
+
sr_img = tensor2img([visuals['result']])
|
| 224 |
+
if 'gt' in visuals:
|
| 225 |
+
gt_img = tensor2img([visuals['gt']])
|
| 226 |
+
del self.gt
|
| 227 |
+
|
| 228 |
+
# tentative for out of GPU memory
|
| 229 |
+
del self.lq
|
| 230 |
+
del self.output
|
| 231 |
+
torch.cuda.empty_cache()
|
| 232 |
+
|
| 233 |
+
if save_img:
|
| 234 |
+
if self.opt['is_train']:
|
| 235 |
+
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
|
| 236 |
+
f'{img_name}_{current_iter}.png')
|
| 237 |
+
else:
|
| 238 |
+
if self.opt['val']['suffix']:
|
| 239 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 240 |
+
f'{img_name}_{self.opt["val"]["suffix"]}.png')
|
| 241 |
+
else:
|
| 242 |
+
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
|
| 243 |
+
f'{img_name}_{self.opt["name"]}.png')
|
| 244 |
+
imwrite(sr_img, save_img_path)
|
| 245 |
+
|
| 246 |
+
if with_metrics:
|
| 247 |
+
# calculate metrics
|
| 248 |
+
for name, opt_ in self.opt['val']['metrics'].items():
|
| 249 |
+
metric_data = dict(img1=sr_img, img2=gt_img)
|
| 250 |
+
self.metric_results[name] += calculate_metric(metric_data, opt_)
|
| 251 |
+
pbar.update(1)
|
| 252 |
+
pbar.set_description(f'Test {img_name}')
|
| 253 |
+
pbar.close()
|
| 254 |
+
|
| 255 |
+
if with_metrics:
|
| 256 |
+
for metric in self.metric_results.keys():
|
| 257 |
+
self.metric_results[metric] /= (idx + 1)
|
| 258 |
+
|
| 259 |
+
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
|
| 263 |
+
log_str = f'Validation {dataset_name}\n'
|
| 264 |
+
for metric, value in self.metric_results.items():
|
| 265 |
+
log_str += f'\t # {metric}: {value:.4f}\n'
|
| 266 |
+
logger = get_root_logger()
|
| 267 |
+
logger.info(log_str)
|
| 268 |
+
if tb_logger:
|
| 269 |
+
for metric, value in self.metric_results.items():
|
| 270 |
+
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def get_current_visuals(self):
|
| 274 |
+
out_dict = OrderedDict()
|
| 275 |
+
out_dict['gt'] = self.gt.detach().cpu()
|
| 276 |
+
out_dict['result'] = self.output.detach().cpu()
|
| 277 |
+
return out_dict
|
| 278 |
+
|
| 279 |
+
def save(self, epoch, current_iter):
|
| 280 |
+
if self.ema_decay > 0:
|
| 281 |
+
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
|
| 282 |
+
else:
|
| 283 |
+
self.save_network(self.net_g, 'net_g', current_iter)
|
| 284 |
+
self.save_network(self.net_d, 'net_d', current_iter)
|
| 285 |
+
self.save_training_state(epoch, current_iter)
|
basicsr/ops/__init__.py
ADDED
|
File without changes
|
basicsr/ops/dcn/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
|
| 2 |
+
modulated_deform_conv)
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
|
| 6 |
+
'modulated_deform_conv'
|
| 7 |
+
]
|
basicsr/ops/dcn/deform_conv.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
from torch.autograd import Function
|
| 5 |
+
from torch.autograd.function import once_differentiable
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from torch.nn.modules.utils import _pair, _single
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from . import deform_conv_ext
|
| 11 |
+
except ImportError:
|
| 12 |
+
import os
|
| 13 |
+
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
| 14 |
+
if BASICSR_JIT == 'True':
|
| 15 |
+
from torch.utils.cpp_extension import load
|
| 16 |
+
module_path = os.path.dirname(__file__)
|
| 17 |
+
deform_conv_ext = load(
|
| 18 |
+
'deform_conv',
|
| 19 |
+
sources=[
|
| 20 |
+
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
|
| 21 |
+
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
|
| 22 |
+
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
|
| 23 |
+
],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DeformConvFunction(Function):
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def forward(ctx,
|
| 31 |
+
input,
|
| 32 |
+
offset,
|
| 33 |
+
weight,
|
| 34 |
+
stride=1,
|
| 35 |
+
padding=0,
|
| 36 |
+
dilation=1,
|
| 37 |
+
groups=1,
|
| 38 |
+
deformable_groups=1,
|
| 39 |
+
im2col_step=64):
|
| 40 |
+
if input is not None and input.dim() != 4:
|
| 41 |
+
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
|
| 42 |
+
ctx.stride = _pair(stride)
|
| 43 |
+
ctx.padding = _pair(padding)
|
| 44 |
+
ctx.dilation = _pair(dilation)
|
| 45 |
+
ctx.groups = groups
|
| 46 |
+
ctx.deformable_groups = deformable_groups
|
| 47 |
+
ctx.im2col_step = im2col_step
|
| 48 |
+
|
| 49 |
+
ctx.save_for_backward(input, offset, weight)
|
| 50 |
+
|
| 51 |
+
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
|
| 52 |
+
|
| 53 |
+
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
| 54 |
+
|
| 55 |
+
if not input.is_cuda:
|
| 56 |
+
raise NotImplementedError
|
| 57 |
+
else:
|
| 58 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
| 59 |
+
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
| 60 |
+
deform_conv_ext.deform_conv_forward(input, weight,
|
| 61 |
+
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
| 62 |
+
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
| 63 |
+
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
| 64 |
+
ctx.deformable_groups, cur_im2col_step)
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
@once_differentiable
|
| 69 |
+
def backward(ctx, grad_output):
|
| 70 |
+
input, offset, weight = ctx.saved_tensors
|
| 71 |
+
|
| 72 |
+
grad_input = grad_offset = grad_weight = None
|
| 73 |
+
|
| 74 |
+
if not grad_output.is_cuda:
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
else:
|
| 77 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
| 78 |
+
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
| 79 |
+
|
| 80 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
| 81 |
+
grad_input = torch.zeros_like(input)
|
| 82 |
+
grad_offset = torch.zeros_like(offset)
|
| 83 |
+
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
|
| 84 |
+
grad_offset, weight, ctx.bufs_[0], weight.size(3),
|
| 85 |
+
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
|
| 86 |
+
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
|
| 87 |
+
ctx.deformable_groups, cur_im2col_step)
|
| 88 |
+
|
| 89 |
+
if ctx.needs_input_grad[2]:
|
| 90 |
+
grad_weight = torch.zeros_like(weight)
|
| 91 |
+
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
|
| 92 |
+
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
| 93 |
+
weight.size(2), ctx.stride[1], ctx.stride[0],
|
| 94 |
+
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
| 95 |
+
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
|
| 96 |
+
cur_im2col_step)
|
| 97 |
+
|
| 98 |
+
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _output_size(input, weight, padding, dilation, stride):
|
| 102 |
+
channels = weight.size(0)
|
| 103 |
+
output_size = (input.size(0), channels)
|
| 104 |
+
for d in range(input.dim() - 2):
|
| 105 |
+
in_size = input.size(d + 2)
|
| 106 |
+
pad = padding[d]
|
| 107 |
+
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
| 108 |
+
stride_ = stride[d]
|
| 109 |
+
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
| 110 |
+
if not all(map(lambda s: s > 0, output_size)):
|
| 111 |
+
raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
|
| 112 |
+
return output_size
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class ModulatedDeformConvFunction(Function):
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def forward(ctx,
|
| 119 |
+
input,
|
| 120 |
+
offset,
|
| 121 |
+
mask,
|
| 122 |
+
weight,
|
| 123 |
+
bias=None,
|
| 124 |
+
stride=1,
|
| 125 |
+
padding=0,
|
| 126 |
+
dilation=1,
|
| 127 |
+
groups=1,
|
| 128 |
+
deformable_groups=1):
|
| 129 |
+
ctx.stride = stride
|
| 130 |
+
ctx.padding = padding
|
| 131 |
+
ctx.dilation = dilation
|
| 132 |
+
ctx.groups = groups
|
| 133 |
+
ctx.deformable_groups = deformable_groups
|
| 134 |
+
ctx.with_bias = bias is not None
|
| 135 |
+
if not ctx.with_bias:
|
| 136 |
+
bias = input.new_empty(1) # fake tensor
|
| 137 |
+
if not input.is_cuda:
|
| 138 |
+
raise NotImplementedError
|
| 139 |
+
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
| 140 |
+
or input.requires_grad:
|
| 141 |
+
ctx.save_for_backward(input, offset, mask, weight, bias)
|
| 142 |
+
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
| 143 |
+
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
| 144 |
+
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
|
| 145 |
+
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
|
| 146 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
| 147 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
@once_differentiable
|
| 152 |
+
def backward(ctx, grad_output):
|
| 153 |
+
if not grad_output.is_cuda:
|
| 154 |
+
raise NotImplementedError
|
| 155 |
+
input, offset, mask, weight, bias = ctx.saved_tensors
|
| 156 |
+
grad_input = torch.zeros_like(input)
|
| 157 |
+
grad_offset = torch.zeros_like(offset)
|
| 158 |
+
grad_mask = torch.zeros_like(mask)
|
| 159 |
+
grad_weight = torch.zeros_like(weight)
|
| 160 |
+
grad_bias = torch.zeros_like(bias)
|
| 161 |
+
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
|
| 162 |
+
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
|
| 163 |
+
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
|
| 164 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
| 165 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
| 166 |
+
if not ctx.with_bias:
|
| 167 |
+
grad_bias = None
|
| 168 |
+
|
| 169 |
+
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def _infer_shape(ctx, input, weight):
|
| 173 |
+
n = input.size(0)
|
| 174 |
+
channels_out = weight.size(0)
|
| 175 |
+
height, width = input.shape[2:4]
|
| 176 |
+
kernel_h, kernel_w = weight.shape[2:4]
|
| 177 |
+
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
|
| 178 |
+
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
|
| 179 |
+
return n, channels_out, height_out, width_out
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
deform_conv = DeformConvFunction.apply
|
| 183 |
+
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class DeformConv(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self,
|
| 189 |
+
in_channels,
|
| 190 |
+
out_channels,
|
| 191 |
+
kernel_size,
|
| 192 |
+
stride=1,
|
| 193 |
+
padding=0,
|
| 194 |
+
dilation=1,
|
| 195 |
+
groups=1,
|
| 196 |
+
deformable_groups=1,
|
| 197 |
+
bias=False):
|
| 198 |
+
super(DeformConv, self).__init__()
|
| 199 |
+
|
| 200 |
+
assert not bias
|
| 201 |
+
assert in_channels % groups == 0, \
|
| 202 |
+
f'in_channels {in_channels} is not divisible by groups {groups}'
|
| 203 |
+
assert out_channels % groups == 0, \
|
| 204 |
+
f'out_channels {out_channels} is not divisible ' \
|
| 205 |
+
f'by groups {groups}'
|
| 206 |
+
|
| 207 |
+
self.in_channels = in_channels
|
| 208 |
+
self.out_channels = out_channels
|
| 209 |
+
self.kernel_size = _pair(kernel_size)
|
| 210 |
+
self.stride = _pair(stride)
|
| 211 |
+
self.padding = _pair(padding)
|
| 212 |
+
self.dilation = _pair(dilation)
|
| 213 |
+
self.groups = groups
|
| 214 |
+
self.deformable_groups = deformable_groups
|
| 215 |
+
# enable compatibility with nn.Conv2d
|
| 216 |
+
self.transposed = False
|
| 217 |
+
self.output_padding = _single(0)
|
| 218 |
+
|
| 219 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
|
| 220 |
+
|
| 221 |
+
self.reset_parameters()
|
| 222 |
+
|
| 223 |
+
def reset_parameters(self):
|
| 224 |
+
n = self.in_channels
|
| 225 |
+
for k in self.kernel_size:
|
| 226 |
+
n *= k
|
| 227 |
+
stdv = 1. / math.sqrt(n)
|
| 228 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 229 |
+
|
| 230 |
+
def forward(self, x, offset):
|
| 231 |
+
# To fix an assert error in deform_conv_cuda.cpp:128
|
| 232 |
+
# input image is smaller than kernel
|
| 233 |
+
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
|
| 234 |
+
if input_pad:
|
| 235 |
+
pad_h = max(self.kernel_size[0] - x.size(2), 0)
|
| 236 |
+
pad_w = max(self.kernel_size[1] - x.size(3), 0)
|
| 237 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
| 238 |
+
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
|
| 239 |
+
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
| 240 |
+
self.deformable_groups)
|
| 241 |
+
if input_pad:
|
| 242 |
+
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
|
| 243 |
+
return out
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class DeformConvPack(DeformConv):
|
| 247 |
+
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
in_channels (int): Same as nn.Conv2d.
|
| 251 |
+
out_channels (int): Same as nn.Conv2d.
|
| 252 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
| 253 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
| 254 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
| 255 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
| 256 |
+
groups (int): Same as nn.Conv2d.
|
| 257 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
| 258 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
| 259 |
+
False.
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
_version = 2
|
| 263 |
+
|
| 264 |
+
def __init__(self, *args, **kwargs):
|
| 265 |
+
super(DeformConvPack, self).__init__(*args, **kwargs)
|
| 266 |
+
|
| 267 |
+
self.conv_offset = nn.Conv2d(
|
| 268 |
+
self.in_channels,
|
| 269 |
+
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
|
| 270 |
+
kernel_size=self.kernel_size,
|
| 271 |
+
stride=_pair(self.stride),
|
| 272 |
+
padding=_pair(self.padding),
|
| 273 |
+
dilation=_pair(self.dilation),
|
| 274 |
+
bias=True)
|
| 275 |
+
self.init_offset()
|
| 276 |
+
|
| 277 |
+
def init_offset(self):
|
| 278 |
+
self.conv_offset.weight.data.zero_()
|
| 279 |
+
self.conv_offset.bias.data.zero_()
|
| 280 |
+
|
| 281 |
+
def forward(self, x):
|
| 282 |
+
offset = self.conv_offset(x)
|
| 283 |
+
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
|
| 284 |
+
self.deformable_groups)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class ModulatedDeformConv(nn.Module):
|
| 288 |
+
|
| 289 |
+
def __init__(self,
|
| 290 |
+
in_channels,
|
| 291 |
+
out_channels,
|
| 292 |
+
kernel_size,
|
| 293 |
+
stride=1,
|
| 294 |
+
padding=0,
|
| 295 |
+
dilation=1,
|
| 296 |
+
groups=1,
|
| 297 |
+
deformable_groups=1,
|
| 298 |
+
bias=True):
|
| 299 |
+
super(ModulatedDeformConv, self).__init__()
|
| 300 |
+
self.in_channels = in_channels
|
| 301 |
+
self.out_channels = out_channels
|
| 302 |
+
self.kernel_size = _pair(kernel_size)
|
| 303 |
+
self.stride = stride
|
| 304 |
+
self.padding = padding
|
| 305 |
+
self.dilation = dilation
|
| 306 |
+
self.groups = groups
|
| 307 |
+
self.deformable_groups = deformable_groups
|
| 308 |
+
self.with_bias = bias
|
| 309 |
+
# enable compatibility with nn.Conv2d
|
| 310 |
+
self.transposed = False
|
| 311 |
+
self.output_padding = _single(0)
|
| 312 |
+
|
| 313 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
| 314 |
+
if bias:
|
| 315 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
| 316 |
+
else:
|
| 317 |
+
self.register_parameter('bias', None)
|
| 318 |
+
self.init_weights()
|
| 319 |
+
|
| 320 |
+
def init_weights(self):
|
| 321 |
+
n = self.in_channels
|
| 322 |
+
for k in self.kernel_size:
|
| 323 |
+
n *= k
|
| 324 |
+
stdv = 1. / math.sqrt(n)
|
| 325 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 326 |
+
if self.bias is not None:
|
| 327 |
+
self.bias.data.zero_()
|
| 328 |
+
|
| 329 |
+
def forward(self, x, offset, mask):
|
| 330 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
| 331 |
+
self.groups, self.deformable_groups)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class ModulatedDeformConvPack(ModulatedDeformConv):
|
| 335 |
+
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
in_channels (int): Same as nn.Conv2d.
|
| 339 |
+
out_channels (int): Same as nn.Conv2d.
|
| 340 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
| 341 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
| 342 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
| 343 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
| 344 |
+
groups (int): Same as nn.Conv2d.
|
| 345 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
| 346 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
| 347 |
+
False.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
_version = 2
|
| 351 |
+
|
| 352 |
+
def __init__(self, *args, **kwargs):
|
| 353 |
+
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
| 354 |
+
|
| 355 |
+
self.conv_offset = nn.Conv2d(
|
| 356 |
+
self.in_channels,
|
| 357 |
+
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
| 358 |
+
kernel_size=self.kernel_size,
|
| 359 |
+
stride=_pair(self.stride),
|
| 360 |
+
padding=_pair(self.padding),
|
| 361 |
+
dilation=_pair(self.dilation),
|
| 362 |
+
bias=True)
|
| 363 |
+
self.init_weights()
|
| 364 |
+
|
| 365 |
+
def init_weights(self):
|
| 366 |
+
super(ModulatedDeformConvPack, self).init_weights()
|
| 367 |
+
if hasattr(self, 'conv_offset'):
|
| 368 |
+
self.conv_offset.weight.data.zero_()
|
| 369 |
+
self.conv_offset.bias.data.zero_()
|
| 370 |
+
|
| 371 |
+
def forward(self, x):
|
| 372 |
+
out = self.conv_offset(x)
|
| 373 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
| 374 |
+
offset = torch.cat((o1, o2), dim=1)
|
| 375 |
+
mask = torch.sigmoid(mask)
|
| 376 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
|
| 377 |
+
self.groups, self.deformable_groups)
|
basicsr/ops/dcn/src/deform_conv_cuda.cpp
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// modify from
|
| 2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
| 3 |
+
|
| 4 |
+
#include <torch/extension.h>
|
| 5 |
+
#include <ATen/DeviceGuard.h>
|
| 6 |
+
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
| 11 |
+
const int channels, const int height, const int width,
|
| 12 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
| 13 |
+
const int pad_w, const int stride_h, const int stride_w,
|
| 14 |
+
const int dilation_h, const int dilation_w,
|
| 15 |
+
const int parallel_imgs, const int deformable_group,
|
| 16 |
+
at::Tensor data_col);
|
| 17 |
+
|
| 18 |
+
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
| 19 |
+
const int channels, const int height, const int width,
|
| 20 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
| 21 |
+
const int pad_w, const int stride_h, const int stride_w,
|
| 22 |
+
const int dilation_h, const int dilation_w,
|
| 23 |
+
const int parallel_imgs, const int deformable_group,
|
| 24 |
+
at::Tensor grad_im);
|
| 25 |
+
|
| 26 |
+
void deformable_col2im_coord(
|
| 27 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
| 28 |
+
const at::Tensor data_offset, const int channels, const int height,
|
| 29 |
+
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
| 30 |
+
const int pad_w, const int stride_h, const int stride_w,
|
| 31 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
| 32 |
+
const int deformable_group, at::Tensor grad_offset);
|
| 33 |
+
|
| 34 |
+
void modulated_deformable_im2col_cuda(
|
| 35 |
+
const at::Tensor data_im, const at::Tensor data_offset,
|
| 36 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
| 37 |
+
const int height_im, const int width_im, const int height_col,
|
| 38 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
| 39 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 40 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
| 41 |
+
at::Tensor data_col);
|
| 42 |
+
|
| 43 |
+
void modulated_deformable_col2im_cuda(
|
| 44 |
+
const at::Tensor data_col, const at::Tensor data_offset,
|
| 45 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
| 46 |
+
const int height_im, const int width_im, const int height_col,
|
| 47 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
| 48 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 49 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
| 50 |
+
at::Tensor grad_im);
|
| 51 |
+
|
| 52 |
+
void modulated_deformable_col2im_coord_cuda(
|
| 53 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
| 54 |
+
const at::Tensor data_offset, const at::Tensor data_mask,
|
| 55 |
+
const int batch_size, const int channels, const int height_im,
|
| 56 |
+
const int width_im, const int height_col, const int width_col,
|
| 57 |
+
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
| 58 |
+
const int stride_h, const int stride_w, const int dilation_h,
|
| 59 |
+
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
| 60 |
+
at::Tensor grad_mask);
|
| 61 |
+
|
| 62 |
+
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
| 63 |
+
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
| 64 |
+
int padW, int dilationH, int dilationW, int group,
|
| 65 |
+
int deformable_group) {
|
| 66 |
+
TORCH_CHECK(weight.ndimension() == 4,
|
| 67 |
+
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
| 68 |
+
"but got: %s",
|
| 69 |
+
weight.ndimension());
|
| 70 |
+
|
| 71 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
| 72 |
+
|
| 73 |
+
TORCH_CHECK(kW > 0 && kH > 0,
|
| 74 |
+
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
| 75 |
+
kW);
|
| 76 |
+
|
| 77 |
+
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
| 78 |
+
"kernel size should be consistent with weight, ",
|
| 79 |
+
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
| 80 |
+
kW, weight.size(2), weight.size(3));
|
| 81 |
+
|
| 82 |
+
TORCH_CHECK(dW > 0 && dH > 0,
|
| 83 |
+
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
| 84 |
+
|
| 85 |
+
TORCH_CHECK(
|
| 86 |
+
dilationW > 0 && dilationH > 0,
|
| 87 |
+
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
| 88 |
+
dilationH, dilationW);
|
| 89 |
+
|
| 90 |
+
int ndim = input.ndimension();
|
| 91 |
+
int dimf = 0;
|
| 92 |
+
int dimh = 1;
|
| 93 |
+
int dimw = 2;
|
| 94 |
+
|
| 95 |
+
if (ndim == 4) {
|
| 96 |
+
dimf++;
|
| 97 |
+
dimh++;
|
| 98 |
+
dimw++;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
| 102 |
+
ndim);
|
| 103 |
+
|
| 104 |
+
long nInputPlane = weight.size(1) * group;
|
| 105 |
+
long inputHeight = input.size(dimh);
|
| 106 |
+
long inputWidth = input.size(dimw);
|
| 107 |
+
long nOutputPlane = weight.size(0);
|
| 108 |
+
long outputHeight =
|
| 109 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 110 |
+
long outputWidth =
|
| 111 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 112 |
+
|
| 113 |
+
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
| 114 |
+
"input channels must divide deformable group size");
|
| 115 |
+
|
| 116 |
+
if (outputWidth < 1 || outputHeight < 1)
|
| 117 |
+
AT_ERROR(
|
| 118 |
+
"Given input size: (%ld x %ld x %ld). "
|
| 119 |
+
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
| 120 |
+
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
| 121 |
+
outputWidth);
|
| 122 |
+
|
| 123 |
+
TORCH_CHECK(input.size(1) == nInputPlane,
|
| 124 |
+
"invalid number of input planes, expected: %d, but got: %d",
|
| 125 |
+
nInputPlane, input.size(1));
|
| 126 |
+
|
| 127 |
+
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
| 128 |
+
"input image is smaller than kernel");
|
| 129 |
+
|
| 130 |
+
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
| 131 |
+
"invalid spatial size of offset, expected height: %d width: %d, but "
|
| 132 |
+
"got height: %d width: %d",
|
| 133 |
+
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
| 134 |
+
|
| 135 |
+
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
| 136 |
+
"invalid number of channels of offset");
|
| 137 |
+
|
| 138 |
+
if (gradOutput != NULL) {
|
| 139 |
+
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
| 140 |
+
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
| 141 |
+
nOutputPlane, gradOutput->size(dimf));
|
| 142 |
+
|
| 143 |
+
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
|
| 144 |
+
gradOutput->size(dimw) == outputWidth),
|
| 145 |
+
"invalid size of gradOutput, expected height: %d width: %d , but "
|
| 146 |
+
"got height: %d width: %d",
|
| 147 |
+
outputHeight, outputWidth, gradOutput->size(dimh),
|
| 148 |
+
gradOutput->size(dimw));
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
| 153 |
+
at::Tensor offset, at::Tensor output,
|
| 154 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
| 155 |
+
int kH, int dW, int dH, int padW, int padH,
|
| 156 |
+
int dilationW, int dilationH, int group,
|
| 157 |
+
int deformable_group, int im2col_step) {
|
| 158 |
+
// todo: resize columns to include im2col: done
|
| 159 |
+
// todo: add im2col_step as input
|
| 160 |
+
// todo: add new output buffer and transpose it to output (or directly
|
| 161 |
+
// transpose output) todo: possibly change data indexing because of
|
| 162 |
+
// parallel_imgs
|
| 163 |
+
|
| 164 |
+
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
| 165 |
+
dilationH, dilationW, group, deformable_group);
|
| 166 |
+
at::DeviceGuard guard(input.device());
|
| 167 |
+
|
| 168 |
+
input = input.contiguous();
|
| 169 |
+
offset = offset.contiguous();
|
| 170 |
+
weight = weight.contiguous();
|
| 171 |
+
|
| 172 |
+
int batch = 1;
|
| 173 |
+
if (input.ndimension() == 3) {
|
| 174 |
+
// Force batch
|
| 175 |
+
batch = 0;
|
| 176 |
+
input.unsqueeze_(0);
|
| 177 |
+
offset.unsqueeze_(0);
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
// todo: assert batchsize dividable by im2col_step
|
| 181 |
+
|
| 182 |
+
long batchSize = input.size(0);
|
| 183 |
+
long nInputPlane = input.size(1);
|
| 184 |
+
long inputHeight = input.size(2);
|
| 185 |
+
long inputWidth = input.size(3);
|
| 186 |
+
|
| 187 |
+
long nOutputPlane = weight.size(0);
|
| 188 |
+
|
| 189 |
+
long outputWidth =
|
| 190 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 191 |
+
long outputHeight =
|
| 192 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 193 |
+
|
| 194 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
| 195 |
+
|
| 196 |
+
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
| 197 |
+
outputHeight, outputWidth});
|
| 198 |
+
columns = at::zeros(
|
| 199 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
| 200 |
+
input.options());
|
| 201 |
+
|
| 202 |
+
if (ones.ndimension() != 2 ||
|
| 203 |
+
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
| 204 |
+
ones = at::ones({outputHeight, outputWidth}, input.options());
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 208 |
+
inputHeight, inputWidth});
|
| 209 |
+
offset =
|
| 210 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
| 211 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 212 |
+
|
| 213 |
+
at::Tensor output_buffer =
|
| 214 |
+
at::zeros({batchSize / im2col_step, nOutputPlane,
|
| 215 |
+
im2col_step * outputHeight, outputWidth},
|
| 216 |
+
output.options());
|
| 217 |
+
|
| 218 |
+
output_buffer = output_buffer.view(
|
| 219 |
+
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
| 220 |
+
output_buffer.size(2), output_buffer.size(3)});
|
| 221 |
+
|
| 222 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
| 223 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
| 224 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
| 225 |
+
dilationW, im2col_step, deformable_group, columns);
|
| 226 |
+
|
| 227 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 228 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 229 |
+
weight.size(2), weight.size(3)});
|
| 230 |
+
|
| 231 |
+
for (int g = 0; g < group; g++) {
|
| 232 |
+
output_buffer[elt][g] = output_buffer[elt][g]
|
| 233 |
+
.flatten(1)
|
| 234 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
| 235 |
+
.view_as(output_buffer[elt][g]);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
output_buffer = output_buffer.view(
|
| 240 |
+
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
| 241 |
+
output_buffer.size(3), output_buffer.size(4)});
|
| 242 |
+
|
| 243 |
+
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
| 244 |
+
im2col_step, outputHeight, outputWidth});
|
| 245 |
+
output_buffer.transpose_(1, 2);
|
| 246 |
+
output.copy_(output_buffer);
|
| 247 |
+
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
| 248 |
+
|
| 249 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 250 |
+
offset = offset.view(
|
| 251 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 252 |
+
|
| 253 |
+
if (batch == 0) {
|
| 254 |
+
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
| 255 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
| 256 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
return 1;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
| 263 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
| 264 |
+
at::Tensor gradOffset, at::Tensor weight,
|
| 265 |
+
at::Tensor columns, int kW, int kH, int dW,
|
| 266 |
+
int dH, int padW, int padH, int dilationW,
|
| 267 |
+
int dilationH, int group,
|
| 268 |
+
int deformable_group, int im2col_step) {
|
| 269 |
+
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
| 270 |
+
dilationH, dilationW, group, deformable_group);
|
| 271 |
+
at::DeviceGuard guard(input.device());
|
| 272 |
+
|
| 273 |
+
input = input.contiguous();
|
| 274 |
+
offset = offset.contiguous();
|
| 275 |
+
gradOutput = gradOutput.contiguous();
|
| 276 |
+
weight = weight.contiguous();
|
| 277 |
+
|
| 278 |
+
int batch = 1;
|
| 279 |
+
|
| 280 |
+
if (input.ndimension() == 3) {
|
| 281 |
+
// Force batch
|
| 282 |
+
batch = 0;
|
| 283 |
+
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
| 284 |
+
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
| 285 |
+
gradOutput = gradOutput.view(
|
| 286 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
long batchSize = input.size(0);
|
| 290 |
+
long nInputPlane = input.size(1);
|
| 291 |
+
long inputHeight = input.size(2);
|
| 292 |
+
long inputWidth = input.size(3);
|
| 293 |
+
|
| 294 |
+
long nOutputPlane = weight.size(0);
|
| 295 |
+
|
| 296 |
+
long outputWidth =
|
| 297 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 298 |
+
long outputHeight =
|
| 299 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 300 |
+
|
| 301 |
+
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
| 302 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 303 |
+
columns = at::zeros(
|
| 304 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
| 305 |
+
input.options());
|
| 306 |
+
|
| 307 |
+
// change order of grad output
|
| 308 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
| 309 |
+
nOutputPlane, outputHeight, outputWidth});
|
| 310 |
+
gradOutput.transpose_(1, 2);
|
| 311 |
+
|
| 312 |
+
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 313 |
+
inputHeight, inputWidth});
|
| 314 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 315 |
+
inputHeight, inputWidth});
|
| 316 |
+
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
| 317 |
+
deformable_group * 2 * kH * kW, outputHeight,
|
| 318 |
+
outputWidth});
|
| 319 |
+
offset =
|
| 320 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
| 321 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 322 |
+
|
| 323 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
| 324 |
+
// divide into groups
|
| 325 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 326 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 327 |
+
weight.size(2), weight.size(3)});
|
| 328 |
+
gradOutput = gradOutput.view(
|
| 329 |
+
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
| 330 |
+
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
| 331 |
+
|
| 332 |
+
for (int g = 0; g < group; g++) {
|
| 333 |
+
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
| 334 |
+
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
columns =
|
| 338 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 339 |
+
gradOutput = gradOutput.view(
|
| 340 |
+
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
| 341 |
+
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
| 342 |
+
|
| 343 |
+
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
| 344 |
+
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
| 345 |
+
dilationH, dilationW, im2col_step, deformable_group,
|
| 346 |
+
gradOffset[elt]);
|
| 347 |
+
|
| 348 |
+
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
| 349 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
| 350 |
+
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
gradOutput.transpose_(1, 2);
|
| 354 |
+
gradOutput =
|
| 355 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
| 356 |
+
|
| 357 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 358 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 359 |
+
gradOffset = gradOffset.view(
|
| 360 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 361 |
+
offset = offset.view(
|
| 362 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 363 |
+
|
| 364 |
+
if (batch == 0) {
|
| 365 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
| 366 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
| 367 |
+
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
| 368 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
| 369 |
+
gradOffset =
|
| 370 |
+
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
return 1;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
int deform_conv_backward_parameters_cuda(
|
| 377 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
| 378 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
| 379 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
| 380 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
| 381 |
+
int deformable_group, float scale, int im2col_step) {
|
| 382 |
+
// todo: transpose and reshape outGrad
|
| 383 |
+
// todo: reshape columns
|
| 384 |
+
// todo: add im2col_step as input
|
| 385 |
+
|
| 386 |
+
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
| 387 |
+
padW, dilationH, dilationW, group, deformable_group);
|
| 388 |
+
at::DeviceGuard guard(input.device());
|
| 389 |
+
|
| 390 |
+
input = input.contiguous();
|
| 391 |
+
offset = offset.contiguous();
|
| 392 |
+
gradOutput = gradOutput.contiguous();
|
| 393 |
+
|
| 394 |
+
int batch = 1;
|
| 395 |
+
|
| 396 |
+
if (input.ndimension() == 3) {
|
| 397 |
+
// Force batch
|
| 398 |
+
batch = 0;
|
| 399 |
+
input = input.view(
|
| 400 |
+
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
| 401 |
+
gradOutput = gradOutput.view(
|
| 402 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
long batchSize = input.size(0);
|
| 406 |
+
long nInputPlane = input.size(1);
|
| 407 |
+
long inputHeight = input.size(2);
|
| 408 |
+
long inputWidth = input.size(3);
|
| 409 |
+
|
| 410 |
+
long nOutputPlane = gradWeight.size(0);
|
| 411 |
+
|
| 412 |
+
long outputWidth =
|
| 413 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
| 414 |
+
long outputHeight =
|
| 415 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
| 416 |
+
|
| 417 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
| 418 |
+
|
| 419 |
+
columns = at::zeros(
|
| 420 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
| 421 |
+
input.options());
|
| 422 |
+
|
| 423 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
| 424 |
+
nOutputPlane, outputHeight, outputWidth});
|
| 425 |
+
gradOutput.transpose_(1, 2);
|
| 426 |
+
|
| 427 |
+
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
| 428 |
+
gradOutputBuffer =
|
| 429 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
| 430 |
+
outputHeight, outputWidth});
|
| 431 |
+
gradOutputBuffer.copy_(gradOutput);
|
| 432 |
+
gradOutputBuffer =
|
| 433 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
| 434 |
+
im2col_step * outputHeight, outputWidth});
|
| 435 |
+
|
| 436 |
+
gradOutput.transpose_(1, 2);
|
| 437 |
+
gradOutput =
|
| 438 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
| 439 |
+
|
| 440 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
| 441 |
+
inputHeight, inputWidth});
|
| 442 |
+
offset =
|
| 443 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
| 444 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 445 |
+
|
| 446 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
| 447 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
| 448 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
| 449 |
+
dilationW, im2col_step, deformable_group, columns);
|
| 450 |
+
|
| 451 |
+
// divide into group
|
| 452 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
| 453 |
+
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
| 454 |
+
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
| 455 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 456 |
+
gradWeight =
|
| 457 |
+
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
| 458 |
+
gradWeight.size(2), gradWeight.size(3)});
|
| 459 |
+
|
| 460 |
+
for (int g = 0; g < group; g++) {
|
| 461 |
+
gradWeight[g] = gradWeight[g]
|
| 462 |
+
.flatten(1)
|
| 463 |
+
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
| 464 |
+
columns[g].transpose(1, 0), 1.0, scale)
|
| 465 |
+
.view_as(gradWeight[g]);
|
| 466 |
+
}
|
| 467 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
| 468 |
+
{gradOutputBuffer.size(0),
|
| 469 |
+
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
| 470 |
+
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
| 471 |
+
columns =
|
| 472 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 473 |
+
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
| 474 |
+
gradWeight.size(2), gradWeight.size(3),
|
| 475 |
+
gradWeight.size(4)});
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
| 479 |
+
offset = offset.view(
|
| 480 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
| 481 |
+
|
| 482 |
+
if (batch == 0) {
|
| 483 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
| 484 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
return 1;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
void modulated_deform_conv_cuda_forward(
|
| 491 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 492 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
| 493 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
| 494 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
| 495 |
+
const int dilation_w, const int group, const int deformable_group,
|
| 496 |
+
const bool with_bias) {
|
| 497 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
| 498 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
| 499 |
+
at::DeviceGuard guard(input.device());
|
| 500 |
+
|
| 501 |
+
const int batch = input.size(0);
|
| 502 |
+
const int channels = input.size(1);
|
| 503 |
+
const int height = input.size(2);
|
| 504 |
+
const int width = input.size(3);
|
| 505 |
+
|
| 506 |
+
const int channels_out = weight.size(0);
|
| 507 |
+
const int channels_kernel = weight.size(1);
|
| 508 |
+
const int kernel_h_ = weight.size(2);
|
| 509 |
+
const int kernel_w_ = weight.size(3);
|
| 510 |
+
|
| 511 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
| 512 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
| 513 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
| 514 |
+
if (channels != channels_kernel * group)
|
| 515 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
| 516 |
+
channels, channels_kernel * group);
|
| 517 |
+
|
| 518 |
+
const int height_out =
|
| 519 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
| 520 |
+
const int width_out =
|
| 521 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
| 522 |
+
|
| 523 |
+
if (ones.ndimension() != 2 ||
|
| 524 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
| 525 |
+
// Resize plane and fill with ones...
|
| 526 |
+
ones = at::ones({height_out, width_out}, input.options());
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
// resize output
|
| 530 |
+
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
| 531 |
+
// resize temporary columns
|
| 532 |
+
columns =
|
| 533 |
+
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
| 534 |
+
input.options());
|
| 535 |
+
|
| 536 |
+
output = output.view({output.size(0), group, output.size(1) / group,
|
| 537 |
+
output.size(2), output.size(3)});
|
| 538 |
+
|
| 539 |
+
for (int b = 0; b < batch; b++) {
|
| 540 |
+
modulated_deformable_im2col_cuda(
|
| 541 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
| 542 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 543 |
+
dilation_h, dilation_w, deformable_group, columns);
|
| 544 |
+
|
| 545 |
+
// divide into group
|
| 546 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 547 |
+
weight.size(2), weight.size(3)});
|
| 548 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 549 |
+
|
| 550 |
+
for (int g = 0; g < group; g++) {
|
| 551 |
+
output[b][g] = output[b][g]
|
| 552 |
+
.flatten(1)
|
| 553 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
| 554 |
+
.view_as(output[b][g]);
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
| 558 |
+
weight.size(3), weight.size(4)});
|
| 559 |
+
columns =
|
| 560 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
output = output.view({output.size(0), output.size(1) * output.size(2),
|
| 564 |
+
output.size(3), output.size(4)});
|
| 565 |
+
|
| 566 |
+
if (with_bias) {
|
| 567 |
+
output += bias.view({1, bias.size(0), 1, 1});
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
void modulated_deform_conv_cuda_backward(
|
| 572 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 573 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
| 574 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
| 575 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
| 576 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
| 577 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
| 578 |
+
const bool with_bias) {
|
| 579 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
| 580 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
| 581 |
+
at::DeviceGuard guard(input.device());
|
| 582 |
+
|
| 583 |
+
const int batch = input.size(0);
|
| 584 |
+
const int channels = input.size(1);
|
| 585 |
+
const int height = input.size(2);
|
| 586 |
+
const int width = input.size(3);
|
| 587 |
+
|
| 588 |
+
const int channels_kernel = weight.size(1);
|
| 589 |
+
const int kernel_h_ = weight.size(2);
|
| 590 |
+
const int kernel_w_ = weight.size(3);
|
| 591 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
| 592 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
| 593 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
| 594 |
+
if (channels != channels_kernel * group)
|
| 595 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
| 596 |
+
channels, channels_kernel * group);
|
| 597 |
+
|
| 598 |
+
const int height_out =
|
| 599 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
| 600 |
+
const int width_out =
|
| 601 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
| 602 |
+
|
| 603 |
+
if (ones.ndimension() != 2 ||
|
| 604 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
| 605 |
+
// Resize plane and fill with ones...
|
| 606 |
+
ones = at::ones({height_out, width_out}, input.options());
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
grad_input = grad_input.view({batch, channels, height, width});
|
| 610 |
+
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
| 611 |
+
input.options());
|
| 612 |
+
|
| 613 |
+
grad_output =
|
| 614 |
+
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
| 615 |
+
grad_output.size(2), grad_output.size(3)});
|
| 616 |
+
|
| 617 |
+
for (int b = 0; b < batch; b++) {
|
| 618 |
+
// divide int group
|
| 619 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 620 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
| 621 |
+
weight.size(2), weight.size(3)});
|
| 622 |
+
|
| 623 |
+
for (int g = 0; g < group; g++) {
|
| 624 |
+
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
| 625 |
+
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
columns =
|
| 629 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 630 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
| 631 |
+
weight.size(3), weight.size(4)});
|
| 632 |
+
|
| 633 |
+
// gradient w.r.t. input coordinate data
|
| 634 |
+
modulated_deformable_col2im_coord_cuda(
|
| 635 |
+
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
| 636 |
+
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
| 637 |
+
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
| 638 |
+
grad_mask[b]);
|
| 639 |
+
// gradient w.r.t. input data
|
| 640 |
+
modulated_deformable_col2im_cuda(
|
| 641 |
+
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
| 642 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 643 |
+
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
| 644 |
+
|
| 645 |
+
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
| 646 |
+
// group
|
| 647 |
+
modulated_deformable_im2col_cuda(
|
| 648 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
| 649 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 650 |
+
dilation_h, dilation_w, deformable_group, columns);
|
| 651 |
+
|
| 652 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
| 653 |
+
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
| 654 |
+
grad_weight.size(1), grad_weight.size(2),
|
| 655 |
+
grad_weight.size(3)});
|
| 656 |
+
if (with_bias)
|
| 657 |
+
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
| 658 |
+
|
| 659 |
+
for (int g = 0; g < group; g++) {
|
| 660 |
+
grad_weight[g] =
|
| 661 |
+
grad_weight[g]
|
| 662 |
+
.flatten(1)
|
| 663 |
+
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
| 664 |
+
.view_as(grad_weight[g]);
|
| 665 |
+
if (with_bias) {
|
| 666 |
+
grad_bias[g] =
|
| 667 |
+
grad_bias[g]
|
| 668 |
+
.view({-1, 1})
|
| 669 |
+
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
| 670 |
+
.view(-1);
|
| 671 |
+
}
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
columns =
|
| 675 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
| 676 |
+
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
| 677 |
+
grad_weight.size(2), grad_weight.size(3),
|
| 678 |
+
grad_weight.size(4)});
|
| 679 |
+
if (with_bias)
|
| 680 |
+
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
| 681 |
+
}
|
| 682 |
+
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
| 683 |
+
grad_output.size(2), grad_output.size(3),
|
| 684 |
+
grad_output.size(4)});
|
| 685 |
+
}
|
basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*!
|
| 2 |
+
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
|
| 3 |
+
*
|
| 4 |
+
* COPYRIGHT
|
| 5 |
+
*
|
| 6 |
+
* All contributions by the University of California:
|
| 7 |
+
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
| 8 |
+
* All rights reserved.
|
| 9 |
+
*
|
| 10 |
+
* All other contributions:
|
| 11 |
+
* Copyright (c) 2014-2017, the respective contributors
|
| 12 |
+
* All rights reserved.
|
| 13 |
+
*
|
| 14 |
+
* Caffe uses a shared copyright model: each contributor holds copyright over
|
| 15 |
+
* their contributions to Caffe. The project versioning records all such
|
| 16 |
+
* contribution and copyright details. If a contributor wants to further mark
|
| 17 |
+
* their specific copyright on a particular contribution, they should indicate
|
| 18 |
+
* their copyright solely in the commit message of the change when it is
|
| 19 |
+
* committed.
|
| 20 |
+
*
|
| 21 |
+
* LICENSE
|
| 22 |
+
*
|
| 23 |
+
* Redistribution and use in source and binary forms, with or without
|
| 24 |
+
* modification, are permitted provided that the following conditions are met:
|
| 25 |
+
*
|
| 26 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 27 |
+
* list of conditions and the following disclaimer.
|
| 28 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 29 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 30 |
+
* and/or other materials provided with the distribution.
|
| 31 |
+
*
|
| 32 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 33 |
+
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 34 |
+
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 35 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
| 36 |
+
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 37 |
+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 38 |
+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
| 39 |
+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 40 |
+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 41 |
+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 42 |
+
*
|
| 43 |
+
* CONTRIBUTION AGREEMENT
|
| 44 |
+
*
|
| 45 |
+
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
| 46 |
+
* or otherwise, the contributor releases their content to the
|
| 47 |
+
* license and copyright terms herein.
|
| 48 |
+
*
|
| 49 |
+
***************** END Caffe Copyright Notice and Disclaimer ********************
|
| 50 |
+
*
|
| 51 |
+
* Copyright (c) 2018 Microsoft
|
| 52 |
+
* Licensed under The MIT License [see LICENSE for details]
|
| 53 |
+
* \file modulated_deformable_im2col.cuh
|
| 54 |
+
* \brief Function definitions of converting an image to
|
| 55 |
+
* column matrix based on kernel, padding, dilation, and offset.
|
| 56 |
+
* These functions are mainly used in deformable convolution operators.
|
| 57 |
+
* \ref: https://arxiv.org/abs/1703.06211
|
| 58 |
+
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
| 59 |
+
*/
|
| 60 |
+
|
| 61 |
+
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
| 62 |
+
|
| 63 |
+
#include <ATen/ATen.h>
|
| 64 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 65 |
+
#include <THC/THCAtomics.cuh>
|
| 66 |
+
#include <stdio.h>
|
| 67 |
+
#include <math.h>
|
| 68 |
+
#include <float.h>
|
| 69 |
+
|
| 70 |
+
using namespace at;
|
| 71 |
+
|
| 72 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
| 73 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
| 74 |
+
i += blockDim.x * gridDim.x)
|
| 75 |
+
|
| 76 |
+
const int CUDA_NUM_THREADS = 1024;
|
| 77 |
+
const int kMaxGridNum = 65535;
|
| 78 |
+
|
| 79 |
+
inline int GET_BLOCKS(const int N)
|
| 80 |
+
{
|
| 81 |
+
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
template <typename scalar_t>
|
| 85 |
+
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
| 86 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
| 87 |
+
{
|
| 88 |
+
|
| 89 |
+
int h_low = floor(h);
|
| 90 |
+
int w_low = floor(w);
|
| 91 |
+
int h_high = h_low + 1;
|
| 92 |
+
int w_high = w_low + 1;
|
| 93 |
+
|
| 94 |
+
scalar_t lh = h - h_low;
|
| 95 |
+
scalar_t lw = w - w_low;
|
| 96 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 97 |
+
|
| 98 |
+
scalar_t v1 = 0;
|
| 99 |
+
if (h_low >= 0 && w_low >= 0)
|
| 100 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
| 101 |
+
scalar_t v2 = 0;
|
| 102 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 103 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
| 104 |
+
scalar_t v3 = 0;
|
| 105 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 106 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
| 107 |
+
scalar_t v4 = 0;
|
| 108 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 109 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
| 110 |
+
|
| 111 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 112 |
+
|
| 113 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 114 |
+
return val;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
template <typename scalar_t>
|
| 118 |
+
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 119 |
+
const int h, const int w, const int height, const int width)
|
| 120 |
+
{
|
| 121 |
+
|
| 122 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 123 |
+
{
|
| 124 |
+
//empty
|
| 125 |
+
return 0;
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
int argmax_h_low = floor(argmax_h);
|
| 129 |
+
int argmax_w_low = floor(argmax_w);
|
| 130 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 131 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 132 |
+
|
| 133 |
+
scalar_t weight = 0;
|
| 134 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
| 135 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
| 136 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
| 137 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
| 138 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
| 139 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
| 140 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
| 141 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
| 142 |
+
return weight;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <typename scalar_t>
|
| 146 |
+
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 147 |
+
const int height, const int width, const scalar_t *im_data,
|
| 148 |
+
const int data_width, const int bp_dir)
|
| 149 |
+
{
|
| 150 |
+
|
| 151 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 152 |
+
{
|
| 153 |
+
//empty
|
| 154 |
+
return 0;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
int argmax_h_low = floor(argmax_h);
|
| 158 |
+
int argmax_w_low = floor(argmax_w);
|
| 159 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 160 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 161 |
+
|
| 162 |
+
scalar_t weight = 0;
|
| 163 |
+
|
| 164 |
+
if (bp_dir == 0)
|
| 165 |
+
{
|
| 166 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 167 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 168 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 169 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 170 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 171 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 172 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 173 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 174 |
+
}
|
| 175 |
+
else if (bp_dir == 1)
|
| 176 |
+
{
|
| 177 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 178 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 179 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 180 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 181 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 182 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 183 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 184 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
return weight;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
template <typename scalar_t>
|
| 191 |
+
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
|
| 192 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
| 193 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 194 |
+
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
|
| 195 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
| 196 |
+
const int height_col, const int width_col,
|
| 197 |
+
scalar_t *data_col)
|
| 198 |
+
{
|
| 199 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 200 |
+
{
|
| 201 |
+
// index index of output matrix
|
| 202 |
+
const int w_col = index % width_col;
|
| 203 |
+
const int h_col = (index / width_col) % height_col;
|
| 204 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
| 205 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
| 206 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
| 207 |
+
|
| 208 |
+
// compute deformable group index
|
| 209 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
| 210 |
+
|
| 211 |
+
const int h_in = h_col * stride_h - pad_h;
|
| 212 |
+
const int w_in = w_col * stride_w - pad_w;
|
| 213 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
| 214 |
+
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
| 215 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
| 216 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 217 |
+
|
| 218 |
+
for (int i = 0; i < kernel_h; ++i)
|
| 219 |
+
{
|
| 220 |
+
for (int j = 0; j < kernel_w; ++j)
|
| 221 |
+
{
|
| 222 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
| 223 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
| 224 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 225 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 226 |
+
scalar_t val = static_cast<scalar_t>(0);
|
| 227 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
| 228 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
| 229 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
| 230 |
+
{
|
| 231 |
+
//const scalar_t map_h = i * dilation_h + offset_h;
|
| 232 |
+
//const scalar_t map_w = j * dilation_w + offset_w;
|
| 233 |
+
//const int cur_height = height - h_in;
|
| 234 |
+
//const int cur_width = width - w_in;
|
| 235 |
+
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
| 236 |
+
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
| 237 |
+
}
|
| 238 |
+
*data_col_ptr = val;
|
| 239 |
+
data_col_ptr += batch_size * height_col * width_col;
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
void deformable_im2col(
|
| 246 |
+
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
|
| 247 |
+
const int height, const int width, const int ksize_h, const int ksize_w,
|
| 248 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 249 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
| 250 |
+
const int deformable_group, at::Tensor data_col)
|
| 251 |
+
{
|
| 252 |
+
// num_axes should be smaller than block size
|
| 253 |
+
// todo: check parallel_imgs is correctly passed in
|
| 254 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
| 255 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
| 256 |
+
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
| 257 |
+
int channel_per_deformable_group = channels / deformable_group;
|
| 258 |
+
|
| 259 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 260 |
+
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
|
| 261 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 262 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 263 |
+
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 264 |
+
|
| 265 |
+
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 266 |
+
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
|
| 267 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
| 268 |
+
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
|
| 269 |
+
height_col, width_col, data_col_);
|
| 270 |
+
}));
|
| 271 |
+
|
| 272 |
+
cudaError_t err = cudaGetLastError();
|
| 273 |
+
if (err != cudaSuccess)
|
| 274 |
+
{
|
| 275 |
+
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
template <typename scalar_t>
|
| 280 |
+
__global__ void deformable_col2im_gpu_kernel(
|
| 281 |
+
const int n, const scalar_t *data_col, const scalar_t *data_offset,
|
| 282 |
+
const int channels, const int height, const int width,
|
| 283 |
+
const int kernel_h, const int kernel_w,
|
| 284 |
+
const int pad_h, const int pad_w,
|
| 285 |
+
const int stride_h, const int stride_w,
|
| 286 |
+
const int dilation_h, const int dilation_w,
|
| 287 |
+
const int channel_per_deformable_group,
|
| 288 |
+
const int batch_size, const int deformable_group,
|
| 289 |
+
const int height_col, const int width_col,
|
| 290 |
+
scalar_t *grad_im)
|
| 291 |
+
{
|
| 292 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 293 |
+
{
|
| 294 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
| 295 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 296 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
| 297 |
+
// compute the start and end of the output
|
| 298 |
+
|
| 299 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
| 300 |
+
|
| 301 |
+
int w_out = index % width_col;
|
| 302 |
+
int h_out = (index / width_col) % height_col;
|
| 303 |
+
int b = (index / width_col / height_col) % batch_size;
|
| 304 |
+
int w_in = w_out * stride_w - pad_w;
|
| 305 |
+
int h_in = h_out * stride_h - pad_h;
|
| 306 |
+
|
| 307 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
|
| 308 |
+
2 * kernel_h * kernel_w * height_col * width_col;
|
| 309 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
| 310 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
| 311 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 312 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 313 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
| 314 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
| 315 |
+
|
| 316 |
+
const scalar_t cur_top_grad = data_col[index];
|
| 317 |
+
const int cur_h = (int)cur_inv_h_data;
|
| 318 |
+
const int cur_w = (int)cur_inv_w_data;
|
| 319 |
+
for (int dy = -2; dy <= 2; dy++)
|
| 320 |
+
{
|
| 321 |
+
for (int dx = -2; dx <= 2; dx++)
|
| 322 |
+
{
|
| 323 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
| 324 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
| 325 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
| 326 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
| 327 |
+
{
|
| 328 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
| 329 |
+
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
| 330 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
void deformable_col2im(
|
| 338 |
+
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
|
| 339 |
+
const int height, const int width, const int ksize_h,
|
| 340 |
+
const int ksize_w, const int pad_h, const int pad_w,
|
| 341 |
+
const int stride_h, const int stride_w,
|
| 342 |
+
const int dilation_h, const int dilation_w,
|
| 343 |
+
const int parallel_imgs, const int deformable_group,
|
| 344 |
+
at::Tensor grad_im)
|
| 345 |
+
{
|
| 346 |
+
|
| 347 |
+
// todo: make sure parallel_imgs is passed in correctly
|
| 348 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
| 349 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
| 350 |
+
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
| 351 |
+
int channel_per_deformable_group = channels / deformable_group;
|
| 352 |
+
|
| 353 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 354 |
+
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
|
| 355 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 356 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 357 |
+
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
| 358 |
+
|
| 359 |
+
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 360 |
+
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
|
| 361 |
+
ksize_w, pad_h, pad_w, stride_h, stride_w,
|
| 362 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 363 |
+
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
|
| 364 |
+
}));
|
| 365 |
+
|
| 366 |
+
cudaError_t err = cudaGetLastError();
|
| 367 |
+
if (err != cudaSuccess)
|
| 368 |
+
{
|
| 369 |
+
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
| 370 |
+
}
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
template <typename scalar_t>
|
| 374 |
+
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
|
| 375 |
+
const scalar_t *data_im, const scalar_t *data_offset,
|
| 376 |
+
const int channels, const int height, const int width,
|
| 377 |
+
const int kernel_h, const int kernel_w,
|
| 378 |
+
const int pad_h, const int pad_w,
|
| 379 |
+
const int stride_h, const int stride_w,
|
| 380 |
+
const int dilation_h, const int dilation_w,
|
| 381 |
+
const int channel_per_deformable_group,
|
| 382 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
| 383 |
+
const int height_col, const int width_col, scalar_t *grad_offset)
|
| 384 |
+
{
|
| 385 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 386 |
+
{
|
| 387 |
+
scalar_t val = 0;
|
| 388 |
+
int w = index % width_col;
|
| 389 |
+
int h = (index / width_col) % height_col;
|
| 390 |
+
int c = (index / width_col / height_col) % offset_channels;
|
| 391 |
+
int b = (index / width_col / height_col) / offset_channels;
|
| 392 |
+
// compute the start and end of the output
|
| 393 |
+
|
| 394 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
| 395 |
+
const int col_step = kernel_h * kernel_w;
|
| 396 |
+
int cnt = 0;
|
| 397 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
|
| 398 |
+
batch_size * width_col * height_col;
|
| 399 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
|
| 400 |
+
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
| 401 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
| 402 |
+
kernel_h * kernel_w * height_col * width_col;
|
| 403 |
+
|
| 404 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
| 405 |
+
|
| 406 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
| 407 |
+
{
|
| 408 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
| 409 |
+
const int bp_dir = offset_c % 2;
|
| 410 |
+
|
| 411 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
| 412 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 413 |
+
int w_out = col_pos % width_col;
|
| 414 |
+
int h_out = (col_pos / width_col) % height_col;
|
| 415 |
+
int w_in = w_out * stride_w - pad_w;
|
| 416 |
+
int h_in = h_out * stride_h - pad_h;
|
| 417 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
| 418 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
| 419 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 420 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 421 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
| 422 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
| 423 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
| 424 |
+
{
|
| 425 |
+
inv_h = inv_w = -2;
|
| 426 |
+
}
|
| 427 |
+
const scalar_t weight = get_coordinate_weight(
|
| 428 |
+
inv_h, inv_w,
|
| 429 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
| 430 |
+
val += weight * data_col_ptr[col_pos];
|
| 431 |
+
cnt += 1;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
grad_offset[index] = val;
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
void deformable_col2im_coord(
|
| 439 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
|
| 440 |
+
const int channels, const int height, const int width, const int ksize_h,
|
| 441 |
+
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
|
| 442 |
+
const int stride_w, const int dilation_h, const int dilation_w,
|
| 443 |
+
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
|
| 444 |
+
{
|
| 445 |
+
|
| 446 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
| 447 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
| 448 |
+
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
|
| 449 |
+
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
|
| 450 |
+
|
| 451 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 452 |
+
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
|
| 453 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 454 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 455 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 456 |
+
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
| 457 |
+
|
| 458 |
+
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 459 |
+
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
|
| 460 |
+
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
| 461 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 462 |
+
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
| 463 |
+
height_col, width_col, grad_offset_);
|
| 464 |
+
}));
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
template <typename scalar_t>
|
| 468 |
+
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
| 469 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
| 470 |
+
{
|
| 471 |
+
int h_low = floor(h);
|
| 472 |
+
int w_low = floor(w);
|
| 473 |
+
int h_high = h_low + 1;
|
| 474 |
+
int w_high = w_low + 1;
|
| 475 |
+
|
| 476 |
+
scalar_t lh = h - h_low;
|
| 477 |
+
scalar_t lw = w - w_low;
|
| 478 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
| 479 |
+
|
| 480 |
+
scalar_t v1 = 0;
|
| 481 |
+
if (h_low >= 0 && w_low >= 0)
|
| 482 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
| 483 |
+
scalar_t v2 = 0;
|
| 484 |
+
if (h_low >= 0 && w_high <= width - 1)
|
| 485 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
| 486 |
+
scalar_t v3 = 0;
|
| 487 |
+
if (h_high <= height - 1 && w_low >= 0)
|
| 488 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
| 489 |
+
scalar_t v4 = 0;
|
| 490 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
| 491 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
| 492 |
+
|
| 493 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
| 494 |
+
|
| 495 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
| 496 |
+
return val;
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
template <typename scalar_t>
|
| 500 |
+
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 501 |
+
const int h, const int w, const int height, const int width)
|
| 502 |
+
{
|
| 503 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 504 |
+
{
|
| 505 |
+
//empty
|
| 506 |
+
return 0;
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
int argmax_h_low = floor(argmax_h);
|
| 510 |
+
int argmax_w_low = floor(argmax_w);
|
| 511 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 512 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 513 |
+
|
| 514 |
+
scalar_t weight = 0;
|
| 515 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
| 516 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
| 517 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
| 518 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
| 519 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
| 520 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
| 521 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
| 522 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
| 523 |
+
return weight;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
template <typename scalar_t>
|
| 527 |
+
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
| 528 |
+
const int height, const int width, const scalar_t *im_data,
|
| 529 |
+
const int data_width, const int bp_dir)
|
| 530 |
+
{
|
| 531 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
| 532 |
+
{
|
| 533 |
+
//empty
|
| 534 |
+
return 0;
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
int argmax_h_low = floor(argmax_h);
|
| 538 |
+
int argmax_w_low = floor(argmax_w);
|
| 539 |
+
int argmax_h_high = argmax_h_low + 1;
|
| 540 |
+
int argmax_w_high = argmax_w_low + 1;
|
| 541 |
+
|
| 542 |
+
scalar_t weight = 0;
|
| 543 |
+
|
| 544 |
+
if (bp_dir == 0)
|
| 545 |
+
{
|
| 546 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 547 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 548 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 549 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 550 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 551 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 552 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 553 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 554 |
+
}
|
| 555 |
+
else if (bp_dir == 1)
|
| 556 |
+
{
|
| 557 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
| 558 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
| 559 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
| 560 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
| 561 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
| 562 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
| 563 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
| 564 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
return weight;
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
template <typename scalar_t>
|
| 571 |
+
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
|
| 572 |
+
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
|
| 573 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
| 574 |
+
const int pad_h, const int pad_w,
|
| 575 |
+
const int stride_h, const int stride_w,
|
| 576 |
+
const int dilation_h, const int dilation_w,
|
| 577 |
+
const int channel_per_deformable_group,
|
| 578 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
| 579 |
+
const int height_col, const int width_col,
|
| 580 |
+
scalar_t *data_col)
|
| 581 |
+
{
|
| 582 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 583 |
+
{
|
| 584 |
+
// index index of output matrix
|
| 585 |
+
const int w_col = index % width_col;
|
| 586 |
+
const int h_col = (index / width_col) % height_col;
|
| 587 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
| 588 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
| 589 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
| 590 |
+
|
| 591 |
+
// compute deformable group index
|
| 592 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
| 593 |
+
|
| 594 |
+
const int h_in = h_col * stride_h - pad_h;
|
| 595 |
+
const int w_in = w_col * stride_w - pad_w;
|
| 596 |
+
|
| 597 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
| 598 |
+
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
| 599 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
| 600 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 601 |
+
|
| 602 |
+
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
| 603 |
+
|
| 604 |
+
for (int i = 0; i < kernel_h; ++i)
|
| 605 |
+
{
|
| 606 |
+
for (int j = 0; j < kernel_w; ++j)
|
| 607 |
+
{
|
| 608 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
| 609 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
| 610 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
| 611 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 612 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 613 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
| 614 |
+
scalar_t val = static_cast<scalar_t>(0);
|
| 615 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
| 616 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
| 617 |
+
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
| 618 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
| 619 |
+
{
|
| 620 |
+
//const float map_h = i * dilation_h + offset_h;
|
| 621 |
+
//const float map_w = j * dilation_w + offset_w;
|
| 622 |
+
//const int cur_height = height - h_in;
|
| 623 |
+
//const int cur_width = width - w_in;
|
| 624 |
+
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
| 625 |
+
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
| 626 |
+
}
|
| 627 |
+
*data_col_ptr = val * mask;
|
| 628 |
+
data_col_ptr += batch_size * height_col * width_col;
|
| 629 |
+
//data_col_ptr += height_col * width_col;
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
template <typename scalar_t>
|
| 636 |
+
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
|
| 637 |
+
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
|
| 638 |
+
const int channels, const int height, const int width,
|
| 639 |
+
const int kernel_h, const int kernel_w,
|
| 640 |
+
const int pad_h, const int pad_w,
|
| 641 |
+
const int stride_h, const int stride_w,
|
| 642 |
+
const int dilation_h, const int dilation_w,
|
| 643 |
+
const int channel_per_deformable_group,
|
| 644 |
+
const int batch_size, const int deformable_group,
|
| 645 |
+
const int height_col, const int width_col,
|
| 646 |
+
scalar_t *grad_im)
|
| 647 |
+
{
|
| 648 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 649 |
+
{
|
| 650 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
| 651 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 652 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
| 653 |
+
// compute the start and end of the output
|
| 654 |
+
|
| 655 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
| 656 |
+
|
| 657 |
+
int w_out = index % width_col;
|
| 658 |
+
int h_out = (index / width_col) % height_col;
|
| 659 |
+
int b = (index / width_col / height_col) % batch_size;
|
| 660 |
+
int w_in = w_out * stride_w - pad_w;
|
| 661 |
+
int h_in = h_out * stride_h - pad_h;
|
| 662 |
+
|
| 663 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 664 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
| 665 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
| 666 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
| 667 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
| 668 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 669 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 670 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
| 671 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
| 672 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
| 673 |
+
|
| 674 |
+
const scalar_t cur_top_grad = data_col[index] * mask;
|
| 675 |
+
const int cur_h = (int)cur_inv_h_data;
|
| 676 |
+
const int cur_w = (int)cur_inv_w_data;
|
| 677 |
+
for (int dy = -2; dy <= 2; dy++)
|
| 678 |
+
{
|
| 679 |
+
for (int dx = -2; dx <= 2; dx++)
|
| 680 |
+
{
|
| 681 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
| 682 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
| 683 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
| 684 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
| 685 |
+
{
|
| 686 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
| 687 |
+
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
| 688 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
| 689 |
+
}
|
| 690 |
+
}
|
| 691 |
+
}
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
template <typename scalar_t>
|
| 696 |
+
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
|
| 697 |
+
const scalar_t *data_col, const scalar_t *data_im,
|
| 698 |
+
const scalar_t *data_offset, const scalar_t *data_mask,
|
| 699 |
+
const int channels, const int height, const int width,
|
| 700 |
+
const int kernel_h, const int kernel_w,
|
| 701 |
+
const int pad_h, const int pad_w,
|
| 702 |
+
const int stride_h, const int stride_w,
|
| 703 |
+
const int dilation_h, const int dilation_w,
|
| 704 |
+
const int channel_per_deformable_group,
|
| 705 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
| 706 |
+
const int height_col, const int width_col,
|
| 707 |
+
scalar_t *grad_offset, scalar_t *grad_mask)
|
| 708 |
+
{
|
| 709 |
+
CUDA_KERNEL_LOOP(index, n)
|
| 710 |
+
{
|
| 711 |
+
scalar_t val = 0, mval = 0;
|
| 712 |
+
int w = index % width_col;
|
| 713 |
+
int h = (index / width_col) % height_col;
|
| 714 |
+
int c = (index / width_col / height_col) % offset_channels;
|
| 715 |
+
int b = (index / width_col / height_col) / offset_channels;
|
| 716 |
+
// compute the start and end of the output
|
| 717 |
+
|
| 718 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
| 719 |
+
const int col_step = kernel_h * kernel_w;
|
| 720 |
+
int cnt = 0;
|
| 721 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
|
| 722 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
| 723 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
| 724 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
| 725 |
+
|
| 726 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
| 727 |
+
|
| 728 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
| 729 |
+
{
|
| 730 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
| 731 |
+
const int bp_dir = offset_c % 2;
|
| 732 |
+
|
| 733 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
| 734 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
| 735 |
+
int w_out = col_pos % width_col;
|
| 736 |
+
int h_out = (col_pos / width_col) % height_col;
|
| 737 |
+
int w_in = w_out * stride_w - pad_w;
|
| 738 |
+
int h_in = h_out * stride_h - pad_h;
|
| 739 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
| 740 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
| 741 |
+
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
| 742 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
| 743 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
| 744 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
| 745 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
| 746 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
| 747 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
| 748 |
+
{
|
| 749 |
+
inv_h = inv_w = -2;
|
| 750 |
+
}
|
| 751 |
+
else
|
| 752 |
+
{
|
| 753 |
+
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
|
| 754 |
+
}
|
| 755 |
+
const scalar_t weight = dmcn_get_coordinate_weight(
|
| 756 |
+
inv_h, inv_w,
|
| 757 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
| 758 |
+
val += weight * data_col_ptr[col_pos] * mask;
|
| 759 |
+
cnt += 1;
|
| 760 |
+
}
|
| 761 |
+
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
| 762 |
+
grad_offset[index] = val;
|
| 763 |
+
if (offset_c % 2 == 0)
|
| 764 |
+
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
|
| 765 |
+
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
|
| 766 |
+
}
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
void modulated_deformable_im2col_cuda(
|
| 770 |
+
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
| 771 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
| 772 |
+
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
|
| 773 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 774 |
+
const int dilation_h, const int dilation_w,
|
| 775 |
+
const int deformable_group, at::Tensor data_col)
|
| 776 |
+
{
|
| 777 |
+
// num_axes should be smaller than block size
|
| 778 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
| 779 |
+
const int num_kernels = channels * batch_size * height_col * width_col;
|
| 780 |
+
|
| 781 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 782 |
+
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
|
| 783 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 784 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 785 |
+
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
| 786 |
+
scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 787 |
+
|
| 788 |
+
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 789 |
+
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
|
| 790 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
| 791 |
+
batch_size, channels, deformable_group, height_col, width_col, data_col_);
|
| 792 |
+
}));
|
| 793 |
+
|
| 794 |
+
cudaError_t err = cudaGetLastError();
|
| 795 |
+
if (err != cudaSuccess)
|
| 796 |
+
{
|
| 797 |
+
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
| 798 |
+
}
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
void modulated_deformable_col2im_cuda(
|
| 802 |
+
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
|
| 803 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
| 804 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
| 805 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 806 |
+
const int dilation_h, const int dilation_w,
|
| 807 |
+
const int deformable_group, at::Tensor grad_im)
|
| 808 |
+
{
|
| 809 |
+
|
| 810 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
| 811 |
+
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
| 812 |
+
|
| 813 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 814 |
+
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
|
| 815 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 816 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 817 |
+
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
| 818 |
+
scalar_t *grad_im_ = grad_im.data_ptr<scalar_t>();
|
| 819 |
+
|
| 820 |
+
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 821 |
+
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
|
| 822 |
+
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 823 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 824 |
+
batch_size, deformable_group, height_col, width_col, grad_im_);
|
| 825 |
+
}));
|
| 826 |
+
|
| 827 |
+
cudaError_t err = cudaGetLastError();
|
| 828 |
+
if (err != cudaSuccess)
|
| 829 |
+
{
|
| 830 |
+
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
| 831 |
+
}
|
| 832 |
+
}
|
| 833 |
+
|
| 834 |
+
void modulated_deformable_col2im_coord_cuda(
|
| 835 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
| 836 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
| 837 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
| 838 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
| 839 |
+
const int dilation_h, const int dilation_w,
|
| 840 |
+
const int deformable_group,
|
| 841 |
+
at::Tensor grad_offset, at::Tensor grad_mask)
|
| 842 |
+
{
|
| 843 |
+
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
|
| 844 |
+
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
|
| 845 |
+
|
| 846 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
| 847 |
+
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
| 848 |
+
const scalar_t *data_col_ = data_col.data_ptr<scalar_t>();
|
| 849 |
+
const scalar_t *data_im_ = data_im.data_ptr<scalar_t>();
|
| 850 |
+
const scalar_t *data_offset_ = data_offset.data_ptr<scalar_t>();
|
| 851 |
+
const scalar_t *data_mask_ = data_mask.data_ptr<scalar_t>();
|
| 852 |
+
scalar_t *grad_offset_ = grad_offset.data_ptr<scalar_t>();
|
| 853 |
+
scalar_t *grad_mask_ = grad_mask.data_ptr<scalar_t>();
|
| 854 |
+
|
| 855 |
+
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
|
| 856 |
+
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
|
| 857 |
+
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
| 858 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
| 859 |
+
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
|
| 860 |
+
grad_offset_, grad_mask_);
|
| 861 |
+
}));
|
| 862 |
+
cudaError_t err = cudaGetLastError();
|
| 863 |
+
if (err != cudaSuccess)
|
| 864 |
+
{
|
| 865 |
+
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
|
| 866 |
+
}
|
| 867 |
+
}
|
basicsr/ops/dcn/src/deform_conv_ext.cpp
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// modify from
|
| 2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
| 3 |
+
|
| 4 |
+
#include <torch/extension.h>
|
| 5 |
+
#include <ATen/DeviceGuard.h>
|
| 6 |
+
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <vector>
|
| 9 |
+
|
| 10 |
+
#define WITH_CUDA // always use cuda
|
| 11 |
+
#ifdef WITH_CUDA
|
| 12 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
| 13 |
+
at::Tensor offset, at::Tensor output,
|
| 14 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
| 15 |
+
int kH, int dW, int dH, int padW, int padH,
|
| 16 |
+
int dilationW, int dilationH, int group,
|
| 17 |
+
int deformable_group, int im2col_step);
|
| 18 |
+
|
| 19 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
| 20 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
| 21 |
+
at::Tensor gradOffset, at::Tensor weight,
|
| 22 |
+
at::Tensor columns, int kW, int kH, int dW,
|
| 23 |
+
int dH, int padW, int padH, int dilationW,
|
| 24 |
+
int dilationH, int group,
|
| 25 |
+
int deformable_group, int im2col_step);
|
| 26 |
+
|
| 27 |
+
int deform_conv_backward_parameters_cuda(
|
| 28 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
| 29 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
| 30 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
| 31 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
| 32 |
+
int deformable_group, float scale, int im2col_step);
|
| 33 |
+
|
| 34 |
+
void modulated_deform_conv_cuda_forward(
|
| 35 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 36 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
| 37 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
| 38 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
| 39 |
+
const int dilation_w, const int group, const int deformable_group,
|
| 40 |
+
const bool with_bias);
|
| 41 |
+
|
| 42 |
+
void modulated_deform_conv_cuda_backward(
|
| 43 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 44 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
| 45 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
| 46 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
| 47 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
| 48 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
| 49 |
+
const bool with_bias);
|
| 50 |
+
#endif
|
| 51 |
+
|
| 52 |
+
int deform_conv_forward(at::Tensor input, at::Tensor weight,
|
| 53 |
+
at::Tensor offset, at::Tensor output,
|
| 54 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
| 55 |
+
int kH, int dW, int dH, int padW, int padH,
|
| 56 |
+
int dilationW, int dilationH, int group,
|
| 57 |
+
int deformable_group, int im2col_step) {
|
| 58 |
+
if (input.device().is_cuda()) {
|
| 59 |
+
#ifdef WITH_CUDA
|
| 60 |
+
return deform_conv_forward_cuda(input, weight, offset, output, columns,
|
| 61 |
+
ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
|
| 62 |
+
deformable_group, im2col_step);
|
| 63 |
+
#else
|
| 64 |
+
AT_ERROR("deform conv is not compiled with GPU support");
|
| 65 |
+
#endif
|
| 66 |
+
}
|
| 67 |
+
AT_ERROR("deform conv is not implemented on CPU");
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
|
| 71 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
| 72 |
+
at::Tensor gradOffset, at::Tensor weight,
|
| 73 |
+
at::Tensor columns, int kW, int kH, int dW,
|
| 74 |
+
int dH, int padW, int padH, int dilationW,
|
| 75 |
+
int dilationH, int group,
|
| 76 |
+
int deformable_group, int im2col_step) {
|
| 77 |
+
if (input.device().is_cuda()) {
|
| 78 |
+
#ifdef WITH_CUDA
|
| 79 |
+
return deform_conv_backward_input_cuda(input, offset, gradOutput,
|
| 80 |
+
gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
|
| 81 |
+
dilationW, dilationH, group, deformable_group, im2col_step);
|
| 82 |
+
#else
|
| 83 |
+
AT_ERROR("deform conv is not compiled with GPU support");
|
| 84 |
+
#endif
|
| 85 |
+
}
|
| 86 |
+
AT_ERROR("deform conv is not implemented on CPU");
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
int deform_conv_backward_parameters(
|
| 90 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
| 91 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
| 92 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
| 93 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
| 94 |
+
int deformable_group, float scale, int im2col_step) {
|
| 95 |
+
if (input.device().is_cuda()) {
|
| 96 |
+
#ifdef WITH_CUDA
|
| 97 |
+
return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
|
| 98 |
+
gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
|
| 99 |
+
dilationH, group, deformable_group, scale, im2col_step);
|
| 100 |
+
#else
|
| 101 |
+
AT_ERROR("deform conv is not compiled with GPU support");
|
| 102 |
+
#endif
|
| 103 |
+
}
|
| 104 |
+
AT_ERROR("deform conv is not implemented on CPU");
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
void modulated_deform_conv_forward(
|
| 108 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 109 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
| 110 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
| 111 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
| 112 |
+
const int dilation_w, const int group, const int deformable_group,
|
| 113 |
+
const bool with_bias) {
|
| 114 |
+
if (input.device().is_cuda()) {
|
| 115 |
+
#ifdef WITH_CUDA
|
| 116 |
+
return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
|
| 117 |
+
offset, mask, output, columns, kernel_h, kernel_w, stride_h,
|
| 118 |
+
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
|
| 119 |
+
deformable_group, with_bias);
|
| 120 |
+
#else
|
| 121 |
+
AT_ERROR("modulated deform conv is not compiled with GPU support");
|
| 122 |
+
#endif
|
| 123 |
+
}
|
| 124 |
+
AT_ERROR("modulated deform conv is not implemented on CPU");
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
void modulated_deform_conv_backward(
|
| 128 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
| 129 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
| 130 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
| 131 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
| 132 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
| 133 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
| 134 |
+
const bool with_bias) {
|
| 135 |
+
if (input.device().is_cuda()) {
|
| 136 |
+
#ifdef WITH_CUDA
|
| 137 |
+
return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
|
| 138 |
+
offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
|
| 139 |
+
grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
|
| 140 |
+
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
|
| 141 |
+
with_bias);
|
| 142 |
+
#else
|
| 143 |
+
AT_ERROR("modulated deform conv is not compiled with GPU support");
|
| 144 |
+
#endif
|
| 145 |
+
}
|
| 146 |
+
AT_ERROR("modulated deform conv is not implemented on CPU");
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 151 |
+
m.def("deform_conv_forward", &deform_conv_forward,
|
| 152 |
+
"deform forward");
|
| 153 |
+
m.def("deform_conv_backward_input", &deform_conv_backward_input,
|
| 154 |
+
"deform_conv_backward_input");
|
| 155 |
+
m.def("deform_conv_backward_parameters",
|
| 156 |
+
&deform_conv_backward_parameters,
|
| 157 |
+
"deform_conv_backward_parameters");
|
| 158 |
+
m.def("modulated_deform_conv_forward",
|
| 159 |
+
&modulated_deform_conv_forward,
|
| 160 |
+
"modulated deform conv forward");
|
| 161 |
+
m.def("modulated_deform_conv_backward",
|
| 162 |
+
&modulated_deform_conv_backward,
|
| 163 |
+
"modulated deform conv backward");
|
| 164 |
+
}
|
basicsr/ops/fused_act/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
| 2 |
+
|
| 3 |
+
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
|
basicsr/ops/fused_act/fused_act.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.autograd import Function
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from . import fused_act_ext
|
| 9 |
+
except ImportError:
|
| 10 |
+
import os
|
| 11 |
+
BASICSR_JIT = os.getenv('BASICSR_JIT')
|
| 12 |
+
if BASICSR_JIT == 'True':
|
| 13 |
+
from torch.utils.cpp_extension import load
|
| 14 |
+
module_path = os.path.dirname(__file__)
|
| 15 |
+
fused_act_ext = load(
|
| 16 |
+
'fused',
|
| 17 |
+
sources=[
|
| 18 |
+
os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
|
| 19 |
+
os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
|
| 20 |
+
],
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
| 28 |
+
ctx.save_for_backward(out)
|
| 29 |
+
ctx.negative_slope = negative_slope
|
| 30 |
+
ctx.scale = scale
|
| 31 |
+
|
| 32 |
+
empty = grad_output.new_empty(0)
|
| 33 |
+
|
| 34 |
+
grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
|
| 35 |
+
|
| 36 |
+
dim = [0]
|
| 37 |
+
|
| 38 |
+
if grad_input.ndim > 2:
|
| 39 |
+
dim += list(range(2, grad_input.ndim))
|
| 40 |
+
|
| 41 |
+
grad_bias = grad_input.sum(dim).detach()
|
| 42 |
+
|
| 43 |
+
return grad_input, grad_bias
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
| 47 |
+
out, = ctx.saved_tensors
|
| 48 |
+
gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
|
| 49 |
+
ctx.scale)
|
| 50 |
+
|
| 51 |
+
return gradgrad_out, None, None, None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FusedLeakyReLUFunction(Function):
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
| 58 |
+
empty = input.new_empty(0)
|
| 59 |
+
out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
| 60 |
+
ctx.save_for_backward(out)
|
| 61 |
+
ctx.negative_slope = negative_slope
|
| 62 |
+
ctx.scale = scale
|
| 63 |
+
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def backward(ctx, grad_output):
|
| 68 |
+
out, = ctx.saved_tensors
|
| 69 |
+
|
| 70 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
|
| 71 |
+
|
| 72 |
+
return grad_input, grad_bias, None, None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class FusedLeakyReLU(nn.Module):
|
| 76 |
+
|
| 77 |
+
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
| 81 |
+
self.negative_slope = negative_slope
|
| 82 |
+
self.scale = scale
|
| 83 |
+
|
| 84 |
+
def forward(self, input):
|
| 85 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
|
| 89 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
basicsr/ops/fused_act/src/fused_bias_act.cpp
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
|
| 2 |
+
#include <torch/extension.h>
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input,
|
| 6 |
+
const torch::Tensor& bias,
|
| 7 |
+
const torch::Tensor& refer,
|
| 8 |
+
int act, int grad, float alpha, float scale);
|
| 9 |
+
|
| 10 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
| 11 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 12 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
| 13 |
+
|
| 14 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input,
|
| 15 |
+
const torch::Tensor& bias,
|
| 16 |
+
const torch::Tensor& refer,
|
| 17 |
+
int act, int grad, float alpha, float scale) {
|
| 18 |
+
CHECK_CUDA(input);
|
| 19 |
+
CHECK_CUDA(bias);
|
| 20 |
+
|
| 21 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 25 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
| 26 |
+
}
|