اداة تحسين الصور
+تحسين وترميم صور الوجه تلقائياً
+diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..85d5c8f1cc33f4e889d1f296a039756728a198f3 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,15 @@ +.git +.gitignore +__pycache__ +*.pyc +*.pyo +*.pyd +.DS_Store +weights/ +results/ +inputs/cropped_faces/ +inputs/gray_faces/ +inputs/masked_faces/ +inputs/whole_imgs/ +output/ +web-demos/ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c980de128bec1cac424dc5e09b5884c21298978e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,5 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +weights/facelib/tmpz5esw78c filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..371bc0ca9a08c693fb52aa56c930ce2202f1763a --- /dev/null +++ b/.gitignore @@ -0,0 +1,131 @@ +.vscode + +# ignored files +version.py + +# ignored files with suffix +*.html + *.png + *.jpeg + *.jpg +*.pt +*.gif +*.pth +*.dat +*.zip + +# template + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# project +results/ +experiments/ +tb_logger/ +run.sh +*debug* +*_old* + diff --git a/API_DOCUMENTATION.md b/API_DOCUMENTATION.md new file mode 100644 index 0000000000000000000000000000000000000000..c2cbb315444b392b63c14d6747eb0c60f1a316c1 --- /dev/null +++ b/API_DOCUMENTATION.md @@ -0,0 +1,118 @@ +# CodeFormer API Documentation + +This document describes the programmatic interface for the CodeFormer Face Restoration service. + +## Base URL +The API is accessible at: +`https://esmailx50-job.hf.space` (or your specific Hugging Face Space URL) + +--- + +## 1. Process Images +Processes one or more images for face restoration and enhancement. + +- **Endpoint:** `/api/process` +- **Method:** `POST` +- **Consumes:** `multipart/form-data` OR `application/json` + +### Parameters +| Parameter | Type | Default | Description | +| :--- | :--- | :--- | :--- | +| `fidelity` | float | `0.5` | Fidelity weight ($w$). Range [0, 1]. Lower is more "hallucinated" detail, higher is more identity preservation. | +| `upscale` | int | `2` | Final upscaling factor. Supported: `1`, `2`, `4`. | +| `background_enhance` | bool | `false` | Enhance the background using Real-ESRGAN. | +| `face_upsample` | bool | `false` | Upsample restored faces using Real-ESRGAN. | +| `return_base64` | bool | `false` | If true, includes the processed image as a base64 string in the JSON response. | + +### Input Formats + +#### A. Multipart Form Data (`multipart/form-data`) +Useful for uploading files directly. +- `image`: One or more image files (as a list). +- Other parameters as form fields. + +**Example (curl):** +```bash +curl -X POST + -F "image=@my_photo.jpg" + -F "fidelity=0.7" + -F "background_enhance=true" + https://esmailx50-job.hf.space/api/process +``` + +#### B. JSON (`application/json`) +Useful for sending base64-encoded image data. +- `image_base64`: A single base64 string (with or without data URI prefix). +- `images_base64`: (Optional) A list of base64 strings for batch processing. +- Other parameters as JSON keys. + +**Example (curl):** +```bash +curl -X POST + -H "Content-Type: application/json" + -d '{ + "image_base64": "data:image/png;base64,iVBORw0KG...", + "fidelity": 0.5, + "return_base64": true + }' + https://esmailx50-job.hf.space/api/process +``` + +### Success Response +```json +{ + "status": "success", + "count": 1, + "results": [ + { + "original_name": "image.png", + "filename": "api_result_uuid.png", + "image_url": "https://.../static/results/api_result_uuid.png", + "image_base64": "iVBORw0KG..." // Only if return_base64 was true + } + ] +} +``` + +### Error Response +```json +{ + "status": "error", + "message": "Detailed error message here" +} +``` + +--- + +## 2. Health Check +Checks if the service is online and returns the compute device being used. + +- **Endpoint:** `/api/health` +- **Method:** `GET` + +**Success Response:** +```json +{ + "status": "online", + "device": "cuda" // or "cpu" +} +``` + +--- + +## CORS & Integration +Cross-Origin Resource Sharing (CORS) is enabled for all routes. This allows you to call the API directly from browser-based applications (React, Vue, etc.) without hitting "Same-Origin Policy" blocks. + +**Javascript Example (Fetch):** +```javascript +const formData = new FormData(); +formData.append('image', fileInput.files[0]); +formData.append('fidelity', '0.5'); + +const response = await fetch('https://esmailx50-job.hf.space/api/process', { + method: 'POST', + body: formData +}); +const data = await response.json(); +console.log(data.results[0].image_url); +``` diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md new file mode 100644 index 0000000000000000000000000000000000000000..eae49b15c95f273a8d9a28687dc77d9398840d24 --- /dev/null +++ b/DOCUMENTATION.md @@ -0,0 +1,158 @@ +# CodeFormer Face Restoration - Project Documentation + +## 1. Introduction + +**CodeFormer** is a robust blind face restoration algorithm designed to restore old, degraded, or AI-generated face images. It utilizes a **Codebook Lookup Transformer** (VQGAN-based) to predict high-quality facial features even from severe degradation, ensuring that the restored faces look natural and faithful to the original identity. + +This project wraps the core CodeFormer research code into a deployable, user-friendly **Flask Web Application**, containerized with **Docker** for easy deployment on platforms like Hugging Face Spaces. + +### Key Features +* **Blind Face Restoration:** Restores faces from low-quality inputs without knowing the specific degradation details. +* **Background Enhancement:** Uses **Real-ESRGAN** to upscale and enhance the non-face background regions of the image. +* **Face Alignment & Paste-back:** Automatically detects faces, aligns them for processing, and seamlessly blends them back into the original image. +* **Adjustable Fidelity:** Users can balance between restoration quality (hallucinating details) and identity fidelity (keeping the original look). + +--- + +## 2. System Architecture + +The application is built on a Python/PyTorch backend served via Flask. + +### 2.1 Technology Stack +* **Framework:** Flask (Python Web Server) +* **Deep Learning:** PyTorch, TorchVision +* **Image Processing:** OpenCV, NumPy, Pillow +* **Core Libraries:** `basicsr` (Basic Super-Restoration), `facelib` (Face detection/utils) +* **Frontend:** HTML5, Bootstrap 5, Jinja2 Templates +* **Containerization:** Docker (CUDA-enabled) + +### 2.2 Directory Structure +``` +CodeFormer/ +├── app.py # Main Flask application entry point +├── Dockerfile # Container configuration +├── requirements.txt # Python dependencies +├── basicsr/ # Core AI framework (Super-Resolution tools) +├── facelib/ # Face detection and alignment utilities +├── templates/ # HTML Frontend +│ ├── index.html # Upload interface +│ └── result.html # Results display +├── static/ # Static assets (css, js, uploads) +│ ├── uploads/ # Temporary storage for input images +│ └── results/ # Temporary storage for processed output +└── weights/ # Pre-trained model weights (downloaded on startup) + ├── CodeFormer/ # CodeFormer model (.pth) + ├── facelib/ # Detection (RetinaFace) and Parsing models + └── realesrgan/ # Background upscaler (Real-ESRGAN) +``` + +### 2.3 Logic Flow +1. **Input:** User uploads an image via the Web UI. +2. **Pre-processing (`app.py`):** + * Image is saved to `static/uploads`. + * Parameters (fidelity, upscale factor) are parsed. +3. **Inference Pipeline:** + * **Detection:** `facelib` detects faces in the image using RetinaFace. + * **Alignment:** Faces are cropped and aligned to a standard 512x512 resolution. + * **Restoration:** The **CodeFormer** model processes the aligned faces. + * **Upscaling (Optional):** The background is upscaled using **Real-ESRGAN**. + * **Paste-back:** Restored faces are warped back to their original positions and blended. +4. **Output:** The final image is saved to `static/results` and displayed to the user. + +--- + +## 3. Installation & Deployment + +### 3.1 Docker Deployment (Recommended) +The project is optimized for Docker. + +**Prerequisites:** Docker, NVIDIA GPU (optional, but recommended). + +1. **Build the Image:** + ```bash + docker build -t codeformer-app . + ``` + +2. **Run the Container:** + ```bash + # Run on port 7860 (Standard for HF Spaces) + docker run -it -p 7860:7860 codeformer-app + ``` + *Note: To use GPU, add the `--gpus all` flag to the run command.* + +### 3.2 Hugging Face Spaces Deployment +This repository is configured for direct deployment to Hugging Face. + +1. Create a **Docker** Space on Hugging Face. +2. Push this entire repository to the Space's Git remote. + ```bash + git remote add hf git@hf.co:spaces/USERNAME/SPACE_NAME + git push hf main + ``` +3. The Space will build (approx. 5-10 mins) and launch automatically. + +### 3.3 Local Development +1. **Install Environment:** + ```bash + conda create -n codeformer python=3.8 + conda activate codeformer + pip install -r requirements.txt + ``` +2. **Install Basicsr:** + ```bash + python basicsr/setup.py install + ``` +3. **Run App:** + ```bash + python app.py + ``` + +--- + +## 4. User Guide (Web Interface) + +### 4.1 Interface Controls + +* **Input Image:** Supports standard formats (JPG, PNG, WEBP). Drag and drop supported. +* **Fidelity Weight (w):** + * **Range:** 0.0 to 1.0. + * **0.0 (Better Quality):** The model "hallucinates" more details. Results look very sharp and high-quality but may slightly alter the person's identity (look less like the original). + * **1.0 (Better Identity):** The model sticks strictly to the original features. Results are faithful to the original photo but might be blurrier or contain more artifacts. + * **Recommended:** 0.5 is a balanced default. +* **Upscale Factor:** + * Scales the final output resolution (1x, 2x, or 4x). + * *Note: Higher scaling requires more VRAM.* +* **Enhance Background:** + * If checked, runs Real-ESRGAN on the non-face areas. + * *Recommendation:* Keep checked for full-photo restoration. Uncheck if you only care about the face or are running on limited hardware. +* **Upsample Face:** + * If checked, the restored face is also upsampled to match the background resolution. + +### 4.2 Viewing Results +The result page features an interactive **Before/After Slider**. Drag the handle left and right to compare the pixels of the original versus the restored image directly. + +--- + +## 5. Technical Details + +### 5.1 Model Weights +The application automatically checks for and downloads the following weights to the `weights/` directory on startup: + +| Model | Path | Description | +| :--- | :--- | :--- | +| **CodeFormer** | `weights/CodeFormer/codeformer.pth` | Main restoration model. | +| **RetinaFace** | `weights/facelib/detection_Resnet50_Final.pth` | Face detection. | +| **ParseNet** | `weights/facelib/parsing_parsenet.pth` | Face parsing (segmentation). | +| **Real-ESRGAN** | `weights/realesrgan/RealESRGAN_x2plus.pth` | Background upscaler (x2). | + +### 5.2 Performance Notes +* **Memory:** The full pipeline (CodeFormer + Real-ESRGAN) requires significant RAM/VRAM. On CPU-only environments (like basic HF Spaces), processing a single image may take 30-60 seconds. +* **Git LFS:** Image assets in this repository are tracked with Git LFS to keep the repo size manageable. + +--- + +## 6. Credits & References + +* **Original Paper:** [Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)](https://arxiv.org/abs/2206.11253) +* **Authors:** Shangchen Zhou, Kelvin C.K. Chan, Chongyi Li, Chen Change Loy (S-Lab, Nanyang Technological University). +* **Original Repository:** [sczhou/CodeFormer](https://github.com/sczhou/CodeFormer) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..a7e5517030181d0e0af84a46072af70d7aaf6ce3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,39 @@ +FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-devel + +WORKDIR /code + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + libgl1 \ + libglib2.0-0 \ + git \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements +COPY requirements.txt . + +# Install python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application code +COPY . . + +# Create necessary directories and set permissions +RUN mkdir -p weights inputs output static && \ + chmod 777 weights inputs output static + +# Install basicsr (build extensions in-place) +RUN python basicsr/setup.py build_ext --inplace + +# Create a non-root user and switch to it +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +WORKDIR /code + +EXPOSE 7860 + +CMD ["python", "app.py"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c76482a059180848c1ba0e0b66d1de3f5e4dd689 --- /dev/null +++ b/LICENSE @@ -0,0 +1,35 @@ +S-Lab License 1.0 + +Copyright 2022 S-Lab + +Redistribution and use for non-commercial purpose in source and +binary forms, with or without modification, are permitted provided +that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in +source or binary forms, with or without modification is required, +please contact the contributor(s) of the work. \ No newline at end of file diff --git a/README.md b/README.md index e9c69748b094bea7812c8920c1a772838b3a82dc..10e24535cd8eade86c41d5ce3a96b900c6216f31 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,229 @@ ---- -title: Codeformer -emoji: 🦀 -colorFrom: purple -colorTo: yellow -sdk: docker -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +--- +title: CodeFormer +emoji: 👤 +colorFrom: blue +colorTo: purple +sdk: docker +app_file: app.py +pinned: false +--- + +
+
+
+
+
+:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
+
+
+### Update
+- **2023.07.20**: Integrated to :panda_face: [OpenXLab](https://openxlab.org.cn/apps). Try out online demo! [](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer)
+- **2023.04.19**: :whale: Training codes and config files are public available now.
+- **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images.
+- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
+- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper:
+- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [](https://huggingface.co/spaces/sczhou/CodeFormer)
+- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [](https://replicate.com/sczhou/codeformer)
+- [**More**](docs/history_changelog.md)
+
+### TODO
+- [x] Add training code and config files
+- [x] Add checkpoint and script for face inpainting
+- [x] Add checkpoint and script for face colorization
+- [x] ~~Add background image enhancement~~
+
+#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts
+[
](https://imgsli.com/MTI3NTE2) [
](https://imgsli.com/MTI3NTE1) [
](https://imgsli.com/MTI3NTIw)
+
+#### Face Restoration
+
+
+
+
+#### Face Color Enhancement and Restoration
+
+
+
+#### Face Inpainting
+
+
+
+
+
+### Dependencies and Installation
+
+- Pytorch >= 1.7.1
+- CUDA >= 10.1
+- Other required packages in `requirements.txt`
+```
+# git clone this repository
+git clone https://github.com/sczhou/CodeFormer
+cd CodeFormer
+
+# create new anaconda env
+conda create -n codeformer python=3.8 -y
+conda activate codeformer
+
+# install python dependencies
+pip3 install -r requirements.txt
+python basicsr/setup.py develop
+conda install -c conda-forge dlib (only for face detection or cropping with dlib)
+```
+
+
+### Quick Inference
+
+#### Download Pre-trained Models:
+Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
+```
+python scripts/download_pretrained_models.py facelib
+python scripts/download_pretrained_models.py dlib (only for dlib face detector)
+```
+
+Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
+```
+python scripts/download_pretrained_models.py CodeFormer
+```
+
+#### Prepare Testing Data:
+You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. You can get the cropped and aligned faces by running the following command:
+```
+# you may need to install dlib via: conda install -c conda-forge dlib
+python scripts/crop_align_face.py -i [input folder] -o [output folder]
+```
+
+
+#### Testing:
+[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
+
+Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. The results will be saved in the `results` folder.
+
+
+🧑🏻 Face Restoration (cropped and aligned face)
+```
+# For cropped and aligned faces (512x512)
+python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path]
+```
+
+:framed_picture: Whole Image Enhancement
+```
+# For whole image
+# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
+# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
+python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path]
+```
+
+:clapper: Video Enhancement
+```
+# For Windows/Mac users, please install ffmpeg first
+conda install -c conda-forge ffmpeg
+```
+```
+# For video clips
+# Video path should end with '.mp4'|'.mov'|'.avi'
+python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path]
+```
+
+🌈 Face Colorization (cropped and aligned face)
+```
+# For cropped and aligned faces (512x512)
+# Colorize black and white or faded photo
+python inference_colorization.py --input_path [image folder]|[image path]
+```
+
+🎨 Face Inpainting (cropped and aligned face)
+```
+# For cropped and aligned faces (512x512)
+# Inputs could be masked by white brush using an image editing app (e.g., Photoshop)
+# (check out the examples in inputs/masked_faces)
+python inference_inpainting.py --input_path [image folder]|[image path]
+```
+### Training:
+The training commands can be found in the documents: [English](docs/train.md) **|** [简体中文](docs/train_CN.md).
+
+### License
+
+This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
+
+---
+### 🐼 Ecosystem Applications & Deployments
+
+CodeFormer has been widely adopted and deployed across a broad range (>20) of online applications, platforms, API services, and independent websites, and has also been integrated into many open-source projects and toolkits.
+
+> Only demos on **Hugging Face Space**, **Replicate**, and **OpenXLab** are official deployments **maintained by the authors**. All other demos, APIs, apps, websites, and integrations listed below are **third-party (non-official)** and are not affiliated with the CodeFormer authors. Please verify their legitimacy to avoid potential financial loss.
+
+
+#### Websites (Non-official)
+
+⚠️⚠️⚠️ The following websites are **not official and are not operated by us**. They use our models without any license or authorization. Please verify their legitimacy to avoid potential financial loss.
+
+
+| Website | Link | Notes |
+|---------|------|--------|
+| CodeFormer.net | https://codeformer.net/ | Non-official website |
+| CodeFormer.cn | https://www.codeformer.cn/ | Non-official website |
+| CodeFormerAI.com | https://codeformerai.com/ | Non-official website |
+
+#### Online Demos / API Platforms
+
+| Platform | Link | Notes |
+|----------|------|--------|
+| Hugging Face | https://huggingface.co/spaces/sczhou/CodeFormer | Maintained by Authors |
+| Replicate | https://replicate.com/sczhou/codeformer | Maintained by Authors |
+| OpenXLab | https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer |Maintained by Authors |
+| Segmind | https://www.segmind.com/models/codeformer | Non-official |
+| Sieve | https://www.sievedata.com/functions/sieve/codeformer | Non-official |
+| Fal.ai | https://fal.ai/models/fal-ai/codeformer | Non-official |
+| VaikerAI | https://vaikerai.com/sczhou/codeformer | Non-official |
+| Scade.pro | https://www.scade.pro/processors/lucataco-codeformer | Non-official |
+| Grandline | https://www.grandline.ai/model/codeformer | Non-official |
+| AI Demos | https://aidemos.com/tools/codeformer | Non-official |
+| Synexa | https://synexa.ai/explore/sczhou/codeformer | Non-official |
+| RentPrompts | https://rentprompts.ai/models/Codeformer | Non-official |
+| ElevaticsAI | https://elevatics.ai/models/super-resolution/codeformer | Non-official |
+| Anakin.ai | https://anakin.ai/apps/codeformer-online-face-restoration-by-codeformer-19343 | Non-official |
+| Relayto | https://relayto.com/explore/codeformer-yf9rj8kwc7zsr | Non-official |
+
+
+#### Open-Source Projects & Toolkits
+
+| Project / Toolkit | Link | Notes |
+|-------------------|------|--------|
+| Stable Diffusion GUI | https://nmkd.itch.io/t2i-gui | Integration |
+| Stable Diffusion WebUI | https://github.com/AUTOMATIC1111/stable-diffusion-webui | Integration |
+| ChaiNNer | https://github.com/chaiNNer-org/chaiNNer | Integration |
+| PyPI | https://pypi.org/project/codeformer/ ; https://pypi.org/project/codeformer-pip/ | Python packages |
+| ComfyUI | https://stable-diffusion-art.com/codeformer/ | Integration |
+
+---
+### Acknowledgement
+
+This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
+
+### Citation
+If our work is useful for your research, please consider citing:
+
+ @inproceedings{zhou2022codeformer,
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
+ booktitle = {NeurIPS},
+ year = {2022}
+ }
+
+
+### Contact
+If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37c56ff5e2a0424b47f6d3c9433520eb13b91f4
--- /dev/null
+++ b/app.py
@@ -0,0 +1,358 @@
+"""
+CodeFormer Flask Application
+Deployment on Hugging Face Spaces
+"""
+
+import os
+import cv2
+import torch
+import uuid
+import numpy as np
+import zipfile
+import base64
+from flask import Flask, render_template, request, send_file, url_for, jsonify, send_from_directory
+from flask_cors import CORS
+from werkzeug.utils import secure_filename
+
+from torchvision.transforms.functional import normalize
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from basicsr.utils.misc import gpu_is_available, get_device
+from basicsr.utils.realesrgan_utils import RealESRGANer
+from basicsr.utils.registry import ARCH_REGISTRY
+
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+from facelib.utils.misc import is_gray
+
+# --- Initialization ---
+app = Flask(__name__)
+CORS(app) # Enable CORS for all routes
+app.config['UPLOAD_FOLDER'] = 'static/uploads'
+app.config['RESULT_FOLDER'] = 'static/results'
+app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB limit
+
+# Ensure directories exist
+os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
+os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
+os.makedirs('weights/CodeFormer', exist_ok=True)
+os.makedirs('weights/facelib', exist_ok=True)
+os.makedirs('weights/realesrgan', exist_ok=True)
+
+# Pretrained model URLs
+pretrain_model_url = {
+ 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
+ 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
+ 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
+ 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
+}
+
+def download_weights():
+ if not os.path.exists('weights/CodeFormer/codeformer.pth'):
+ load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='weights/CodeFormer', progress=True, file_name=None)
+ if not os.path.exists('weights/facelib/detection_Resnet50_Final.pth'):
+ load_file_from_url(url=pretrain_model_url['detection'], model_dir='weights/facelib', progress=True, file_name=None)
+ if not os.path.exists('weights/facelib/parsing_parsenet.pth'):
+ load_file_from_url(url=pretrain_model_url['parsing'], model_dir='weights/facelib', progress=True, file_name=None)
+ if not os.path.exists('weights/realesrgan/RealESRGAN_x2plus.pth'):
+ load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='weights/realesrgan', progress=True, file_name=None)
+
+# Download weights on startup
+print("Checking weights...")
+download_weights()
+
+# Global models
+device = get_device()
+upsampler = None
+codeformer_net = None
+
+def init_models():
+ global upsampler, codeformer_net
+
+ # RealESRGAN
+ half = True if gpu_is_available() else False
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+ upsampler = RealESRGANer(
+ scale=2,
+ model_path="weights/realesrgan/RealESRGAN_x2plus.pth",
+ model=model,
+ tile=400,
+ tile_pad=40,
+ pre_pad=0,
+ half=half,
+ )
+
+ # CodeFormer
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+ ).to(device)
+
+ ckpt_path = "weights/CodeFormer/codeformer.pth"
+ checkpoint = torch.load(ckpt_path)["params_ema"]
+ codeformer_net.load_state_dict(checkpoint)
+ codeformer_net.eval()
+ print("Models loaded successfully.")
+
+init_models()
+
+def process_image(img_path, background_enhance, face_upsample, upscale, codeformer_fidelity):
+ """Core inference logic"""
+ try:
+ # Defaults
+ has_aligned = False
+ only_center_face = False
+ draw_box = False
+ detection_model = "retinaface_resnet50"
+
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+
+ # Memory safety checks
+ upscale = int(upscale)
+ if upscale > 4: upscale = 4
+ if upscale > 2 and max(img.shape[:2]) > 1000: upscale = 2
+ if max(img.shape[:2]) > 1500:
+ upscale = 1
+ background_enhance = False
+ face_upsample = False
+
+ face_helper = FaceRestoreHelper(
+ upscale,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model=detection_model,
+ save_ext="png",
+ use_parse=True,
+ device=device,
+ )
+
+ bg_upsampler = upsampler if background_enhance else None
+ face_upsampler = upsampler if face_upsample else None
+
+ if has_aligned:
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_helper.is_gray = is_gray(img, threshold=5)
+ face_helper.cropped_faces = [img]
+ else:
+ face_helper.read_image(img)
+ face_helper.get_face_landmarks_5(only_center_face=only_center_face, resize=640, eye_dist_threshold=5)
+ face_helper.align_warp_face()
+
+ # Face restoration
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
+ cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+ with torch.no_grad():
+ output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ except Exception as e:
+ print(f"Inference error: {e}")
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype("uint8")
+ face_helper.add_restored_face(restored_face)
+
+ # Paste back
+ if not has_aligned:
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] if bg_upsampler else None
+ face_helper.get_inverse_affine(None)
+
+ if face_upsample and face_upsampler:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box, face_upsampler=face_upsampler)
+ else:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box)
+ else:
+ restored_img = face_helper.restored_faces[0]
+
+ return restored_img
+
+ except Exception as e:
+ print(f"Global processing error: {e}")
+ return None
+
+# --- Routes ---
+
+@app.route('/', methods=['GET'])
+def index():
+ return render_template('index.html')
+
+@app.route('/process', methods=['POST'])
+def process():
+ if 'image' not in request.files:
+ return "No image uploaded", 400
+
+ files = request.files.getlist('image')
+ if not files or files[0].filename == '':
+ return "No selected file", 400
+
+ results = []
+
+ # Get params (same for all images)
+ try:
+ fidelity = float(request.form.get('fidelity', 0.5))
+ upscale = 4 # Enforce 4x upscale
+ background_enhance = 'background_enhance' in request.form
+ face_upsample = 'face_upsample' in request.form
+ except ValueError:
+ return "Invalid parameters", 400
+
+ for file in files:
+ if file.filename == '': continue
+
+ # Save input
+ filename = str(uuid.uuid4()) + "_" + secure_filename(file.filename)
+ input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
+ file.save(input_path)
+
+ # Process
+ result_img = process_image(input_path, background_enhance, face_upsample, upscale, fidelity)
+
+ if result_img is None:
+ continue # Skip failed images or handle error appropriately
+
+ # Save output
+ output_filename = "result_" + filename.rsplit('.', 1)[0] + ".png"
+ output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
+ imwrite(result_img, output_path)
+
+ # Generate preview (max 1000px width/height)
+ preview_filename = "preview_" + output_filename
+ preview_path = os.path.join(app.config['RESULT_FOLDER'], preview_filename)
+
+ h, w = result_img.shape[:2]
+ if max(h, w) > 1000:
+ scale = 1000 / max(h, w)
+ preview_img = cv2.resize(result_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
+ imwrite(preview_img, preview_path)
+ else:
+ preview_filename = output_filename
+
+ results.append({
+ 'original': filename,
+ 'preview': preview_filename,
+ 'download': output_filename
+ })
+
+ if not results:
+ return "Processing failed for all images", 500
+
+ # Create ZIP of all results
+ zip_filename = f"batch_{uuid.uuid4()}.zip"
+ zip_path = os.path.join(app.config['RESULT_FOLDER'], zip_filename)
+
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
+ for item in results:
+ file_path = os.path.join(app.config['RESULT_FOLDER'], item['download'])
+ zipf.write(file_path, item['download'])
+
+ return render_template('result.html', results=results, zip_filename=zip_filename)
+
+# --- API Routes ---
+
+@app.route('/api/process', methods=['POST'])
+def api_process():
+ """
+ API endpoint for image processing.
+ Accepts:
+ - multipart/form-data with one or more 'image' files.
+ - application/json with 'image_base64' string (single image) or 'images_base64' list.
+ Parameters (form or JSON):
+ - fidelity: (float) 0-1, default 0.5.
+ - background_enhance: (bool) default False.
+ - face_upsample: (bool) default False.
+ - upscale: (int) 1-4, default 2.
+ - return_base64: (bool) default False.
+ """
+ try:
+ is_json = request.is_json
+ data = request.get_json() if is_json else request.form
+
+ fidelity = float(data.get('fidelity', 0.5))
+ background_enhance = (str(data.get('background_enhance', 'false')).lower() == 'true') if not is_json else data.get('background_enhance', False)
+ face_upsample = (str(data.get('face_upsample', 'false')).lower() == 'true') if not is_json else data.get('face_upsample', False)
+ upscale = int(data.get('upscale', 2))
+ return_base64 = (str(data.get('return_base64', 'false')).lower() == 'true') if not is_json else data.get('return_base64', False)
+
+ processed_images = []
+ inputs = []
+
+ # Handle JSON input
+ if is_json:
+ if 'image_base64' in data:
+ inputs.append({'data': data['image_base64'], 'name': 'image.png'})
+ elif 'images_base64' in data:
+ for idx, img_b64 in enumerate(data['images_base64']):
+ inputs.append({'data': img_b64, 'name': f'image_{idx}.png'})
+
+ for inp in inputs:
+ temp_filename = str(uuid.uuid4())
+ image_data = base64.b64decode(inp['data'].split(',')[-1])
+ input_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{temp_filename}.png")
+ with open(input_path, 'wb') as f:
+ f.write(image_data)
+ inp['path'] = input_path
+ inp['temp_id'] = temp_filename
+
+ # Handle Multipart input
+ elif 'image' in request.files:
+ files = request.files.getlist('image')
+ for file in files:
+ if file.filename != '':
+ temp_filename = str(uuid.uuid4())
+ filename = f"{temp_filename}_{secure_filename(file.filename)}"
+ input_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
+ file.save(input_path)
+ inputs.append({'path': input_path, 'name': file.filename, 'temp_id': temp_filename})
+
+ if not inputs:
+ return jsonify({"status": "error", "message": "No images provided"}), 400
+
+ for inp in inputs:
+ # Process image
+ result_img = process_image(inp['path'], background_enhance, face_upsample, upscale, fidelity)
+ if result_img is not None:
+ # Save result
+ output_filename = f"api_result_{inp['temp_id']}.png"
+ output_path = os.path.join(app.config['RESULT_FOLDER'], output_filename)
+ imwrite(result_img, output_path)
+
+ res = {
+ "original_name": inp['name'],
+ "image_url": url_for('static', filename=f'results/{output_filename}', _external=True),
+ "filename": output_filename
+ }
+
+ if return_base64:
+ _, buffer = cv2.imencode('.png', result_img)
+ img_base64 = base64.b64encode(buffer).decode('utf-8')
+ res["image_base64"] = img_base64
+
+ processed_images.append(res)
+
+ if not processed_images:
+ return jsonify({"status": "error", "message": "Processing failed for all images"}), 500
+
+ return jsonify({
+ "status": "success",
+ "count": len(processed_images),
+ "results": processed_images
+ })
+
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ return jsonify({"status": "error", "message": str(e)}), 500
+
+@app.route('/api/health', methods=['GET'])
+def health_check():
+ return jsonify({"status": "online", "device": str(device)})
+
+if __name__ == '__main__':
+ # Docker/HF Spaces entry point
+ app.run(host='0.0.0.0', port=7860)
\ No newline at end of file
diff --git a/basicsr/VERSION b/basicsr/VERSION
new file mode 100644
index 0000000000000000000000000000000000000000..b85bccc7d7631d9d65de5514baac020cfbee6545
--- /dev/null
+++ b/basicsr/VERSION
@@ -0,0 +1 @@
+1.3.2
diff --git a/basicsr/__init__.py b/basicsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a06af02c563a37ccd339a4427707bd928c266c9
--- /dev/null
+++ b/basicsr/__init__.py
@@ -0,0 +1,11 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .train import *
+from .utils import *
+from .version import __gitsha__, __version__
diff --git a/basicsr/archs/__init__.py b/basicsr/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcec89c0e9ef2ea698068573123df7f407e8f5c2
--- /dev/null
+++ b/basicsr/archs/__init__.py
@@ -0,0 +1,25 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop('type')
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
+ return net
diff --git a/basicsr/archs/arcface_arch.py b/basicsr/archs/arcface_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..91d511c3ffdfd436426ed872edbd373ace04ed0b
--- /dev/null
+++ b/basicsr/archs/arcface_arch.py
@@ -0,0 +1,245 @@
+import torch.nn as nn
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+def conv3x3(inplanes, outplanes, stride=1):
+ """A simple wrapper for 3x3 convolution with padding.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ outplanes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ """
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ """Basic residual block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class IRBlock(nn.Module):
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
+ super(IRBlock, self).__init__()
+ self.bn0 = nn.BatchNorm2d(inplanes)
+ self.conv1 = conv3x3(inplanes, inplanes)
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.prelu = nn.PReLU()
+ self.conv2 = conv3x3(inplanes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.use_se = use_se
+ if self.use_se:
+ self.se = SEBlock(planes)
+
+ def forward(self, x):
+ residual = x
+ out = self.bn0(x)
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = self.prelu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.use_se:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.prelu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 4 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class SEBlock(nn.Module):
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
+
+ Args:
+ channel (int): Channel number of inputs.
+ reduction (int): Channel reduction ration. Default: 16.
+ """
+
+ def __init__(self, channel, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
+ nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+@ARCH_REGISTRY.register()
+class ResNetArcFace(nn.Module):
+ """ArcFace with ResNet architectures.
+
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
+
+ Args:
+ block (str): Block used in the ArcFace architecture.
+ layers (tuple(int)): Block numbers in each layer.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+
+ def __init__(self, block, layers, use_se=True):
+ if block == 'IRBlock':
+ block = IRBlock
+ self.inplanes = 64
+ self.use_se = use_se
+ super(ResNetArcFace, self).__init__()
+
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.prelu = nn.PReLU()
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.bn4 = nn.BatchNorm2d(512)
+ self.dropout = nn.Dropout()
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
+ self.bn5 = nn.BatchNorm1d(512)
+
+ # initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
+ self.inplanes = planes
+ for _ in range(1, num_blocks):
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn4(x)
+ x = self.dropout(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc5(x)
+ x = self.bn5(x)
+
+ return x
\ No newline at end of file
diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5e44efca072048606b7b065c212ac8fa639f385
--- /dev/null
+++ b/basicsr/archs/arch_util.py
@@ -0,0 +1,318 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-+-
+ |________________|
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == 'ratio':
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == 'shape':
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ Ref:
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+ else:
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
\ No newline at end of file
diff --git a/basicsr/archs/codeformer_arch.py b/basicsr/archs/codeformer_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a91cb1e87d2e6d944de6ebcfbaecf6c55a38b0c
--- /dev/null
+++ b/basicsr/archs/codeformer_arch.py
@@ -0,0 +1,280 @@
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from basicsr.archs.vqgan_arch import *
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ codebook_size=1024, latent_size=256,
+ connect_list=['32', '64', '128', '256'],
+ fix_modules=['quantize','generator'], vqgan_path=None):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+ if vqgan_path is not None:
+ self.load_state_dict(
+ torch.load(vqgan_path, map_location='cpu')['params_ema'])
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd*2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat
\ No newline at end of file
diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..93007293a936eb4fc11074244fc8144cf8e2c641
--- /dev/null
+++ b/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
\ No newline at end of file
diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..69c84bc2e63893388243ce631ca79c5eadf69d24
--- /dev/null
+++ b/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/basicsr/archs/vqgan_arch.py b/basicsr/archs/vqgan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..283a77739caedd25a754e6cf4632a757ce6bd7f3
--- /dev/null
+++ b/basicsr/archs/vqgan_arch.py
@@ -0,0 +1,434 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ # min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta #0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_ema' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError(f'Wrong params!')
+
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ return self.main(x)
\ No newline at end of file
diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..39adf665d4218c05deca9f8e6981fd7ee42d8d9e
--- /dev/null
+++ b/basicsr/data/__init__.py
@@ -0,0 +1,100 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must constain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt['phase']
+ rank, _ = get_dist_info()
+ if phase == 'train':
+ if dist: # distributed training
+ batch_size = dataset_opt['batch_size_per_gpu']
+ num_workers = dataset_opt['num_worker_per_gpu']
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True)
+ if sampler is None:
+ dataloader_args['shuffle'] = True
+ dataloader_args['worker_init_fn'] = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ elif phase in ['val', 'test']: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+
+ prefetch_mode = dataset_opt.get('prefetch_mode')
+ if prefetch_mode == 'cpu': # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+ logger = get_root_logger()
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5135c7f83a0698c1980354b65ffa68f98a3c6cc0
--- /dev/null
+++ b/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..864805bbce6604357b7886150304ade36f7d00b6
--- /dev/null
+++ b/basicsr/data/data_util.py
@@ -0,0 +1,392 @@
+import cv2
+import math
+import numpy as np
+import torch
+from os import path as osp
+from PIL import Image, ImageDraw
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ pad_idx = 0
+ elif padding == 'reflection':
+ pad_idx = -i
+ elif padding == 'reflection_circle':
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == 'replicate':
+ pad_idx = max_frame_num
+ elif padding == 'reflection':
+ pad_idx = max_frame_num * 2 - i
+ elif padding == 'reflection_circle':
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}')
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.split(' ')[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+ f'{len(input_paths)}, {len(gt_paths)}.')
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith('.lmdb'):
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
+ paths = [line.split('.')[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
+
+
+def brush_stroke_mask(img, color=(255,255,255)):
+ min_num_vertex = 8
+ max_num_vertex = 28
+ mean_angle = 2*math.pi / 5
+ angle_range = 2*math.pi / 12
+ # training large mask ratio (training setting)
+ min_width = 30
+ max_width = 70
+ # very large mask ratio (test setting and refine after 200k)
+ # min_width = 80
+ # max_width = 120
+ def generate_mask(H, W, img=None):
+ average_radius = math.sqrt(H*H+W*W) / 8
+ mask = Image.new('RGB', (W, H), 0)
+ if img is not None: mask = img # Image.fromarray(img)
+
+ for _ in range(np.random.randint(1, 4)):
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
+ angles = []
+ vertex = []
+ for i in range(num_vertex):
+ if i % 2 == 0:
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
+ else:
+ angles.append(np.random.uniform(angle_min, angle_max))
+
+ h, w = mask.size
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
+ for i in range(num_vertex):
+ r = np.clip(
+ np.random.normal(loc=average_radius, scale=average_radius//2),
+ 0, 2*average_radius)
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
+ vertex.append((int(new_x), int(new_y)))
+
+ draw = ImageDraw.Draw(mask)
+ width = int(np.random.uniform(min_width, max_width))
+ draw.line(vertex, fill=color, width=width)
+ for v in vertex:
+ draw.ellipse((v[0] - width//2,
+ v[1] - width//2,
+ v[0] + width//2,
+ v[1] + width//2),
+ fill=color)
+
+ return mask
+
+ width, height = img.size
+ mask = generate_mask(height, width, img)
+ return mask
+
+
+def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
+ """Generate a random free form mask with configuration.
+ Args:
+ config: Config should have configuration including IMG_SHAPES,
+ VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
+ Returns:
+ tuple: (top, left, height, width)
+ Link:
+ https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
+ """
+ height = shape[0]
+ width = shape[1]
+ mask = np.zeros((height, width), np.float32)
+ times = np.random.randint(times-5, times)
+ for i in range(times):
+ start_x = np.random.randint(width)
+ start_y = np.random.randint(height)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_len-20, max_len)
+ brush_w = 5 + np.random.randint(max_width-30, max_width)
+ end_x = (start_x + length * np.sin(angle)).astype(np.int32)
+ end_y = (start_y + length * np.cos(angle)).astype(np.int32)
+ cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
+ start_x, start_y = end_x, end_y
+ return mask.astype(np.float32)
\ No newline at end of file
diff --git a/basicsr/data/ffhq_blind_dataset.py b/basicsr/data/ffhq_blind_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fec22bea54e3b90f7d2918a15ce640a629e32992
--- /dev/null
+++ b/basicsr/data/ffhq_blind_dataset.py
@@ -0,0 +1,299 @@
+import cv2
+import math
+import random
+import numpy as np
+import os.path as osp
+from scipy.io import loadmat
+from PIL import Image
+import torch
+import torch.utils.data as data
+from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
+ adjust_hue, adjust_saturation, normalize)
+from basicsr.data import gaussian_kernels as gaussian_kernels
+from basicsr.data.transforms import augment
+from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+@DATASET_REGISTRY.register()
+class FFHQBlindDataset(data.Dataset):
+
+ def __init__(self, opt):
+ super(FFHQBlindDataset, self).__init__()
+ logger = get_root_logger()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+
+ self.gt_folder = opt['dataroot_gt']
+ self.gt_size = opt.get('gt_size', 512)
+ self.in_size = opt.get('in_size', 512)
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
+
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
+
+ self.component_path = opt.get('component_path', None)
+ self.latent_gt_path = opt.get('latent_gt_path', None)
+
+ if self.component_path is not None:
+ self.crop_components = True
+ self.components_dict = torch.load(self.component_path)
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
+ else:
+ self.crop_components = False
+
+ if self.latent_gt_path is not None:
+ self.load_latent_gt = True
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
+ else:
+ self.load_latent_gt = False
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = self.gt_folder
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ self.paths = paths_from_folder(self.gt_folder)
+
+ # inpainting mask
+ self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
+ if self.gen_inpaint_mask:
+ logger.info(f'generate mask ...')
+ # self.mask_max_angle = opt.get('mask_max_angle', 10)
+ # self.mask_max_len = opt.get('mask_max_len', 150)
+ # self.mask_max_width = opt.get('mask_max_width', 50)
+ # self.mask_draw_times = opt.get('mask_draw_times', 10)
+ # # print
+ # logger.info(f'mask_max_angle: {self.mask_max_angle}')
+ # logger.info(f'mask_max_len: {self.mask_max_len}')
+ # logger.info(f'mask_max_width: {self.mask_max_width}')
+ # logger.info(f'mask_draw_times: {self.mask_draw_times}')
+
+ # perform corrupt
+ self.use_corrupt = opt.get('use_corrupt', True)
+ self.use_motion_kernel = False
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
+
+ if self.use_motion_kernel:
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
+ self.motion_kernels = torch.load(motion_kernel_path)
+
+ if self.use_corrupt and not self.gen_inpaint_mask:
+ # degradation configurations
+ self.blur_kernel_size = opt['blur_kernel_size']
+ self.blur_sigma = opt['blur_sigma']
+ self.kernel_list = opt['kernel_list']
+ self.kernel_prob = opt['kernel_prob']
+ self.downsample_range = opt['downsample_range']
+ self.noise_range = opt['noise_range']
+ self.jpeg_range = opt['jpeg_range']
+ # print
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
+
+ # color jitter
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
+ if self.color_jitter_prob is not None:
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
+
+ # to gray
+ self.gray_prob = opt.get('gray_prob', 0.0)
+ if self.gray_prob is not None:
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
+ self.color_jitter_shift /= 255.
+
+ @staticmethod
+ def color_jitter(img, shift):
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
+ img = img + jitter_val
+ img = np.clip(img, 0, 1)
+ return img
+
+ @staticmethod
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
+ fn_idx = torch.randperm(4)
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness is not None:
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
+ img = adjust_brightness(img, brightness_factor)
+
+ if fn_id == 1 and contrast is not None:
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
+ img = adjust_contrast(img, contrast_factor)
+
+ if fn_id == 2 and saturation is not None:
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
+ img = adjust_saturation(img, saturation_factor)
+
+ if fn_id == 3 and hue is not None:
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
+ img = adjust_hue(img, hue_factor)
+ return img
+
+
+ def get_component_locations(self, name, status):
+ components_bbox = self.components_dict[name]
+ if status[0]: # hflip
+ # exchange right and left eye
+ tmp = components_bbox['left_eye']
+ components_bbox['left_eye'] = components_bbox['right_eye']
+ components_bbox['right_eye'] = tmp
+ # modify the width coordinate
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
+
+ locations_gt = {}
+ locations_in = {}
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
+ mean = components_bbox[part][0:2]
+ half_len = components_bbox[part][2]
+ if 'eye' in part:
+ half_len *= self.eye_enlarge_ratio
+ elif part == 'nose':
+ half_len *= self.nose_enlarge_ratio
+ elif part == 'mouth':
+ half_len *= self.mouth_enlarge_ratio
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
+ loc = torch.from_numpy(loc).float()
+ locations_gt[part] = loc
+ loc_in = loc/(self.gt_size//self.in_size)
+ locations_in[part] = loc_in
+ return locations_gt, locations_in
+
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load gt image
+ gt_path = self.paths[index]
+ name = osp.basename(gt_path)[:-4]
+ img_bytes = self.file_client.get(gt_path)
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # random horizontal flip
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
+
+ if self.load_latent_gt:
+ if status[0]:
+ latent_gt = self.latent_gt_dict['hflip'][name]
+ else:
+ latent_gt = self.latent_gt_dict['orig'][name]
+
+ if self.crop_components:
+ locations_gt, locations_in = self.get_component_locations(name, status)
+
+ # generate in image
+ img_in = img_gt
+ if self.use_corrupt and not self.gen_inpaint_mask:
+ # motion blur
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
+ m_i = random.randint(0,31)
+ k = self.motion_kernels[f'{m_i:02d}']
+ img_in = cv2.filter2D(img_in,-1,k)
+
+ # gaussian blur
+ kernel = gaussian_kernels.random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ self.blur_kernel_size,
+ self.blur_sigma,
+ self.blur_sigma,
+ [-math.pi, math.pi],
+ noise_range=None)
+ img_in = cv2.filter2D(img_in, -1, kernel)
+
+ # downsample
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
+
+ # noise
+ if self.noise_range is not None:
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
+ img_in = img_in + noise
+ img_in = np.clip(img_in, 0, 1)
+
+ # jpeg
+ if self.jpeg_range is not None:
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
+
+ # resize to in_size
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
+
+ # if self.gen_inpaint_mask:
+ # inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size),
+ # max_angle = self.mask_max_angle, max_len = self.mask_max_len,
+ # max_width = self.mask_max_width, times = self.mask_draw_times)
+ # img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
+ # 1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)
+
+ # inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)
+
+ if self.gen_inpaint_mask:
+ img_in = (img_in*255).astype('uint8')
+ img_in = brush_stroke_mask(Image.fromarray(img_in))
+ img_in = np.array(img_in) / 255.
+
+ # random color jitter (only for lq)
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
+ # random to gray (only for lq)
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)
+
+ # random color jitter (pytorch version) (only for lq)
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
+ brightness = self.opt.get('brightness', (0.5, 1.5))
+ contrast = self.opt.get('contrast', (0.5, 1.5))
+ saturation = self.opt.get('saturation', (0, 1.5))
+ hue = self.opt.get('hue', (-0.1, 0.1))
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
+
+ # round and clip
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
+
+ # Set vgg range_norm=True if use the normalization here
+ # normalize
+ normalize(img_in, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}
+
+ if self.crop_components:
+ return_dict['locations_in'] = locations_in
+ return_dict['locations_gt'] = locations_gt
+
+ if self.load_latent_gt:
+ return_dict['latent_gt'] = latent_gt
+
+ # if self.gen_inpaint_mask:
+ # return_dict['inpaint_mask'] = inpaint_mask
+
+ return return_dict
+
+
+ def __len__(self):
+ return len(self.paths)
\ No newline at end of file
diff --git a/basicsr/data/ffhq_blind_joint_dataset.py b/basicsr/data/ffhq_blind_joint_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6aeaeb1b40d5521228edb4518f3280eb6332b74
--- /dev/null
+++ b/basicsr/data/ffhq_blind_joint_dataset.py
@@ -0,0 +1,324 @@
+import cv2
+import math
+import random
+import numpy as np
+import os.path as osp
+from scipy.io import loadmat
+import torch
+import torch.utils.data as data
+from torchvision.transforms.functional import (adjust_brightness, adjust_contrast,
+ adjust_hue, adjust_saturation, normalize)
+from basicsr.data import gaussian_kernels as gaussian_kernels
+from basicsr.data.transforms import augment
+from basicsr.data.data_util import paths_from_folder
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+@DATASET_REGISTRY.register()
+class FFHQBlindJointDataset(data.Dataset):
+
+ def __init__(self, opt):
+ super(FFHQBlindJointDataset, self).__init__()
+ logger = get_root_logger()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+
+ self.gt_folder = opt['dataroot_gt']
+ self.gt_size = opt.get('gt_size', 512)
+ self.in_size = opt.get('in_size', 512)
+ assert self.gt_size >= self.in_size, 'Wrong setting.'
+
+ self.mean = opt.get('mean', [0.5, 0.5, 0.5])
+ self.std = opt.get('std', [0.5, 0.5, 0.5])
+
+ self.component_path = opt.get('component_path', None)
+ self.latent_gt_path = opt.get('latent_gt_path', None)
+
+ if self.component_path is not None:
+ self.crop_components = True
+ self.components_dict = torch.load(self.component_path)
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
+ self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
+ self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
+ else:
+ self.crop_components = False
+
+ if self.latent_gt_path is not None:
+ self.load_latent_gt = True
+ self.latent_gt_dict = torch.load(self.latent_gt_path)
+ else:
+ self.load_latent_gt = False
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = self.gt_folder
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ self.paths = paths_from_folder(self.gt_folder)
+
+ # perform corrupt
+ self.use_corrupt = opt.get('use_corrupt', True)
+ self.use_motion_kernel = False
+ # self.use_motion_kernel = opt.get('use_motion_kernel', True)
+
+ if self.use_motion_kernel:
+ self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
+ motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
+ self.motion_kernels = torch.load(motion_kernel_path)
+
+ if self.use_corrupt:
+ # degradation configurations
+ self.blur_kernel_size = self.opt['blur_kernel_size']
+ self.kernel_list = self.opt['kernel_list']
+ self.kernel_prob = self.opt['kernel_prob']
+ # Small degradation
+ self.blur_sigma = self.opt['blur_sigma']
+ self.downsample_range = self.opt['downsample_range']
+ self.noise_range = self.opt['noise_range']
+ self.jpeg_range = self.opt['jpeg_range']
+ # Large degradation
+ self.blur_sigma_large = self.opt['blur_sigma_large']
+ self.downsample_range_large = self.opt['downsample_range_large']
+ self.noise_range_large = self.opt['noise_range_large']
+ self.jpeg_range_large = self.opt['jpeg_range_large']
+
+ # print
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
+
+ # color jitter
+ self.color_jitter_prob = opt.get('color_jitter_prob', None)
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
+ if self.color_jitter_prob is not None:
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
+
+ # to gray
+ self.gray_prob = opt.get('gray_prob', 0.0)
+ if self.gray_prob is not None:
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
+ self.color_jitter_shift /= 255.
+
+ @staticmethod
+ def color_jitter(img, shift):
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
+ img = img + jitter_val
+ img = np.clip(img, 0, 1)
+ return img
+
+ @staticmethod
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
+ fn_idx = torch.randperm(4)
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness is not None:
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
+ img = adjust_brightness(img, brightness_factor)
+
+ if fn_id == 1 and contrast is not None:
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
+ img = adjust_contrast(img, contrast_factor)
+
+ if fn_id == 2 and saturation is not None:
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
+ img = adjust_saturation(img, saturation_factor)
+
+ if fn_id == 3 and hue is not None:
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
+ img = adjust_hue(img, hue_factor)
+ return img
+
+
+ def get_component_locations(self, name, status):
+ components_bbox = self.components_dict[name]
+ if status[0]: # hflip
+ # exchange right and left eye
+ tmp = components_bbox['left_eye']
+ components_bbox['left_eye'] = components_bbox['right_eye']
+ components_bbox['right_eye'] = tmp
+ # modify the width coordinate
+ components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
+ components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
+ components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
+ components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
+
+ locations_gt = {}
+ locations_in = {}
+ for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
+ mean = components_bbox[part][0:2]
+ half_len = components_bbox[part][2]
+ if 'eye' in part:
+ half_len *= self.eye_enlarge_ratio
+ elif part == 'nose':
+ half_len *= self.nose_enlarge_ratio
+ elif part == 'mouth':
+ half_len *= self.mouth_enlarge_ratio
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
+ loc = torch.from_numpy(loc).float()
+ locations_gt[part] = loc
+ loc_in = loc/(self.gt_size//self.in_size)
+ locations_in[part] = loc_in
+ return locations_gt, locations_in
+
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load gt image
+ gt_path = self.paths[index]
+ name = osp.basename(gt_path)[:-4]
+ img_bytes = self.file_client.get(gt_path)
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # random horizontal flip
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
+
+ if self.load_latent_gt:
+ if status[0]:
+ latent_gt = self.latent_gt_dict['hflip'][name]
+ else:
+ latent_gt = self.latent_gt_dict['orig'][name]
+
+ if self.crop_components:
+ locations_gt, locations_in = self.get_component_locations(name, status)
+
+ # generate in image
+ img_in = img_gt
+ if self.use_corrupt:
+ # motion blur
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
+ m_i = random.randint(0,31)
+ k = self.motion_kernels[f'{m_i:02d}']
+ img_in = cv2.filter2D(img_in,-1,k)
+
+ # gaussian blur
+ kernel = gaussian_kernels.random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ self.blur_kernel_size,
+ self.blur_sigma,
+ self.blur_sigma,
+ [-math.pi, math.pi],
+ noise_range=None)
+ img_in = cv2.filter2D(img_in, -1, kernel)
+
+ # downsample
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
+ img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
+
+ # noise
+ if self.noise_range is not None:
+ noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
+ noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
+ img_in = img_in + noise
+ img_in = np.clip(img_in, 0, 1)
+
+ # jpeg
+ if self.jpeg_range is not None:
+ jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
+ _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
+ img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.
+
+ # resize to in_size
+ img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
+
+
+ # generate in_large with large degradation
+ img_in_large = img_gt
+
+ if self.use_corrupt:
+ # motion blur
+ if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
+ m_i = random.randint(0,31)
+ k = self.motion_kernels[f'{m_i:02d}']
+ img_in_large = cv2.filter2D(img_in_large,-1,k)
+
+ # gaussian blur
+ kernel = gaussian_kernels.random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ self.blur_kernel_size,
+ self.blur_sigma_large,
+ self.blur_sigma_large,
+ [-math.pi, math.pi],
+ noise_range=None)
+ img_in_large = cv2.filter2D(img_in_large, -1, kernel)
+
+ # downsample
+ scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
+ img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)
+
+ # noise
+ if self.noise_range_large is not None:
+ noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
+ noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
+ img_in_large = img_in_large + noise
+ img_in_large = np.clip(img_in_large, 0, 1)
+
+ # jpeg
+ if self.jpeg_range_large is not None:
+ jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
+ _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
+ img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.
+
+ # resize to in_size
+ img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)
+
+
+ # random color jitter (only for lq)
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
+ img_in = self.color_jitter(img_in, self.color_jitter_shift)
+ img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
+ # random to gray (only for lq)
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
+ img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
+ img_in = np.tile(img_in[:, :, None], [1, 1, 3])
+ img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
+ img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)
+
+ # random color jitter (pytorch version) (only for lq)
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
+ brightness = self.opt.get('brightness', (0.5, 1.5))
+ contrast = self.opt.get('contrast', (0.5, 1.5))
+ saturation = self.opt.get('saturation', (0, 1.5))
+ hue = self.opt.get('hue', (-0.1, 0.1))
+ img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
+ img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)
+
+ # round and clip
+ img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
+ img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.
+
+ # Set vgg range_norm=True if use the normalization here
+ # normalize
+ normalize(img_in, self.mean, self.std, inplace=True)
+ normalize(img_in_large, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}
+
+ if self.crop_components:
+ return_dict['locations_in'] = locations_in
+ return_dict['locations_gt'] = locations_gt
+
+ if self.load_latent_gt:
+ return_dict['latent_gt'] = latent_gt
+
+ return return_dict
+
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/basicsr/data/gaussian_kernels.py b/basicsr/data/gaussian_kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..201b3dfb4f72df477b12f830691fd2976986f137
--- /dev/null
+++ b/basicsr/data/gaussian_kernels.py
@@ -0,0 +1,690 @@
+import math
+import numpy as np
+import random
+from scipy.ndimage.interpolation import shift
+from scipy.stats import multivariate_normal
+
+
+def sigma_matrix2(sig_x, sig_y, theta):
+ """Calculate the rotated sigma matrix (two dimensional matrix).
+ Args:
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ Returns:
+ ndarray: Rotated sigma matrix.
+ """
+ D = np.array([[sig_x**2, 0], [0, sig_y**2]])
+ U = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ return np.dot(U, np.dot(D, U.T))
+
+
+def mesh_grid(kernel_size):
+ """Generate the mesh grid, centering at zero.
+ Args:
+ kernel_size (int):
+ Returns:
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+ xx (ndarray): with the shape (kernel_size, kernel_size)
+ yy (ndarray): with the shape (kernel_size, kernel_size)
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ xx, yy = np.meshgrid(ax, ax)
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
+ yy.reshape(kernel_size * kernel_size,
+ 1))).reshape(kernel_size, kernel_size, 2)
+ return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+ """Calculate PDF of the bivariate Gaussian distribution.
+ Args:
+ sigma_matrix (ndarray): with the shape (2, 2)
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+ Returns:
+ kernel (ndarrray): un-normalized kernel.
+ """
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+ return kernel
+
+
+def cdf2(D, grid):
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
+ Used in skewed Gaussian distribution.
+ Args:
+ D (ndarrasy): skew matrix.
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+ Returns:
+ cdf (ndarray): skewed cdf.
+ """
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
+ grid = np.dot(grid, D)
+ cdf = rv.cdf(grid)
+ return cdf
+
+
+def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
+ """Generate a bivariate skew Gaussian kernel.
+ Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ D (ndarrasy): skew matrix.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ .. _A multivariate skew normal distribution:
+ https://www.sciencedirect.com/science/article/pii/S0047259X03001313
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ pdf = pdf2(sigma_matrix, grid)
+ cdf = cdf2(D, grid)
+ kernel = pdf * cdf
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def mass_center_shift(kernel_size, kernel):
+ """Calculate the shift of the mass center of a kenrel.
+ Args:
+ kernel_size (int):
+ kernel (ndarray): normalized kernel.
+ Returns:
+ delta_h (float):
+ delta_w (float):
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
+ delta_h = np.dot(row_sum, ax)
+ delta_w = np.dot(col_sum, ax)
+ return delta_h, delta_w
+
+
+def bivariate_skew_Gaussian_center(kernel_size,
+ sig_x,
+ sig_y,
+ theta,
+ D,
+ grid=None):
+ """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ D (ndarrasy): skew matrix.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): centered and normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+ kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_anisotropic_Gaussian(kernel_size,
+ sig_x,
+ sig_y,
+ theta,
+ grid=None):
+ """Generate a bivariate anisotropic Gaussian kernel.
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
+ """Generate a bivariate isotropic Gaussian kernel.
+ Args:
+ kernel_size (int):
+ sig (float):
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_generalized_Gaussian(kernel_size,
+ sig_x,
+ sig_y,
+ theta,
+ beta,
+ grid=None):
+ """Generate a bivariate generalized Gaussian kernel.
+ Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
+ by Pascal et. al (2013).
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
+ https://arxiv.org/abs/1302.6498
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(
+ -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
+ """Generate a plateau-like anisotropic kernel.
+ 1 / (1+x^(beta))
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
+ """Generate a plateau-like isotropic kernel.
+ 1 / (1+x^(beta))
+ Args:
+ kernel_size (int):
+ sig (float):
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(
+ np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_skew_Gaussian_center(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate skew Gaussian kernels at center.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+
+ sigma_max = np.max([sigma_x, sigma_y])
+ thres = 3 / sigma_max
+ D = [[np.random.uniform(-thres, thres),
+ np.random.uniform(-thres, thres)],
+ [np.random.uniform(-thres, thres),
+ np.random.uniform(-thres, thres)]]
+
+ kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
+ rotation, D)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation, D
+ else:
+ return kernel
+
+
+def random_bivariate_anisotropic_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate anisotropic Gaussian kernels.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
+ rotation)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation
+ else:
+ return kernel
+
+
+def random_bivariate_isotropic_Gaussian(kernel_size,
+ sigma_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate isotropic Gaussian kernels.
+ Args:
+ kernel_size (int):
+ sigma_range (tuple): [0.6, 5]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+
+ kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma
+ else:
+ return kernel
+
+
+def random_bivariate_generalized_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate generalized Gaussian kernels.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
+ rotation, beta)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation, beta
+ else:
+ return kernel
+
+
+def random_bivariate_plateau_type1(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate plateau type1 kernels.
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi/2, math.pi/2]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ if strict:
+ sigma_max = np.max([sigma_x, sigma_y])
+ sigma_min = np.min([sigma_x, sigma_y])
+ sigma_x, sigma_y = sigma_max, sigma_min
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
+ beta)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma_x, sigma_y, rotation, beta
+ else:
+ return kernel
+
+
+def random_bivariate_plateau_type1_iso(kernel_size,
+ sigma_range,
+ beta_range,
+ noise_range=None,
+ strict=False):
+ """Randomly generate bivariate plateau type1 kernels (iso).
+ Args:
+ kernel_size (int):
+ sigma_range (tuple): [0.6, 5]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+ beta = np.random.uniform(beta_range[0], beta_range[1])
+
+ kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if strict:
+ return kernel, sigma, beta
+ else:
+ return kernel
+
+
+def random_mixed_kernels(kernel_list,
+ kernel_prob,
+ kernel_size=21,
+ sigma_x_range=[0.6, 5],
+ sigma_y_range=[0.6, 5],
+ rotation_range=[-math.pi, math.pi],
+ beta_range=[0.5, 8],
+ noise_range=None):
+ """Randomly generate mixed kernels.
+ Args:
+ kernel_list (tuple): a list name of kenrel types,
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
+ kernel_prob (tuple): corresponding kernel probability for each kernel type
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
+ Returns:
+ kernel (ndarray):
+ """
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
+ if kernel_type == 'iso':
+ kernel = random_bivariate_isotropic_Gaussian(
+ kernel_size, sigma_x_range, noise_range=noise_range)
+ elif kernel_type == 'aniso':
+ kernel = random_bivariate_anisotropic_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=noise_range)
+ elif kernel_type == 'skew':
+ kernel = random_bivariate_skew_Gaussian_center(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=noise_range)
+ elif kernel_type == 'generalized':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=noise_range)
+ elif kernel_type == 'plateau_iso':
+ kernel = random_bivariate_plateau_type1_iso(
+ kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
+ elif kernel_type == 'plateau_aniso':
+ kernel = random_bivariate_plateau_type1(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=noise_range)
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(
+ noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def show_one_kernel():
+ import matplotlib.pyplot as plt
+ kernel_size = 21
+
+ # bivariate skew Gaussian
+ D = [[0, 0], [0, 0]]
+ D = [[3 / 4, 0], [0, 0.5]]
+ kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
+ # bivariate anisotropic Gaussian
+ kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
+ # bivariate anisotropic Gaussian
+ kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
+ # bivariate generalized Gaussian
+ kernel = bivariate_generalized_Gaussian(
+ kernel_size, 2, 4, -math.pi / 4, beta=4)
+
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+ print(delta_h, delta_w)
+
+ fig, axs = plt.subplots(nrows=2, ncols=2)
+ # axs.set_axis_off()
+ ax = axs[0][0]
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
+ fig.colorbar(im, ax=ax)
+
+ # image
+ ax = axs[0][1]
+ kernel_vis = kernel - np.min(kernel)
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
+ ax.imshow(kernel_vis, interpolation='nearest')
+
+ _, xx, yy = mesh_grid(kernel_size)
+ # contour
+ ax = axs[1][0]
+ CS = ax.contour(xx, yy, kernel, origin='upper')
+ ax.clabel(CS, inline=1, fontsize=3)
+
+ # contourf
+ ax = axs[1][1]
+ kernel = kernel / np.max(kernel)
+ p = ax.contourf(
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
+ fig.colorbar(p)
+
+ plt.show()
+
+
+def show_plateau_kernel():
+ import matplotlib.pyplot as plt
+ kernel_size = 21
+
+ kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
+ kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
+ kernel_gau = bivariate_generalized_Gaussian(
+ kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
+ delta_h, delta_w = mass_center_shift(kernel_size, kernel)
+ print(delta_h, delta_w)
+
+ # kernel_slice = kernel[10, :]
+ # kernel_gau_slice = kernel_gau[10, :]
+ # kernel_norm_slice = kernel_norm[10, :]
+ # fig, ax = plt.subplots()
+ # t = list(range(1, 22))
+
+ # ax.plot(t, kernel_gau_slice)
+ # ax.plot(t, kernel_slice)
+ # ax.plot(t, kernel_norm_slice)
+
+ # t = np.arange(0, 10, 0.1)
+ # y = np.exp(-0.5 * t)
+ # y2 = np.reciprocal(1 + t)
+ # print(t.shape)
+ # print(y.shape)
+ # ax.plot(t, y)
+ # ax.plot(t, y2)
+ # plt.show()
+
+ fig, axs = plt.subplots(nrows=2, ncols=2)
+ # axs.set_axis_off()
+ ax = axs[0][0]
+ im = ax.matshow(kernel, cmap='jet', origin='upper')
+ fig.colorbar(im, ax=ax)
+
+ # image
+ ax = axs[0][1]
+ kernel_vis = kernel - np.min(kernel)
+ kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
+ ax.imshow(kernel_vis, interpolation='nearest')
+
+ _, xx, yy = mesh_grid(kernel_size)
+ # contour
+ ax = axs[1][0]
+ CS = ax.contour(xx, yy, kernel, origin='upper')
+ ax.clabel(CS, inline=1, fontsize=3)
+
+ # contourf
+ ax = axs[1][1]
+ kernel = kernel / np.max(kernel)
+ p = ax.contourf(
+ xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
+ fig.colorbar(p)
+
+ plt.show()
diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..72e7d02650c0ce9c6c949314440d1d02dae9ec39
--- /dev/null
+++ b/basicsr/data/paired_image_dataset.py
@@ -0,0 +1,101 @@
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class PairedImageDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
+ GT image pairs.
+
+ There are three modes:
+ 1. 'lmdb': Use lmdb files.
+ If opt['io_backend'] == lmdb.
+ 2. 'meta_info_file': Use meta information file to generate paths.
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+ 3. 'folder': Scan folders to generate paths.
+ The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_flip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h
+ and w for implementation).
+
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(PairedImageDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ if 'filename_tmpl' in opt:
+ self.filename_tmpl = opt['filename_tmpl']
+ else:
+ self.filename_tmpl = '{}'
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
+ self.opt['meta_info_file'], self.filename_tmpl)
+ else:
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
+
+ # TODO: color space transform
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce127790dc9e4f7014604ba818fa764c3f695a15
--- /dev/null
+++ b/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,125 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccf983540d4bca99fece9e0b41f7c1b7d3513db0
--- /dev/null
+++ b/basicsr/data/transforms.py
@@ -0,0 +1,165 @@
+import cv2
+import random
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[:h - h_remainder, :w - w_remainder, ...]
+ else:
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
+ """Paired random crop.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ h_lq, w_lq, _ = img_lqs[0].shape
+ h_gt, w_gt, _ = img_gts[0].shape
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5432575fb782a1e059938a5068c9526183ce5853
--- /dev/null
+++ b/basicsr/losses/__init__.py
@@ -0,0 +1,26 @@
+from copy import deepcopy
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import LOSS_REGISTRY
+from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
+ gradient_penalty_loss, r1_penalty)
+
+__all__ = [
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
+ 'r1_penalty', 'g_path_regularize'
+]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop('type')
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+ return loss
diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dcfc85f4479ea7d0773ce33c91870cf36b392f5
--- /dev/null
+++ b/basicsr/losses/loss_util.py
@@ -0,0 +1,95 @@
+import functools
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == 'sum':
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == 'mean':
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
diff --git a/basicsr/losses/losses.py b/basicsr/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb965afe533e6df4245c2d4ec8787926424d4b6
--- /dev/null
+++ b/basicsr/losses/losses.py
@@ -0,0 +1,455 @@
+import math
+import lpips
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+ return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(L1Loss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+ """MSE (L2) loss.
+
+ Args:
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(MSELoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+ variant of L1Loss).
+
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+ Super-Resolution".
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ eps (float): A value used to control the curvature near zero.
+ Default: 1e-12.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+ super(CharbonnierLoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+ """Weighted TV loss.
+
+ Args:
+ loss_weight (float): Loss weight. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
+
+ def forward(self, pred, weight=None):
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
+
+ loss = x_diff + y_diff
+
+ return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculting losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'mse':
+ self.criterion = torch.nn.MSELoss(reduction='mean')
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+ gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+
+@LOSS_REGISTRY.register()
+class LPIPSLoss(nn.Module):
+ def __init__(self,
+ loss_weight=1.0,
+ use_input_norm=True,
+ range_norm=False,):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target):
+ if self.range_norm:
+ pred = (pred + 1) / 2
+ target = (target + 1) / 2
+ if self.use_input_norm:
+ pred = (pred - self.mean) / self.std
+ target = (target - self.mean) / self.std
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean()
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ target_label = self.get_target_label(input, target_is_real)
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Ref:
+ Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fd5f6f6875cf2cdfdb44a26ee3610e31cf7bc68
--- /dev/null
+++ b/basicsr/metrics/__init__.py
@@ -0,0 +1,19 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ['calculate_psnr', 'calculate_ssim']
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop('type')
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..865158efbca8f2dea4667c61c4739bdc39d9ecda
--- /dev/null
+++ b/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils.matlab_functions import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.
diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..325558f132c1712e95e79ec8b86faec1cd1bf063
--- /dev/null
+++ b/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,128 @@
+import cv2
+import numpy as np
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def _ssim(img1, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: ssim result.
+ """
+
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate SSIM (structural similarity).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the SSIM calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: ssim result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ ssims = []
+ for i in range(img1.shape[2]):
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
+ return np.array(ssims).mean()
diff --git a/basicsr/models/__init__.py b/basicsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8e66cd3a726fc18cb54e2fd5eb024d0dcc6f2ff
--- /dev/null
+++ b/basicsr/models/__init__.py
@@ -0,0 +1,30 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+ logger = get_root_logger()
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4de5ea3806dd239ff57cc267a51880dfc373836
--- /dev/null
+++ b/basicsr/models/base_model.py
@@ -0,0 +1,322 @@
+import logging
+import os
+import torch
+from collections import OrderedDict
+from copy import deepcopy
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from basicsr.models import lr_scheduler as lr_scheduler
+from basicsr.utils.dist_util import master_only
+
+logger = logging.getLogger('basicsr')
+
+
+class BaseModel():
+ """Base model."""
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.is_train = opt['is_train']
+ self.schedulers = []
+ self.optimizers = []
+
+ def feed_data(self, data):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ pass
+
+ def save(self, epoch, current_iter):
+ """Save networks and training state."""
+ pass
+
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+ """Validation function.
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
+ current_iter (int): Current iteration.
+ tb_logger (tensorboard logger): Tensorboard logger.
+ save_img (bool): Whether to save images. Default: False.
+ """
+ if self.opt['dist']:
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+ else:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def model_ema(self, decay=0.999):
+ net_g = self.get_bare_model(self.net_g)
+
+ net_g_params = dict(net_g.named_parameters())
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
+
+ for k in net_g_ema_params.keys():
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def model_to_device(self, net):
+ """Model to device. It also warps models with DistributedDataParallel
+ or DataParallel.
+
+ Args:
+ net (nn.Module)
+ """
+ net = net.to(self.device)
+ if self.opt['dist']:
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
+ net = DistributedDataParallel(
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+ elif self.opt['num_gpu'] > 1:
+ net = DataParallel(net)
+ return net
+
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
+ if optim_type == 'Adam':
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
+ else:
+ raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
+ return optimizer
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ train_opt = self.opt['train']
+ scheduler_type = train_opt['scheduler'].pop('type')
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
+ else:
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def get_bare_model(self, net):
+ """Get bare model, especially under wrapping with
+ DistributedDataParallel or DataParallel.
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net = net.module
+ return net
+
+ @master_only
+ def print_network(self, net):
+ """Print the str and parameter number of a network.
+
+ Args:
+ net (nn.Module)
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
+ else:
+ net_cls_str = f'{net.__class__.__name__}'
+
+ net = self.get_bare_model(net)
+ net_str = str(net)
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
+
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
+ logger.info(net_str)
+
+ def _set_lr(self, lr_groups_l):
+ """Set learning rate for warmup.
+
+ Args:
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
+ """
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
+ param_group['lr'] = lr
+
+ def _get_init_lr(self):
+ """Get the initial lr, which is set by the scheduler.
+ """
+ init_lr_groups_l = []
+ for optimizer in self.optimizers:
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
+ return init_lr_groups_l
+
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warmup iter numbers. -1 for no warmup.
+ Default: -1.
+ """
+ if current_iter > 1:
+ for scheduler in self.schedulers:
+ scheduler.step()
+ # set up warm-up learning rate
+ if current_iter < warmup_iter:
+ # get initial lr for each group
+ init_lr_g_l = self._get_init_lr()
+ # modify warming-up learning rates
+ # currently only support linearly warm up
+ warm_up_lr_l = []
+ for init_lr_g in init_lr_g_l:
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
+ # set learning rate
+ self._set_lr(warm_up_lr_l)
+
+ def get_current_learning_rate(self):
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
+
+ @master_only
+ def save_network(self, net, net_label, current_iter, param_key='params'):
+ """Save networks.
+
+ Args:
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ param_key (str | list[str]): The parameter key(s) to save network.
+ Default: 'params'.
+ """
+ if current_iter == -1:
+ current_iter = 'latest'
+ save_filename = f'{net_label}_{current_iter}.pth'
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
+
+ net = net if isinstance(net, list) else [net]
+ param_key = param_key if isinstance(param_key, list) else [param_key]
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
+
+ save_dict = {}
+ for net_, param_key_ in zip(net, param_key):
+ net_ = self.get_bare_model(net_)
+ state_dict = net_.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ state_dict[key] = param.cpu()
+ save_dict[param_key_] = state_dict
+
+ torch.save(save_dict, save_path)
+
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
+ """Print keys with differnet name or different size when loading models.
+
+ 1. Print keys with differnet names.
+ 2. If strict=False, print the same key but with different tensor size.
+ It also ignore these keys with different sizes (not load).
+
+ Args:
+ crt_net (torch model): Current network.
+ load_net (dict): Loaded network.
+ strict (bool): Whether strictly loaded. Default: True.
+ """
+ crt_net = self.get_bare_model(crt_net)
+ crt_net = crt_net.state_dict()
+ crt_net_keys = set(crt_net.keys())
+ load_net_keys = set(load_net.keys())
+
+ if crt_net_keys != load_net_keys:
+ logger.warning('Current net - loaded net:')
+ for v in sorted(list(crt_net_keys - load_net_keys)):
+ logger.warning(f' {v}')
+ logger.warning('Loaded net - current net:')
+ for v in sorted(list(load_net_keys - crt_net_keys)):
+ logger.warning(f' {v}')
+
+ # check the size for the same keys
+ if not strict:
+ common_keys = crt_net_keys & load_net_keys
+ for k in common_keys:
+ if crt_net[k].size() != load_net[k].size():
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
+ load_net[k + '.ignore'] = load_net.pop(k)
+
+ def load_network(self, net, load_path, strict=True, param_key='params'):
+ """Load network.
+
+ Args:
+ load_path (str): The path of networks to be loaded.
+ net (nn.Module): Network.
+ strict (bool): Whether strictly loaded.
+ param_key (str): The parameter key of loaded network. If set to
+ None, use the root 'path'.
+ Default: 'params'.
+ """
+ net = self.get_bare_model(net)
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}.')
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
+ if param_key is not None:
+ if param_key not in load_net and 'params' in load_net:
+ param_key = 'params'
+ logger.info('Loading: params_ema does not exist, use params.')
+ load_net = load_net[param_key]
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ self._print_different_keys_loading(net, load_net, strict)
+ net.load_state_dict(load_net, strict=strict)
+
+ @master_only
+ def save_training_state(self, epoch, current_iter):
+ """Save training states during training, which will be used for
+ resuming.
+
+ Args:
+ epoch (int): Current epoch.
+ current_iter (int): Current iteration.
+ """
+ if current_iter != -1:
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
+ for o in self.optimizers:
+ state['optimizers'].append(o.state_dict())
+ for s in self.schedulers:
+ state['schedulers'].append(s.state_dict())
+ save_filename = f'{current_iter}.state'
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
+ torch.save(state, save_path)
+
+ def resume_training(self, resume_state):
+ """Reload the optimizers and schedulers for resumed training.
+
+ Args:
+ resume_state (dict): Resume state.
+ """
+ resume_optimizers = resume_state['optimizers']
+ resume_schedulers = resume_state['schedulers']
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
+ for i, o in enumerate(resume_optimizers):
+ self.optimizers[i].load_state_dict(o)
+ for i, s in enumerate(resume_schedulers):
+ self.schedulers[i].load_state_dict(s)
+
+ def reduce_loss_dict(self, loss_dict):
+ """reduce loss dict.
+
+ In distributed training, it averages the losses among different GPUs .
+
+ Args:
+ loss_dict (OrderedDict): Loss dict.
+ """
+ with torch.no_grad():
+ if self.opt['dist']:
+ keys = []
+ losses = []
+ for name, value in loss_dict.items():
+ keys.append(name)
+ losses.append(value)
+ losses = torch.stack(losses, 0)
+ torch.distributed.reduce(losses, dst=0)
+ if self.opt['rank'] == 0:
+ losses /= self.opt['world_size']
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
+
+ log_dict = OrderedDict()
+ for name, value in loss_dict.items():
+ log_dict[name] = value.mean().item()
+
+ return log_dict
diff --git a/basicsr/models/codeformer_idx_model.py b/basicsr/models/codeformer_idx_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..581d7c88c7aac7f177537cec366744b0569054c4
--- /dev/null
+++ b/basicsr/models/codeformer_idx_model.py
@@ -0,0 +1,220 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+import torch.nn.functional as F
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class CodeFormerIdxModel(SRModel):
+ def feed_data(self, data):
+ self.gt = data['gt'].to(self.device)
+ self.input = data['in'].to(self.device)
+ self.b = self.gt.shape[0]
+
+ if 'latent_gt' in data:
+ self.idx_gt = data['latent_gt'].to(self.device)
+ self.idx_gt = self.idx_gt.view(self.b, -1)
+ else:
+ self.idx_gt = None
+
+ def init_training_settings(self):
+ logger = get_root_logger()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
+ self.generate_idx_gt = False
+ elif self.opt.get('network_vqgan', None) is not None:
+ self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
+ self.hq_vqgan_fix.eval()
+ self.generate_idx_gt = True
+ for param in self.hq_vqgan_fix.parameters():
+ param.requires_grad = False
+ else:
+ raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
+
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
+
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
+
+ self.net_g.train()
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ # optimize net_g
+ self.optimizer_g.zero_grad()
+
+ if self.generate_idx_gt:
+ x = self.hq_vqgan_fix.encoder(self.gt)
+ _, _, quant_stats = self.hq_vqgan_fix.quantize(x)
+ min_encoding_indices = quant_stats['min_encoding_indices']
+ self.idx_gt = min_encoding_indices.view(self.b, -1)
+
+ if self.hq_feat_loss:
+ # quant_feats
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
+
+ logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ # hq_feat_loss
+ if self.hq_feat_loss: # codebook loss
+ l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
+ l_g_total += l_feat_encoder
+ loss_dict['l_feat_encoder'] = l_feat_encoder
+
+ # cross_entropy_loss
+ if self.cross_entropy_loss:
+ # b(hw)n -> bn(hw)
+ cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
+ l_g_total += cross_entropy_loss
+ loss_dict['cross_entropy_loss'] = cross_entropy_loss
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+
+ def test(self):
+ with torch.no_grad():
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ self.output, _, _ = self.net_g_ema(self.input, w=0)
+ else:
+ logger = get_root_logger()
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
+ self.net_g.eval()
+ self.output, _, _ = self.net_g(self.input, w=0)
+ self.net_g.train()
+
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['gt'] = self.gt.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ return out_dict
+
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/codeformer_joint_model.py b/basicsr/models/codeformer_joint_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca9bb0c40d57c1721646122878d8bf617c4b7e4
--- /dev/null
+++ b/basicsr/models/codeformer_joint_model.py
@@ -0,0 +1,350 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+import torch.nn.functional as F
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class CodeFormerJointModel(SRModel):
+ def feed_data(self, data):
+ self.gt = data['gt'].to(self.device)
+ self.input = data['in'].to(self.device)
+ self.input_large_de = data['in_large_de'].to(self.device)
+ self.b = self.gt.shape[0]
+
+ if 'latent_gt' in data:
+ self.idx_gt = data['latent_gt'].to(self.device)
+ self.idx_gt = self.idx_gt.view(self.b, -1)
+ else:
+ self.idx_gt = None
+
+ def init_training_settings(self):
+ logger = get_root_logger()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ if self.opt['datasets']['train'].get('latent_gt_path', None) is not None:
+ self.generate_idx_gt = False
+ elif self.opt.get('network_vqgan', None) is not None:
+ self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
+ self.hq_vqgan_fix.eval()
+ self.generate_idx_gt = True
+ for param in self.hq_vqgan_fix.parameters():
+ param.requires_grad = False
+ else:
+ raise NotImplementedError(f'Shoule have network_vqgan config or pre-calculated latent code.')
+
+ logger.info(f'Need to generate latent GT code: {self.generate_idx_gt}')
+
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
+ self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+
+ self.fix_generator = train_opt.get('fix_generator', True)
+ logger.info(f'fix_generator: {self.fix_generator}')
+
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
+ return d_weight
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def gray_resize_for_identity(self, out, size=128):
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
+ out_gray = out_gray.unsqueeze(1)
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
+ return out_gray
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+
+ if self.generate_idx_gt:
+ x = self.hq_vqgan_fix.encoder(self.gt)
+ output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
+ min_encoding_indices = quant_stats['min_encoding_indices']
+ self.idx_gt = min_encoding_indices.view(self.b, -1)
+
+ if current_iter <= 40000: # small degradation
+ small_per_n = 1
+ w = 1
+ elif current_iter <= 80000: # small degradation
+ small_per_n = 1
+ w = 1.3
+ elif current_iter <= 120000: # large degradation
+ small_per_n = 120000
+ w = 0
+ else: # mixed degradation
+ small_per_n = 15
+ w = 1.3
+
+ if current_iter % small_per_n == 0:
+ self.output, logits, lq_feat = self.net_g(self.input, w=w, detach_16=True)
+ large_de = False
+ else:
+ logits, lq_feat = self.net_g(self.input_large_de, code_only=True)
+ large_de = True
+
+ if self.hq_feat_loss:
+ # quant_feats
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
+ # hq_feat_loss
+ if not 'transformer' in self.opt['network_g']['fix_modules']:
+ if self.hq_feat_loss: # codebook loss
+ l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
+ l_g_total += l_feat_encoder
+ loss_dict['l_feat_encoder'] = l_feat_encoder
+
+ # cross_entropy_loss
+ if self.cross_entropy_loss:
+ # b(hw)n -> bn(hw)
+ cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
+ l_g_total += cross_entropy_loss
+ loss_dict['cross_entropy_loss'] = cross_entropy_loss
+
+ # pixel loss
+ if not large_de: # when large degradation don't need image-level loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+
+ # gan loss
+ if current_iter > self.net_d_start_iter:
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ recon_loss = l_g_pix + l_g_percep
+ if not self.fix_generator:
+ last_layer = self.net_g.module.generator.blocks[-1].weight
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
+ else:
+ largest_fuse_size = self.opt['network_g']['connect_list'][-1]
+ last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
+
+ d_weight *= self.scale_adaptive_gan_weight # 0.8
+ loss_dict['d_weight'] = d_weight
+ l_g_total += d_weight * l_g_gan
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ # optimize net_d
+ if not large_de:
+ if current_iter > self.net_d_start_iter:
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+
+ def test(self):
+ with torch.no_grad():
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ self.output, _, _ = self.net_g_ema(self.input, w=1)
+ else:
+ logger = get_root_logger()
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
+ self.net_g.eval()
+ self.output, _, _ = self.net_g(self.input, w=1)
+ self.net_g.train()
+
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['gt'] = self.gt.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ return out_dict
+
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/codeformer_model.py b/basicsr/models/codeformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4db04aaba6f7fb0eb288160fd944ec6d667ec31
--- /dev/null
+++ b/basicsr/models/codeformer_model.py
@@ -0,0 +1,332 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+import torch.nn.functional as F
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class CodeFormerModel(SRModel):
+ def feed_data(self, data):
+ self.gt = data['gt'].to(self.device)
+ self.input = data['in'].to(self.device)
+ self.b = self.gt.shape[0]
+
+ if 'latent_gt' in data:
+ self.idx_gt = data['latent_gt'].to(self.device)
+ self.idx_gt = self.idx_gt.view(self.b, -1)
+ else:
+ self.idx_gt = None
+
+ def init_training_settings(self):
+ logger = get_root_logger()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ if self.opt.get('network_vqgan', None) is not None and self.opt['datasets'].get('latent_gt_path') is None:
+ self.hq_vqgan_fix = build_network(self.opt['network_vqgan']).to(self.device)
+ self.hq_vqgan_fix.eval()
+ self.generate_idx_gt = True
+ for param in self.hq_vqgan_fix.parameters():
+ param.requires_grad = False
+ else:
+ self.generate_idx_gt = False
+
+ self.hq_feat_loss = train_opt.get('use_hq_feat_loss', True)
+ self.feat_loss_weight = train_opt.get('feat_loss_weight', 1.0)
+ self.cross_entropy_loss = train_opt.get('cross_entropy_loss', True)
+ self.entropy_loss_weight = train_opt.get('entropy_loss_weight', 0.5)
+ self.fidelity_weight = train_opt.get('fidelity_weight', 1.0)
+ self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
+
+
+ self.net_g.train()
+ # define network net_d
+ if self.fidelity_weight > 0:
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
+
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+
+ self.fix_generator = train_opt.get('fix_generator', True)
+ logger.info(f'fix_generator: {self.fix_generator}')
+
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
+ return d_weight
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ if self.fidelity_weight > 0:
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def gray_resize_for_identity(self, out, size=128):
+ out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
+ out_gray = out_gray.unsqueeze(1)
+ out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
+ return out_gray
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+
+ if self.generate_idx_gt:
+ x = self.hq_vqgan_fix.encoder(self.gt)
+ output, _, quant_stats = self.hq_vqgan_fix.quantize(x)
+ min_encoding_indices = quant_stats['min_encoding_indices']
+ self.idx_gt = min_encoding_indices.view(self.b, -1)
+
+ if self.fidelity_weight > 0:
+ self.output, logits, lq_feat = self.net_g(self.input, w=self.fidelity_weight, detach_16=True)
+ else:
+ logits, lq_feat = self.net_g(self.input, w=0, code_only=True)
+
+ if self.hq_feat_loss:
+ # quant_feats
+ quant_feat_gt = self.net_g.module.quantize.get_codebook_feat(self.idx_gt, shape=[self.b,16,16,256])
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
+ # hq_feat_loss
+ if self.hq_feat_loss: # codebook loss
+ l_feat_encoder = torch.mean((quant_feat_gt.detach()-lq_feat)**2) * self.feat_loss_weight
+ l_g_total += l_feat_encoder
+ loss_dict['l_feat_encoder'] = l_feat_encoder
+
+ # cross_entropy_loss
+ if self.cross_entropy_loss:
+ # b(hw)n -> bn(hw)
+ cross_entropy_loss = F.cross_entropy(logits.permute(0, 2, 1), self.idx_gt) * self.entropy_loss_weight
+ l_g_total += cross_entropy_loss
+ loss_dict['cross_entropy_loss'] = cross_entropy_loss
+
+ if self.fidelity_weight > 0: # when fidelity_weight == 0 don't need image-level loss
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+
+ # gan loss
+ if current_iter > self.net_d_start_iter:
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ recon_loss = l_g_pix + l_g_percep
+ if not self.fix_generator:
+ last_layer = self.net_g.module.generator.blocks[-1].weight
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
+ else:
+ largest_fuse_size = self.opt['network_g']['connect_list'][-1]
+ last_layer = self.net_g.module.fuse_convs_dict[largest_fuse_size].shift[-1].weight
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
+
+ d_weight *= self.scale_adaptive_gan_weight # 0.8
+ loss_dict['d_weight'] = d_weight
+ l_g_total += d_weight * l_g_gan
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ # optimize net_d
+ if current_iter > self.net_d_start_iter and self.fidelity_weight > 0:
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+
+ def test(self):
+ with torch.no_grad():
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ self.output, _, _ = self.net_g_ema(self.input, w=self.fidelity_weight)
+ else:
+ logger = get_root_logger()
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
+ self.net_g.eval()
+ self.output, _, _ = self.net_g(self.input, w=self.fidelity_weight)
+ self.net_g.train()
+
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['gt'] = self.gt.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ return out_dict
+
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ if self.fidelity_weight > 0:
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7d21bd31146d08ba6c4545814cd6df61c22f58
--- /dev/null
+++ b/basicsr/models/lr_scheduler.py
@@ -0,0 +1,96 @@
+import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+ """ MultiStep with restarts learning rate scheme.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ milestones (list): Iterations that will decrease learning rate.
+ gamma (float): Decrease ratio. Default: 0.1.
+ restarts (list): Restart iterations. Default: [0].
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.restarts = restarts
+ self.restart_weights = restart_weights
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The mimimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_min = eta_min
+ assert (len(self.periods) == len(
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+
+ return [
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..136fdcde7ef6af45add81de0f19c7bd8a7cf95bc
--- /dev/null
+++ b/basicsr/models/sr_model.py
@@ -0,0 +1,209 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+@MODEL_REGISTRY.register()
+class SRModel(BaseModel):
+ """Base SR model for single image super-resolution."""
+
+ def __init__(self, opt):
+ super(SRModel, self).__init__(opt)
+
+ # define network
+ self.net_g = build_network(opt['network_g'])
+ self.net_g = self.model_to_device(self.net_g)
+ self.print_network(self.net_g)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_g', 'params')
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+ if self.is_train:
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ self.net_g.train()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if self.cri_pix is None and self.cri_perceptual is None:
+ raise ValueError('Both pixel and perceptual losses are None.')
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ optim_params = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def feed_data(self, data):
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+
+ def optimize_parameters(self, current_iter):
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_total = 0
+ loss_dict = OrderedDict()
+ # pixel loss
+ if self.cri_pix:
+ l_pix = self.cri_pix(self.output, self.gt)
+ l_total += l_pix
+ loss_dict['l_pix'] = l_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_percep, l_style = self.cri_perceptual(self.output, self.gt)
+ if l_percep is not None:
+ l_total += l_percep
+ loss_dict['l_percep'] = l_percep
+ if l_style is not None:
+ l_total += l_style
+ loss_dict['l_style'] = l_style
+
+ l_total.backward()
+ self.optimizer_g.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def test(self):
+ if hasattr(self, 'ema_decay'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(self.lq)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(self.lq)
+ self.net_g.train()
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['lq'] = self.lq.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ if hasattr(self, 'gt'):
+ out_dict['gt'] = self.gt.detach().cpu()
+ return out_dict
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'ema_decay'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/models/vqgan_model.py b/basicsr/models/vqgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..05bcbb594b962a1271aa2df84749c647ec3bf8a2
--- /dev/null
+++ b/basicsr/models/vqgan_model.py
@@ -0,0 +1,285 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+import torch.nn.functional as F
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class VQGANModel(SRModel):
+ def feed_data(self, data):
+ self.gt = data['gt'].to(self.device)
+ self.b = self.gt.shape[0]
+
+
+ def init_training_settings(self):
+ logger = get_root_logger()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ if train_opt.get('codebook_opt'):
+ self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
+ else:
+ self.l_weight_codebook = 1.0
+
+ self.vqgan_quantizer = self.opt['network_g']['quantizer']
+ logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
+
+ self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
+ self.disc_weight = train_opt.get('disc_weight', 0.8)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
+ recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
+ return d_weight
+
+ def adopt_weight(self, weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_params_g = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params_g.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ loss_dict = OrderedDict()
+ if self.opt['network_g']['quantizer'] == 'gumbel':
+ self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
+ if current_iter%1000 == 0:
+ logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output, l_codebook, quant_stats = self.net_g(self.gt)
+
+ l_codebook = l_codebook*self.l_weight_codebook
+
+ l_g_total = 0
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep = self.cri_perceptual(self.output, self.gt)
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+
+ # gan loss
+ if current_iter > self.net_d_start_iter:
+ # fake_g_pred = self.net_d(self.output_1024)
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ recon_loss = l_g_total
+ last_layer = self.net_g.module.generator.blocks[-1].weight
+ d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
+ d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
+ d_weight *= self.disc_weight # tamming setting 0.8
+ l_g_total += d_weight * l_g_gan
+ loss_dict['l_g_gan'] = d_weight * l_g_gan
+
+ l_g_total += l_codebook
+ loss_dict['l_codebook'] = l_codebook
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ if current_iter > self.net_d_start_iter:
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+
+ def test(self):
+ with torch.no_grad():
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ self.output, _, _ = self.net_g_ema(self.gt)
+ else:
+ logger = get_root_logger()
+ logger.warning('Do not have self.net_g_ema, use self.net_g.')
+ self.net_g.eval()
+ self.output, _, _ = self.net_g(self.gt)
+ self.net_g.train()
+
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ metric_data = dict(img1=sr_img, img2=gt_img)
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}\n'
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['gt'] = self.gt.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ return out_dict
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/basicsr/ops/__init__.py b/basicsr/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b534fc667eefc85cff7b025dcdfd2d0057c6fe35
--- /dev/null
+++ b/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+ modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+ 'modulated_deform_conv'
+]
diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cd07fe0e9ddf19e39785c20c8bbcf85fced1f63
--- /dev/null
+++ b/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,377 @@
+import math
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+try:
+ from . import deform_conv_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ deform_conv_ext = load(
+ 'deform_conv',
+ sources=[
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+ ],
+ )
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_ext.deform_conv_forward(input, weight,
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} is not divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} is not divisible ' \
+ f'by groups {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
diff --git a/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6fbef833f96bd0e7060bbe69685b4f1ab8011a62
--- /dev/null
+++ b/basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include تحسين وترميم صور الوجه تلقائياً
+Processed {{ results|length }} images successfully.
+