diff --git a/diffusion-posterior-sampling/bkse/LICENSE b/diffusion-posterior-sampling/bkse/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..370961d1e704e321bf630bf26cc53babc051551b --- /dev/null +++ b/diffusion-posterior-sampling/bkse/LICENSE @@ -0,0 +1,203 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +'+ + diff --git a/diffusion-posterior-sampling/bkse/README.md b/diffusion-posterior-sampling/bkse/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e61b7d3c775026420a0e9d3a5c92b99e10b60681 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/README.md @@ -0,0 +1,181 @@ +# Exploring Image Deblurring via Encoded Blur Kernel Space + +## About the project + +We introduce a method to encode the blur operators of an arbitrary dataset of sharp-blur image pairs into a blur kernel space. Assuming the encoded kernel space is close enough to in-the-wild blur operators, we propose an alternating optimization algorithm for blind image deblurring. It approximates an unseen blur operator by a kernel in the encoded space and searches for the corresponding sharp image. Due to the method's design, the encoded kernel space is fully differentiable, thus can be easily adopted in deep neural network models. + +![Blur kernel space](imgs/teaser.jpg) + +Detail of the method and experimental results can be found in [our following paper](https://arxiv.org/abs/2104.00317): +``` +@inproceedings{m_Tran-etal-CVPR21, +  author = {Phong Tran and Anh Tran and Quynh Phung and Minh Hoai}, +  title = {Explore Image Deblurring via Encoded Blur Kernel Space}, +  year = {2021}, +  booktitle = {Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition (CVPR)} +} +``` +Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GDvbr4WQUibaEhQVzYPPObV4STn9NAot?usp=sharing) + +## Table of Content + +* [About the Project](#about-the-project) +* [Getting Started](#getting-started) + * [Prerequisites](#prerequisites) + * [Installation](#installation) + * [Using the pretrained model](#Using-the-pretrained-model) +* [Training and evaluation](#Training-and-evaluation) +* [Model Zoo](#Model-zoo) + +## Getting started + +### Prerequisites + +* Python >= 3.7 +* Pytorch >= 1.4.0 +* CUDA >= 10.0 + +### Installation + +``` sh +git clone https://github.com/VinAIResearch/blur-kernel-space-exploring.git +cd blur-kernel-space-exploring + + +conda create -n BlurKernelSpace -y python=3.7 +conda activate BlurKernelSpace +conda install --file requirements.txt +``` + +## Training and evaluation +### Preparing datasets +You can download the datasets in the [model zoo section](#model-zoo). + +To use your customized dataset, your dataset must be organized as follow: +``` +root +├── blur_imgs + ├── 000 + ├──── 00000000.png + ├──── 00000001.png + ├──── ... + ├── 001 + ├──── 00000000.png + ├──── 00000001.png + ├──── ... +├── sharp_imgs + ├── 000 + ├──── 00000000.png + ├──── 00000001.png + ├──── ... + ├── 001 + ├──── 00000000.png + ├──── 00000001.png + ├──── ... +``` +where `root`, `blur_imgs`, and `sharp_imgs` folders can have arbitrary names. For example, let `root, blur_imgs, sharp_imgs` be `REDS, train_blur, train_sharp` respectively (That is, you are using the REDS training set), then use the following scripts to create the lmdb dataset: +```sh +python create_lmdb.py --H 720 --W 1280 --C 3 --img_folder REDS/train_sharp --name train_sharp_wval --save_path ../datasets/REDS/train_sharp_wval.lmdb +python create_lmdb.py --H 720 --W 1280 --C 3 --img_folder REDS/train_blur --name train_blur_wval --save_path ../datasets/REDS/train_blur_wval.lmdb +``` +where `(H, C, W)` is the shape of the images (note that all images in the dataset must have the same shape), `img_folder` is the folder that contains the images, `name` is the name of the dataset, and `save_path` is the save destination (`save_path` must end with `.lmdb`). + +When the script is finished, two folders `train_sharp_wval.lmdb` and `train_blur_wval.lmdb` will be created in `./REDS`. + + +### Training +To do image deblurring, data augmentation, and blur generation, you first need to train the blur encoding network (The F function in the paper). This is the only network that you need to train. After creating the dataset, change the value of `dataroot_HQ` and `dataroot_LQ` in `options/kernel_encoding/REDS/woVAE.yml` to the paths of the sharp and blur lmdb datasets that were created before, then use the following script to train the model: +``` +python train.py -opt options/kernel_encoding/REDS/woVAE.yml +``` + +where `opt` is the path to yaml file that contains training configurations. You can find some default configurations in the `options` folder. Checkpoints, training states, and logs will be saved in `experiments/modelName`. You can change the configurations (learning rate, hyper-parameters, network structure, etc) in the yaml file. + +### Testing +#### Data augmentation +To augment a given dataset, first, create an lmdb dataset using `scripts/create_lmdb.py` as before. Then use the following script: +``` +python data_augmentation.py --target_H=720 --target_W=1280 \ + --source_H=720 --source_W=1280\ + --augmented_H=256 --augmented_W=256\ + --source_LQ_root=datasets/REDS/train_blur_wval.lmdb \ + --source_HQ_root=datasets/REDS/train_sharp_wval.lmdb \ + --target_HQ_root=datasets/REDS/test_sharp_wval.lmdb \ + --save_path=results/GOPRO_augmented \ + --num_images=10 \ + --yml_path=options/data_augmentation/default.yml +``` +`(target_H, target_W)`, `(source_H, source_W)`, and `(augmented_H, augmented_W)` are the desired shapes of the target images, source images, and augmented images respectively. `source_LQ_root`, `source_HQ_root`, and `target_HQ_root` are the paths of the lmdb datasets for the reference blur-sharp pairs and the input sharp images that were created before. `num_images` is the size of the augmented dataset. `model_path` is the path of the trained model. `yml_path` is the path to the model configuration file. Results will be saved in `save_path`. + +![Data augmentation examples](imgs/results/augmentation.jpg) + +#### Generate novel blur kernels +To generate a blur image given a sharp image, use the following command: +```sh +python generate_blur.py --yml_path=options/generate_blur/default.yml \ + --image_path=imgs/sharp_imgs/mushishi.png \ + --num_samples=10 + --save_path=./res.png +``` +where `model_path` is the path of the pre-trained model, `yml_path` is the path of the configuration file. `image_path` is the path of the sharp image. After running the script, a blur image corresponding to the sharp image will be saved in `save_path`. Here is some expected output: +![kernel generating examples](imgs/results/generate_blur.jpg) +**Note**: This only works with models that were trained with `--VAE` flag. The size of input images must be divisible by 128. + +#### Generic Deblurring +To deblur a blurry image, use the following command: +```sh +python generic_deblur.py --image_path imgs/blur_imgs/blur1.png --yml_path options/generic_deblur/default.yml --save_path ./res.png +``` +where `image_path` is the path of the blurry image. `yml_path` is the path of the configuration file. The deblurred image will be saved to `save_path`. + +![Image deblurring examples](imgs/results/general_deblurring.jpg) + +#### Deblurring using sharp image prior +[mapping]: https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k +[synthesis]: https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8 +[pretrained model]: https://drive.google.com/file/d/1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO/view +First, you need to download the pre-trained styleGAN or styleGAN2 networks. If you want to use styleGAN, download the [mapping] and [synthesis] networks, then rename and copy them to `experiments/pretrained/stylegan_mapping.pt` and `experiments/pretrained/stylegan_synthesis.pt` respectively. If you want to use styleGAN2 instead, download the [pretrained model], then rename and copy it to `experiments/pretrained/stylegan2.pt`. + +To deblur a blurry image using styleGAN latent space as the sharp image prior, you can use one of the following commands: +```sh +python domain_specific_deblur.py --input_dir imgs/blur_faces \ + --output_dir experiments/domain_specific_deblur/results \ + --yml_path options/domain_specific_deblur/stylegan.yml # Use latent space of stylegan +python domain_specific_deblur.py --input_dir imgs/blur_faces \ + --output_dir experiments/domain_specific_deblur/results \ + --yml_path options/domain_specific_deblur/stylegan2.yml # Use latent space of stylegan2 +``` +Results will be saved in `experiments/domain_specific_deblur/results`. +**Note**: Generally, the code still works with images that have the size divisible by 128. However, since our blur kernels are not uniform, the size of the kernel increases as the size of the image increases. + +![PULSE-like Deblurring examples](imgs/results/domain_specific_deblur.jpg) + +## Model Zoo +Pretrained models and corresponding datasets are provided in the below table. After downloading the datasets and models, follow the instructions in the [testing section](#testing) to do data augmentation, generating blur images, or image deblurring. + +[REDS]: https://seungjunnah.github.io/Datasets/reds.html +[GOPRO]: https://seungjunnah.github.io/Datasets/gopro + +[REDS woVAE]: https://drive.google.com/file/d/12ZhjXWcYhAZjBnMtF0ai0R5PQydZct61/view?usp=sharing +[GOPRO woVAE]: https://drive.google.com/file/d/1WrVALP-woJgtiZyvQ7NOkaZssHbHwKYn/view?usp=sharing +[GOPRO wVAE]: https://drive.google.com/file/d/1QMUY8mxUMgEJty2Gk7UY0WYmyyYRY7vS/view?usp=sharing +[GOPRO + REDS woVAE]: https://drive.google.com/file/d/169R0hEs3rNeloj-m1rGS4YjW38pu-LFD/view?usp=sharing + +|Model name | dataset(s) | status | +|:-----------------------|:---------------:|-------------------------:| +|[REDS woVAE] | [REDS] | :heavy_check_mark: | +|[GOPRO woVAE] | [GOPRO] | :heavy_check_mark: | +|[GOPRO wVAE] | [GOPRO] | :heavy_check_mark: | +|[GOPRO + REDS woVAE] | [GOPRO], [REDS] | :heavy_check_mark: | + + +## Notes and references +The training code is borrowed from the EDVR project: https://github.com/xinntao/EDVR + +The backbone code is borrowed from the DeblurGAN project: https://github.com/KupynOrest/DeblurGAN + +The styleGAN code is borrowed from the PULSE project: https://github.com/adamian98/pulse + +The stylegan2 code is borrowed from https://github.com/rosinality/stylegan2-pytorch diff --git a/diffusion-posterior-sampling/bkse/data/GOPRO_dataset.py b/diffusion-posterior-sampling/bkse/data/GOPRO_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0c816b37aabd901b570509ce831020e29274f61c --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data/GOPRO_dataset.py @@ -0,0 +1,135 @@ +""" +GOPRO dataset +support reading images from lmdb, image folder and memcached +""" +import logging +import os.path as osp +import pickle +import random + +import cv2 +import data.util as util +import lmdb +import numpy as np +import torch +import torch.utils.data as data + + +try: + import mc # import memcached +except ImportError: + pass + +logger = logging.getLogger("base") + + +class GOPRODataset(data.Dataset): + """ + Reading the training GOPRO dataset + key example: 000_00000000 + HQ: Ground-Truth; + LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames + support reading N LQ frames, N = 1, 3, 5, 7 + """ + + def __init__(self, opt): + super(GOPRODataset, self).__init__() + self.opt = opt + # temporal augmentation + + self.HQ_root, self.LQ_root = opt["dataroot_HQ"], opt["dataroot_LQ"] + self.N_frames = opt["N_frames"] + self.data_type = self.opt["data_type"] + # directly load image keys + if self.data_type == "lmdb": + self.paths_HQ, _ = util.get_image_paths(self.data_type, opt["dataroot_HQ"]) + logger.info("Using lmdb meta info for cache keys.") + elif opt["cache_keys"]: + logger.info("Using cache keys: {}".format(opt["cache_keys"])) + self.paths_HQ = pickle.load(open(opt["cache_keys"], "rb"))["keys"] + else: + raise ValueError( + "Need to create cache keys (meta_info.pkl) \ + by running [create_lmdb.py]" + ) + + assert self.paths_HQ, "Error: HQ path is empty." + + if self.data_type == "lmdb": + self.HQ_env, self.LQ_env = None, None + elif self.data_type == "mc": # memcached + self.mclient = None + elif self.data_type == "img": + pass + else: + raise ValueError("Wrong data type: {}".format(self.data_type)) + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.HQ_env = lmdb.open(self.opt["dataroot_HQ"], readonly=True, lock=False, readahead=False, meminit=False) + self.LQ_env = lmdb.open(self.opt["dataroot_LQ"], readonly=True, lock=False, readahead=False, meminit=False) + + def _ensure_memcached(self): + if self.mclient is None: + # specify the config files + server_list_config_file = None + client_config_file = None + self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) + + def _read_img_mc(self, path): + """ Return BGR, HWC, [0, 255], uint8""" + value = mc.pyvector() + self.mclient.Get(path, value) + value_buf = mc.ConvertBuffer(value) + img_array = np.frombuffer(value_buf, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) + return img + + def _read_img_mc_BGR(self, path, name_a, name_b): + """ + Read BGR channels separately and then combine for 1M limits in cluster + """ + img_B = self._read_img_mc(osp.join(path + "_B", name_a, name_b + ".png")) + img_G = self._read_img_mc(osp.join(path + "_G", name_a, name_b + ".png")) + img_R = self._read_img_mc(osp.join(path + "_R", name_a, name_b + ".png")) + img = cv2.merge((img_B, img_G, img_R)) + return img + + def __getitem__(self, index): + if self.data_type == "mc": + self._ensure_memcached() + elif self.data_type == "lmdb" and (self.HQ_env is None or self.LQ_env is None): + self._init_lmdb() + + HQ_size = self.opt["HQ_size"] + key = self.paths_HQ[index] + + # get the HQ image (as the center frame) + img_HQ = util.read_img(self.HQ_env, key, (3, 720, 1280)) + + # get LQ images + img_LQ = util.read_img(self.LQ_env, key, (3, 720, 1280)) + + if self.opt["phase"] == "train": + _, H, W = 3, 720, 1280 # LQ size + # randomly crop + rnd_h = random.randint(0, max(0, H - HQ_size)) + rnd_w = random.randint(0, max(0, W - HQ_size)) + img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] + img_HQ = img_HQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] + + # augmentation - flip, rotate + imgs = [img_HQ, img_LQ] + rlt = util.augment(imgs, self.opt["use_flip"], self.opt["use_rot"]) + img_HQ = rlt[0] + img_LQ = rlt[1] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_HQ = img_HQ[:, :, [2, 1, 0]] + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + img_HQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQ, (2, 0, 1)))).float() + return {"LQ": img_LQ, "HQ": img_HQ, "key": key} + + def __len__(self): + return len(self.paths_HQ) diff --git a/diffusion-posterior-sampling/bkse/data/REDS_dataset.py b/diffusion-posterior-sampling/bkse/data/REDS_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b6166df447003fd4c9533eb18b9ed58b5fa2ef --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data/REDS_dataset.py @@ -0,0 +1,139 @@ +""" +REDS dataset +support reading images from lmdb, image folder and memcached +""" +import logging +import os.path as osp +import pickle +import random + +import cv2 +import data.util as util +import lmdb +import numpy as np +import torch +import torch.utils.data as data + + +try: + import mc # import memcached +except ImportError: + pass + +logger = logging.getLogger("base") + + +class REDSDataset(data.Dataset): + """ + Reading the training REDS dataset + key example: 000_00000000 + HQ: Ground-Truth; + LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames + support reading N LQ frames, N = 1, 3, 5, 7 + """ + + def __init__(self, opt): + super(REDSDataset, self).__init__() + self.opt = opt + # temporal augmentation + + self.HQ_root, self.LQ_root = opt["dataroot_HQ"], opt["dataroot_LQ"] + self.N_frames = opt["N_frames"] + self.data_type = self.opt["data_type"] + # directly load image keys + if self.data_type == "lmdb": + self.paths_HQ, _ = util.get_image_paths(self.data_type, opt["dataroot_HQ"]) + logger.info("Using lmdb meta info for cache keys.") + elif opt["cache_keys"]: + logger.info("Using cache keys: {}".format(opt["cache_keys"])) + self.paths_HQ = pickle.load(open(opt["cache_keys"], "rb"))["keys"] + else: + raise ValueError( + "Need to create cache keys (meta_info.pkl) \ + by running [create_lmdb.py]" + ) + + # remove the REDS4 for testing + self.paths_HQ = [v for v in self.paths_HQ if v.split("_")[0] not in ["000", "011", "015", "020"]] + assert self.paths_HQ, "Error: HQ path is empty." + + if self.data_type == "lmdb": + self.HQ_env, self.LQ_env = None, None + elif self.data_type == "mc": # memcached + self.mclient = None + elif self.data_type == "img": + pass + else: + raise ValueError("Wrong data type: {}".format(self.data_type)) + + def _init_lmdb(self): + # https://github.com/chainer/chainermn/issues/129 + self.HQ_env = lmdb.open(self.opt["dataroot_HQ"], readonly=True, lock=False, readahead=False, meminit=False) + self.LQ_env = lmdb.open(self.opt["dataroot_LQ"], readonly=True, lock=False, readahead=False, meminit=False) + + def _ensure_memcached(self): + if self.mclient is None: + # specify the config files + server_list_config_file = None + client_config_file = None + self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file) + + def _read_img_mc(self, path): + """ Return BGR, HWC, [0, 255], uint8""" + value = mc.pyvector() + self.mclient.Get(path, value) + value_buf = mc.ConvertBuffer(value) + img_array = np.frombuffer(value_buf, np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) + return img + + def _read_img_mc_BGR(self, path, name_a, name_b): + """ + Read BGR channels separately and then combine for 1M limits in cluster + """ + img_B = self._read_img_mc(osp.join(path + "_B", name_a, name_b + ".png")) + img_G = self._read_img_mc(osp.join(path + "_G", name_a, name_b + ".png")) + img_R = self._read_img_mc(osp.join(path + "_R", name_a, name_b + ".png")) + img = cv2.merge((img_B, img_G, img_R)) + return img + + def __getitem__(self, index): + if self.data_type == "mc": + self._ensure_memcached() + elif self.data_type == "lmdb" and (self.HQ_env is None or self.LQ_env is None): + self._init_lmdb() + + HQ_size = self.opt["HQ_size"] + key = self.paths_HQ[index] + name_a, name_b = key.split("_") + + # get the HQ image + img_HQ = util.read_img(self.HQ_env, key, (3, 720, 1280)) + + # get the LQ image + img_LQ = util.read_img(self.LQ_env, key, (3, 720, 1280)) + + if self.opt["phase"] == "train": + _, H, W = 3, 720, 1280 # LQ size + # randomly crop + rnd_h = random.randint(0, max(0, H - HQ_size)) + rnd_w = random.randint(0, max(0, W - HQ_size)) + img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] + img_HQ = img_HQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] + + # augmentation - flip, rotate + imgs = [img_HQ, img_LQ] + rlt = util.augment(imgs, self.opt["use_flip"], self.opt["use_rot"]) + img_HQ = rlt[0] + img_LQ = rlt[1] + + # BGR to RGB, HWC to CHW, numpy to tensor + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_HQ = img_HQ[:, :, [2, 1, 0]] + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + img_HQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQ, (2, 0, 1)))).float() + + return {"LQ": img_LQ, "HQ": img_HQ} + + def __len__(self): + return len(self.paths_HQ) diff --git a/diffusion-posterior-sampling/bkse/data/__init__.py b/diffusion-posterior-sampling/bkse/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..f90b5f425efc34749e473d687381d013fa5784b8 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data/__init__.py @@ -0,0 +1,53 @@ +"""create dataset and dataloader""" +import logging + +import torch +import torch.utils.data + + +def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): + phase = dataset_opt["phase"] + if phase == "train": + if opt["dist"]: + world_size = torch.distributed.get_world_size() + num_workers = dataset_opt["n_workers"] + assert dataset_opt["batch_size"] % world_size == 0 + batch_size = dataset_opt["batch_size"] // world_size + shuffle = False + else: + num_workers = dataset_opt["n_workers"] * len(opt["gpu_ids"]) + batch_size = dataset_opt["batch_size"] + shuffle = True + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + sampler=sampler, + drop_last=True, + pin_memory=False, + ) + else: + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=False) + + +def create_dataset(dataset_opt): + mode = dataset_opt["mode"] + # datasets for image restoration + if mode == "REDS": + from data.REDS_dataset import REDSDataset as D + elif mode == "GOPRO": + from data.GOPRO_dataset import GOPRODataset as D + elif mode == "fewshot": + from data.fewshot_dataset import FewShotDataset as D + elif mode == "levin": + from data.levin_dataset import LevinDataset as D + elif mode == "mix": + from data.mix_dataset import MixDataset as D + else: + raise NotImplementedError(f"Dataset {mode} is not recognized.") + dataset = D(dataset_opt) + + logger = logging.getLogger("base") + logger.info("Dataset [{:s} - {:s}] is created.".format(dataset.__class__.__name__, dataset_opt["name"])) + return dataset diff --git a/diffusion-posterior-sampling/bkse/data/data_sampler.py b/diffusion-posterior-sampling/bkse/data/data_sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..e49f141f080e954f597ef29d4ee287b49dad17b0 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data/data_sampler.py @@ -0,0 +1,72 @@ +""" +Modified from torch.utils.data.distributed.DistributedSampler +Support enlarging the dataset for *iteration-oriented* training, +for saving time when restart the dataloader after each epoch +""" +import math + +import torch +import torch.distributed as dist +from torch.utils.data.sampler import Sampler + + +class DistIterSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError( + "Requires distributed \ + package to be available" + ) + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError( + "Requires distributed \ + package to be available" + ) + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dsize = len(self.dataset) + indices = [v % dsize for v in indices] + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/diffusion-posterior-sampling/bkse/data/mix_dataset.py b/diffusion-posterior-sampling/bkse/data/mix_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6146f431540ae1814fcb541a2e80a909907fb5 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data/mix_dataset.py @@ -0,0 +1,104 @@ +""" +Mix dataset +support reading images from lmdb +""" +import logging +import random + +import data.util as util +import lmdb +import numpy as np +import torch +import torch.utils.data as data + + +logger = logging.getLogger("base") + + +class MixDataset(data.Dataset): + """ + Reading the training REDS dataset + key example: 000_00000000 + HQ: Ground-Truth; + LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames + support reading N LQ frames, N = 1, 3, 5, 7 + """ + + def __init__(self, opt): + super(MixDataset, self).__init__() + self.opt = opt + # temporal augmentation + + self.HQ_roots = opt["dataroots_HQ"] + self.LQ_roots = opt["dataroots_LQ"] + self.use_identical = opt["identical_loss"] + dataset_weights = opt["dataset_weights"] + self.data_type = "lmdb" + # directly load image keys + self.HQ_envs, self.LQ_envs = None, None + self.paths_HQ = [] + for idx, (HQ_root, LQ_root) in enumerate(zip(self.HQ_roots, self.LQ_roots)): + paths_HQ, _ = util.get_image_paths(self.data_type, HQ_root) + self.paths_HQ += list(zip([idx] * len(paths_HQ), paths_HQ)) * dataset_weights[idx] + random.shuffle(self.paths_HQ) + logger.info("Using lmdb meta info for cache keys.") + + def _init_lmdb(self): + self.HQ_envs, self.LQ_envs = [], [] + for HQ_root, LQ_root in zip(self.HQ_roots, self.LQ_roots): + self.HQ_envs.append(lmdb.open(HQ_root, readonly=True, lock=False, readahead=False, meminit=False)) + self.LQ_envs.append(lmdb.open(LQ_root, readonly=True, lock=False, readahead=False, meminit=False)) + + def __getitem__(self, index): + if self.HQ_envs is None: + self._init_lmdb() + + HQ_size = self.opt["HQ_size"] + env_idx, key = self.paths_HQ[index] + name_a, name_b = key.split("_") + target_frame_idx = int(name_b) + + # determine the neighbor frames + # ensure not exceeding the borders + neighbor_list = [target_frame_idx] + name_b = "{:08d}".format(neighbor_list[0]) + + # get the HQ image (as the center frame) + img_HQ_l = [] + for v in neighbor_list: + img_HQ = util.read_img(self.HQ_envs[env_idx], "{}_{:08d}".format(name_a, v), (3, 720, 1280)) + img_HQ_l.append(img_HQ) + + # get LQ images + img_LQ = util.read_img(self.LQ_envs[env_idx], "{}_{:08d}".format(name_a, neighbor_list[-1]), (3, 720, 1280)) + if self.opt["phase"] == "train": + _, H, W = 3, 720, 1280 # LQ size + # randomly crop + rnd_h = random.randint(0, max(0, H - HQ_size)) + rnd_w = random.randint(0, max(0, W - HQ_size)) + img_LQ = img_LQ[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] + img_HQ_l = [v[rnd_h : rnd_h + HQ_size, rnd_w : rnd_w + HQ_size, :] for v in img_HQ_l] + + # augmentation - flip, rotate + img_HQ_l.append(img_LQ) + rlt = util.augment(img_HQ_l, self.opt["use_flip"], self.opt["use_rot"]) + img_HQ_l = rlt[0:-1] + img_LQ = rlt[-1] + + # stack LQ images to NHWC, N is the frame number + img_HQs = np.stack(img_HQ_l, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_LQ = img_LQ[:, :, [2, 1, 0]] + img_HQs = img_HQs[:, :, :, [2, 1, 0]] + img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() + img_HQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HQs, (0, 3, 1, 2)))).float() + # print(img_LQ.shape, img_HQs.shape) + + if self.use_identical and np.random.randint(0, 10) == 0: + img_LQ = img_HQs[-1, :, :, :] + return {"LQ": img_LQ, "HQs": img_HQs, "identical_w": 10} + + return {"LQ": img_LQ, "HQs": img_HQs, "identical_w": 0} + + def __len__(self): + return len(self.paths_HQ) diff --git a/diffusion-posterior-sampling/bkse/data/util.py b/diffusion-posterior-sampling/bkse/data/util.py new file mode 100755 index 0000000000000000000000000000000000000000..c65d7cfd51356e92e0d64510e91e32bc2c538150 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data/util.py @@ -0,0 +1,574 @@ +import glob +import math +import os +import pickle +import random + +import cv2 +import numpy as np +import torch + + +#################### +# Files & IO +#################### + +# get image path list +IMG_EXTENSIONS = [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def _get_paths_from_images(path): + """get image path list from image folder""" + assert os.path.isdir(path), "{:s} is not a valid directory".format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, "{:s} has no valid image file".format(path) + return images + + +def _get_paths_from_lmdb(dataroot): + """get image path list from lmdb meta info""" + meta_info = pickle.load(open(os.path.join(dataroot, "meta_info.pkl"), "rb")) + paths = meta_info["keys"] + sizes = meta_info["resolution"] + if len(sizes) == 1: + sizes = sizes * len(paths) + return paths, sizes + + +def get_image_paths(data_type, dataroot): + """get image path list + support lmdb or image files""" + paths, sizes = None, None + if dataroot is not None: + if data_type == "lmdb": + paths, sizes = _get_paths_from_lmdb(dataroot) + elif data_type == "img": + paths = sorted(_get_paths_from_images(dataroot)) + else: + raise NotImplementedError( + f"data_type {data_type} \ + is not recognized." + ) + return paths, sizes + + +def glob_file_list(root): + return sorted(glob.glob(os.path.join(root, "*"))) + + +# read images +def _read_img_lmdb(env, key, size): + """read image from lmdb with key (w/ and w/o fixed size) + size: (C, H, W) tuple""" + with env.begin(write=False) as txn: + buf = txn.get(key.encode("ascii")) + if buf is None: + print(key) + img_flat = np.frombuffer(buf, dtype=np.uint8) + C, H, W = size + img = img_flat.reshape(H, W, C) + return img + + +def read_img(env, path, size=None): + """read image by cv2 or from lmdb + return: Numpy float32, HWC, BGR, [0,1]""" + if env is None: # img + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + else: + img = _read_img_lmdb(env, path, size) + img = img.astype(np.float32) / 255.0 + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +def read_img_gray(env, path, size=None): + """read image by cv2 or from lmdb + return: Numpy float32, HWC, BGR, [0,1]""" + img = _read_img_lmdb(env, path, size) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = img.astype(np.float32) / 255.0 + img = img[:, :, np.newaxis] + return img + + +def read_img_seq(path): + """Read a sequence of images from a given folder path + Args: + path (list/str): list of image paths/image folder path + + Returns: + imgs (Tensor): size (T, C, H, W), RGB, [0, 1] + """ + if type(path) is list: + img_path_l = path + else: + img_path_l = sorted(glob.glob(os.path.join(path, "*"))) + img_l = [read_img(None, v) for v in img_path_l] + # stack to Torch tensor + imgs = np.stack(img_l, axis=0) + imgs = imgs[:, :, :, [2, 1, 0]] + imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() + return imgs + + +def index_generation(crt_i, max_n, N, padding="reflection"): + """Generate an index list for reading N frames from a sequence of images + Args: + crt_i (int): current center index + max_n (int): max number of the sequence of images (calculated from 1) + N (int): reading N frames + padding (str): padding mode, one of + replicate | reflection | new_info | circle + Example: crt_i = 0, N = 5 + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + new_info: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + return_l (list [int]): a list of indexes + """ + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == "replicate": + add_idx = 0 + elif padding == "reflection": + add_idx = -i + elif padding == "new_info": + add_idx = (crt_i + n_pad) + (-i) + elif padding == "circle": + add_idx = N + i + else: + raise ValueError("Wrong padding mode") + elif i > max_n: + if padding == "replicate": + add_idx = max_n + elif padding == "reflection": + add_idx = max_n * 2 - i + elif padding == "new_info": + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == "circle": + add_idx = i - N + else: + raise ValueError("Wrong padding mode") + else: + add_idx = i + return_l.append(add_idx) + return return_l + + +#################### +# image processing +# process on numpy image +#################### + + +def augment(img_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +def augment_flow(img_list, flow_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: + flow = flow[:, ::-1, :] + flow[:, :, 0] *= -1 + if vflip: + flow = flow[::-1, :, :] + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + rlt_img_list = [_augment(img) for img in img_list] + rlt_flow_list = [_augment_flow(flow) for flow in flow_list] + + return rlt_img_list, rlt_flow_list + + +def channel_convert(in_c, tar_type, img_list): + """conversion among BGR, gray and y""" + if in_c == 3 and tar_type == "gray": # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == "y": # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == "RGB": # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +def rgb2ycbcr(img, only_y=True): + """same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255.0 + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]] + ) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255.0 + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + """bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255.0 + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]] + ) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255.0 + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + """same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + """ + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255.0 + # convert + rlt = np.matmul( + img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]] + ) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255.0 + return rlt.astype(in_img_type) + + +def modcrop(img_in, scale): + """img_in: Numpy, HWC or HW""" + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[: H - H_r, : W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[: H - H_r, : W - W_r, :] + else: + raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) + return img + + +#################### +# Functions +#################### + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx ** 2 + absx3 = absx ** 3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * (((absx > 1) * (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + """ + Use a modified kernel to simultaneously interpolate + and antialias- larger kernel width + """ + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P + ) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. + # Only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: CHW RGB [0,1] + # output: CHW RGB [0,1] w/o round + + in_C, in_H, in_W = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = "cubic" + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[0, i, :] = img_aug[0, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[1, i, :] = img_aug[1, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + out_1[2, i, :] = img_aug[2, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[0, :, i] = out_1_aug[0, :, idx : idx + kernel_width].mv(weights_W[i]) + out_2[1, :, i] = out_1_aug[1, :, idx : idx + kernel_width].mv(weights_W[i]) + out_2[2, :, i] = out_1_aug[2, :, idx : idx + kernel_width].mv(weights_W[i]) + + return out_2 + + +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC BGR [0,1] + # output: HWC BGR [0,1] w/o round + img = torch.from_numpy(img) + + in_H, in_W, in_C = img.size() + _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = "cubic" + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing + ) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing + ) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + out_1[i, :, 0] = img_aug[idx : idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 1] = img_aug[idx : idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) + out_1[i, :, 2] = img_aug[idx : idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + out_2[:, i, 0] = out_1_aug[:, idx : idx + kernel_width, 0].mv(weights_W[i]) + out_2[:, i, 1] = out_1_aug[:, idx : idx + kernel_width, 1].mv(weights_W[i]) + out_2[:, i, 2] = out_1_aug[:, idx : idx + kernel_width, 2].mv(weights_W[i]) + + return out_2.numpy() + + +if __name__ == "__main__": + # test imresize function + # read images + img = cv2.imread("test.png") + img = img * 1.0 / 255 + img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() + # imresize + scale = 1 / 4 + import time + + total_time = 0 + for i in range(10): + start_time = time.time() + rlt = imresize(img, scale, antialiasing=True) + use_time = time.time() - start_time + total_time += use_time + print("average time: {}".format(total_time / 10)) + + import torchvision.utils + + torchvision.utils.save_image((rlt * 255).round() / 255, "rlt.png", nrow=1, padding=0, normalize=False) diff --git a/diffusion-posterior-sampling/bkse/data_augmentation.py b/diffusion-posterior-sampling/bkse/data_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8c6d48d4ea1dbc23e5184706ac830db9699ef2 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/data_augmentation.py @@ -0,0 +1,145 @@ +import argparse +import logging +import os +import os.path as osp +import random + +import cv2 +import data.util as data_util +import lmdb +import numpy as np +import torch +import utils.util as util +import yaml +from models.kernel_encoding.kernel_wizard import KernelWizard + + +def read_image(env, key, x, y, h, w): + img = data_util.read_img(env, key, (3, 720, 1280)) + img = np.transpose(img[x : x + h, y : y + w, [2, 1, 0]], (2, 0, 1)) + return img + + +def main(): + device = torch.device("cuda") + + parser = argparse.ArgumentParser(description="Kernel extractor testing") + + parser.add_argument("--source_H", action="store", help="source image height", type=int, required=True) + parser.add_argument("--source_W", action="store", help="source image width", type=int, required=True) + parser.add_argument("--target_H", action="store", help="target image height", type=int, required=True) + parser.add_argument("--target_W", action="store", help="target image width", type=int, required=True) + parser.add_argument( + "--augmented_H", action="store", help="desired height of the augmented images", type=int, required=True + ) + parser.add_argument( + "--augmented_W", action="store", help="desired width of the augmented images", type=int, required=True + ) + + parser.add_argument( + "--source_LQ_root", action="store", help="source low-quality dataroot", type=str, required=True + ) + parser.add_argument( + "--source_HQ_root", action="store", help="source high-quality dataroot", type=str, required=True + ) + parser.add_argument( + "--target_HQ_root", action="store", help="target high-quality dataroot", type=str, required=True + ) + parser.add_argument("--save_path", action="store", help="save path", type=str, required=True) + parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) + parser.add_argument( + "--num_images", action="store", help="number of desire augmented images", type=int, required=True + ) + + args = parser.parse_args() + + source_LQ_root = args.source_LQ_root + source_HQ_root = args.source_HQ_root + target_HQ_root = args.target_HQ_root + + save_path = args.save_path + source_H, source_W = args.source_H, args.source_W + target_H, target_W = args.target_H, args.target_W + augmented_H, augmented_W = args.augmented_H, args.augmented_W + yml_path = args.yml_path + num_images = args.num_images + + # Initializing logger + logger = logging.getLogger("base") + os.makedirs(save_path, exist_ok=True) + util.setup_logger("base", save_path, "test", level=logging.INFO, screen=True, tofile=True) + logger.info("source LQ root: {}".format(source_LQ_root)) + logger.info("source HQ root: {}".format(source_HQ_root)) + logger.info("target HQ root: {}".format(target_HQ_root)) + logger.info("augmented height: {}".format(augmented_H)) + logger.info("augmented width: {}".format(augmented_W)) + logger.info("Number of augmented images: {}".format(num_images)) + + # Initializing mode + logger.info("Loading model...") + with open(yml_path, "r") as f: + print(yml_path) + opt = yaml.load(f)["KernelWizard"] + model_path = opt["pretrained"] + model = KernelWizard(opt) + model.eval() + model.load_state_dict(torch.load(model_path)) + model = model.to(device) + logger.info("Done") + + # processing data + source_HQ_env = lmdb.open(source_HQ_root, readonly=True, lock=False, readahead=False, meminit=False) + source_LQ_env = lmdb.open(source_LQ_root, readonly=True, lock=False, readahead=False, meminit=False) + target_HQ_env = lmdb.open(target_HQ_root, readonly=True, lock=False, readahead=False, meminit=False) + paths_source_HQ, _ = data_util.get_image_paths("lmdb", source_HQ_root) + paths_target_HQ, _ = data_util.get_image_paths("lmdb", target_HQ_root) + + psnr_avg = 0 + + for i in range(num_images): + source_key = np.random.choice(paths_source_HQ) + target_key = np.random.choice(paths_target_HQ) + + source_rnd_h = random.randint(0, max(0, source_H - augmented_H)) + source_rnd_w = random.randint(0, max(0, source_W - augmented_W)) + target_rnd_h = random.randint(0, max(0, target_H - augmented_H)) + target_rnd_w = random.randint(0, max(0, target_W - augmented_W)) + + source_LQ = read_image(source_LQ_env, source_key, source_rnd_h, source_rnd_w, augmented_H, augmented_W) + source_HQ = read_image(source_HQ_env, source_key, source_rnd_h, source_rnd_w, augmented_H, augmented_W) + target_HQ = read_image(target_HQ_env, target_key, target_rnd_h, target_rnd_w, augmented_H, augmented_W) + + source_LQ = torch.Tensor(source_LQ).unsqueeze(0).to(device) + source_HQ = torch.Tensor(source_HQ).unsqueeze(0).to(device) + target_HQ = torch.Tensor(target_HQ).unsqueeze(0).to(device) + + with torch.no_grad(): + kernel_mean, kernel_sigma = model(source_HQ, source_LQ) + kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean) + fake_source_LQ = model.adaptKernel(source_HQ, kernel) + target_LQ = model.adaptKernel(target_HQ, kernel) + + LQ_img = util.tensor2img(source_LQ) + fake_LQ_img = util.tensor2img(fake_source_LQ) + target_LQ_img = util.tensor2img(target_LQ) + target_HQ_img = util.tensor2img(target_HQ) + + target_HQ_dst = osp.join(save_path, "sharp/{:03d}/{:08d}.png".format(i // 100, i % 100)) + target_LQ_dst = osp.join(save_path, "blur/{:03d}/{:08d}.png".format(i // 100, i % 100)) + + os.makedirs(osp.dirname(target_HQ_dst), exist_ok=True) + os.makedirs(osp.dirname(target_LQ_dst), exist_ok=True) + + cv2.imwrite(target_HQ_dst, target_HQ_img) + cv2.imwrite(target_LQ_dst, target_LQ_img) + # torch.save(kernel, osp.join(osp.dirname(target_LQ_dst), f'kernel{i:03d}.pth')) + + psnr = util.calculate_psnr(LQ_img, fake_LQ_img) + + logger.info("Reconstruction PSNR of image #{:03d}/{:03d}: {:.2f}db".format(i, num_images, psnr)) + psnr_avg += psnr + + logger.info("Average reconstruction PSNR: {:.2f}db".format(psnr_avg / num_images)) + + +main() diff --git a/diffusion-posterior-sampling/bkse/domain_specific_deblur.py b/diffusion-posterior-sampling/bkse/domain_specific_deblur.py new file mode 100644 index 0000000000000000000000000000000000000000..e45dcd256b61ad94f92bbdc886f08459ab31c8a6 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/domain_specific_deblur.py @@ -0,0 +1,89 @@ +import argparse +from math import ceil, log10 +from pathlib import Path + +import torchvision +import yaml +from PIL import Image +from torch.nn import DataParallel +from torch.utils.data import DataLoader, Dataset + + +class Images(Dataset): + def __init__(self, root_dir, duplicates): + self.root_path = Path(root_dir) + self.image_list = list(self.root_path.glob("*.png")) + self.duplicates = ( + duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images + ) + + def __len__(self): + return self.duplicates * len(self.image_list) + + def __getitem__(self, idx): + img_path = self.image_list[idx // self.duplicates] + image = torchvision.transforms.ToTensor()(Image.open(img_path)) + if self.duplicates == 1: + return image, img_path.stem + else: + return image, img_path.stem + f"_{(idx % self.duplicates)+1}" + + +parser = argparse.ArgumentParser(description="PULSE") + +# I/O arguments +parser.add_argument("--input_dir", type=str, default="imgs/blur_faces", help="input data directory") +parser.add_argument( + "--output_dir", type=str, default="experiments/domain_specific_deblur/results", help="output data directory" +) +parser.add_argument( + "--cache_dir", + type=str, + default="experiments/domain_specific_deblur/cache", + help="cache directory for model weights", +) +parser.add_argument( + "--yml_path", type=str, default="options/domain_specific_deblur/stylegan2.yml", help="configuration file" +) + +kwargs = vars(parser.parse_args()) + +with open(kwargs["yml_path"], "rb") as f: + opt = yaml.safe_load(f) + +dataset = Images(kwargs["input_dir"], duplicates=opt["duplicates"]) +out_path = Path(kwargs["output_dir"]) +out_path.mkdir(parents=True, exist_ok=True) + +dataloader = DataLoader(dataset, batch_size=opt["batch_size"]) + +if opt["stylegan_ver"] == 1: + from models.dsd.dsd_stylegan import DSDStyleGAN + + model = DSDStyleGAN(opt=opt, cache_dir=kwargs["cache_dir"]) +else: + from models.dsd.dsd_stylegan2 import DSDStyleGAN2 + + model = DSDStyleGAN2(opt=opt, cache_dir=kwargs["cache_dir"]) + +model = DataParallel(model) + +toPIL = torchvision.transforms.ToPILImage() + +for ref_im, ref_im_name in dataloader: + if opt["save_intermediate"]: + padding = ceil(log10(100)) + for i in range(opt["batch_size"]): + int_path_HR = Path(out_path / ref_im_name[i] / "HR") + int_path_LR = Path(out_path / ref_im_name[i] / "LR") + int_path_HR.mkdir(parents=True, exist_ok=True) + int_path_LR.mkdir(parents=True, exist_ok=True) + for j, (HR, LR) in enumerate(model(ref_im)): + for i in range(opt["batch_size"]): + toPIL(HR[i].cpu().detach().clamp(0, 1)).save(int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png") + toPIL(LR[i].cpu().detach().clamp(0, 1)).save(int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png") + else: + # out_im = model(ref_im,**kwargs) + for j, (HR, LR) in enumerate(model(ref_im)): + for i in range(opt["batch_size"]): + toPIL(HR[i].cpu().detach().clamp(0, 1)).save(out_path / f"{ref_im_name[i]}.png") diff --git a/diffusion-posterior-sampling/bkse/experiments/pretrained/kernel.pth b/diffusion-posterior-sampling/bkse/experiments/pretrained/kernel.pth new file mode 100644 index 0000000000000000000000000000000000000000..243c5804bb5db2ac6548dcd8d638c3378de91971 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/experiments/pretrained/kernel.pth differ diff --git a/diffusion-posterior-sampling/bkse/generate_blur.py b/diffusion-posterior-sampling/bkse/generate_blur.py new file mode 100644 index 0000000000000000000000000000000000000000..5e162bcee94ec049c5303fcabceebef705aacc9b --- /dev/null +++ b/diffusion-posterior-sampling/bkse/generate_blur.py @@ -0,0 +1,53 @@ +import argparse + +import cv2 +import numpy as np +import os.path as osp +import torch +import utils.util as util +import yaml +from models.kernel_encoding.kernel_wizard import KernelWizard + + +def main(): + device = torch.device("cuda") + + parser = argparse.ArgumentParser(description="Kernel extractor testing") + + parser.add_argument("--image_path", action="store", help="image path", type=str, required=True) + parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) + parser.add_argument("--save_path", action="store", help="save path", type=str, default=".") + parser.add_argument("--num_samples", action="store", help="number of samples", type=int, default=1) + + args = parser.parse_args() + + image_path = args.image_path + yml_path = args.yml_path + num_samples = args.num_samples + + # Initializing mode + with open(yml_path, "r") as f: + opt = yaml.load(f)["KernelWizard"] + model_path = opt["pretrained"] + model = KernelWizard(opt) + model.eval() + model.load_state_dict(torch.load(model_path)) + model = model.to(device) + + HQ = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) / 255.0 + HQ = np.transpose(HQ, (2, 0, 1)) + HQ_tensor = torch.Tensor(HQ).unsqueeze(0).to(device).cuda() + + for i in range(num_samples): + print(f"Sample #{i}/{num_samples}") + with torch.no_grad(): + kernel = torch.randn((1, 512, 2, 2)).cuda() * 1.2 + LQ_tensor = model.adaptKernel(HQ_tensor, kernel) + + dst = osp.join(args.save_path, f"blur{i:03d}.png") + LQ_img = util.tensor2img(LQ_tensor) + + cv2.imwrite(dst, LQ_img) + + +main() diff --git a/diffusion-posterior-sampling/bkse/generic_deblur.py b/diffusion-posterior-sampling/bkse/generic_deblur.py new file mode 100644 index 0000000000000000000000000000000000000000..c384208de249ad7b001e9e580269ef090a81f86c --- /dev/null +++ b/diffusion-posterior-sampling/bkse/generic_deblur.py @@ -0,0 +1,28 @@ +import argparse + +import cv2 +import yaml +from models.deblurring.joint_deblur import JointDeblur + + +def main(): + parser = argparse.ArgumentParser(description="Kernel extractor testing") + + parser.add_argument("--image_path", action="store", help="image path", type=str, required=True) + parser.add_argument("--save_path", action="store", help="save path", type=str, default="res.png") + parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) + + args = parser.parse_args() + + # Initializing mode + with open(args.yml_path, "rb") as f: + opt = yaml.safe_load(f) + model = JointDeblur(opt) + + blur_img = cv2.cvtColor(cv2.imread(args.image_path), cv2.COLOR_BGR2RGB) + sharp_img = model.deblur(blur_img) + + cv2.imwrite(args.save_path, sharp_img) + + +main() diff --git a/diffusion-posterior-sampling/bkse/imgs/blur_faces/face01.png b/diffusion-posterior-sampling/bkse/imgs/blur_faces/face01.png new file mode 100644 index 0000000000000000000000000000000000000000..071cdcdc54c1204d78f6d69300e614402686b2a4 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/blur_faces/face01.png differ diff --git a/diffusion-posterior-sampling/bkse/imgs/blur_imgs/blur1.png b/diffusion-posterior-sampling/bkse/imgs/blur_imgs/blur1.png new file mode 100644 index 0000000000000000000000000000000000000000..810fed231f4977d62f8016c95ccf15b230b45a3e Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/blur_imgs/blur1.png differ diff --git a/diffusion-posterior-sampling/bkse/imgs/blur_imgs/blur2.png b/diffusion-posterior-sampling/bkse/imgs/blur_imgs/blur2.png new file mode 100644 index 0000000000000000000000000000000000000000..2e1e9319c4a26d9f92b055c8cdfdade751e4d793 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/blur_imgs/blur2.png differ diff --git a/diffusion-posterior-sampling/bkse/imgs/results/augmentation.jpg b/diffusion-posterior-sampling/bkse/imgs/results/augmentation.jpg new file mode 100644 index 0000000000000000000000000000000000000000..913b11c8aa1c1d6096851e2c8e7a4fbb9d5474a5 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/results/augmentation.jpg differ diff --git a/diffusion-posterior-sampling/bkse/imgs/results/domain_specific_deblur.jpg b/diffusion-posterior-sampling/bkse/imgs/results/domain_specific_deblur.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d0a29bc07fe4e57c071d25a66157e5a96fb19a3 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/results/domain_specific_deblur.jpg differ diff --git a/diffusion-posterior-sampling/bkse/imgs/results/general_deblurring.jpg b/diffusion-posterior-sampling/bkse/imgs/results/general_deblurring.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bb19e42376a5d342ccd8bcec9bc7ea3d495382bd Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/results/general_deblurring.jpg differ diff --git a/diffusion-posterior-sampling/bkse/imgs/results/generate_blur.jpg b/diffusion-posterior-sampling/bkse/imgs/results/generate_blur.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fe35bc45d81a0aa295e70b0f8fdfb0e7e558335a Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/results/generate_blur.jpg differ diff --git a/diffusion-posterior-sampling/bkse/imgs/results/kernel_encoding_wGT.png b/diffusion-posterior-sampling/bkse/imgs/results/kernel_encoding_wGT.png new file mode 100644 index 0000000000000000000000000000000000000000..66ea7be2c62dcdd0dc07443043f1fbec7bd2ef33 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/results/kernel_encoding_wGT.png differ diff --git a/diffusion-posterior-sampling/bkse/imgs/sharp_imgs/mushishi.png b/diffusion-posterior-sampling/bkse/imgs/sharp_imgs/mushishi.png new file mode 100644 index 0000000000000000000000000000000000000000..4898640f9cdfd81721f80ba639e2870a6055f728 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/sharp_imgs/mushishi.png differ diff --git a/diffusion-posterior-sampling/bkse/imgs/teaser.jpg b/diffusion-posterior-sampling/bkse/imgs/teaser.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2f53b05658c1963c1b38de4cd4c69feb19ae6795 Binary files /dev/null and b/diffusion-posterior-sampling/bkse/imgs/teaser.jpg differ diff --git a/diffusion-posterior-sampling/bkse/models/__init__.py b/diffusion-posterior-sampling/bkse/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81ac34183d164e666d42b5481e7f7e83ad15c183 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/__init__.py @@ -0,0 +1,15 @@ +import logging + + +logger = logging.getLogger("base") + + +def create_model(opt): + model = opt["model"] + if model == "image_base": + from models.kernel_encoding.image_base_model import ImageBaseModel as M + else: + raise NotImplementedError("Model [{:s}] not recognized.".format(model)) + m = M(opt) + logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) + return m diff --git a/diffusion-posterior-sampling/bkse/models/arch_util.py b/diffusion-posterior-sampling/bkse/models/arch_util.py new file mode 100644 index 0000000000000000000000000000000000000000..88e1d60d658bf413887d4c34977e47917ff1d4e2 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/arch_util.py @@ -0,0 +1,58 @@ +import functools + +import torch.nn as nn +import torch.nn.init as init + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type="instance"): + """Return a normalization layer + Parameters: + norm_type (str) -- the name of the normalization + layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and + track running statistics (mean/stddev). + + For InstanceNorm, we do not use learnable affine + parameters. We do not track running statistics. + """ + if norm_type == "batch": + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == "instance": + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == "none": + + def norm_layer(x): + return Identity() + + else: + raise NotImplementedError( + f"normalization layer {norm_type}\ + is not found" + ) + return norm_layer + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode="fan_in") + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode="fan_in") + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) diff --git a/diffusion-posterior-sampling/bkse/models/backbones/resnet.py b/diffusion-posterior-sampling/bkse/models/backbones/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b47a311a3314a148ebea702e93d19b46befdf7aa --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/resnet.py @@ -0,0 +1,89 @@ +import torch.nn as nn +import torch.nn.functional as F +from models.arch_util import initialize_weights + + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding + layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + Returns a conv block (with a conv layer, a normalization layer, + and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == "zero": + p = 1 + else: + raise NotImplementedError( + f"padding {padding_type} \ + is not implemented" + ) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == "reflect": + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == "replicate": + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == "zero": + p = 1 + else: + raise NotImplementedError( + f"padding {padding_type} \ + is not implemented" + ) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class ResidualBlock_noBN(nn.Module): + """Residual block w/o BN + ---Conv-ReLU-Conv-+- + |________________| + """ + + def __init__(self, nf=64): + super(ResidualBlock_noBN, self).__init__() + self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + + # initialization + initialize_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = F.relu(self.conv1(x), inplace=False) + out = self.conv2(out) + return identity + out diff --git a/diffusion-posterior-sampling/bkse/models/backbones/skip/concat.py b/diffusion-posterior-sampling/bkse/models/backbones/skip/concat.py new file mode 100644 index 0000000000000000000000000000000000000000..8798bd42c6ab7d6dc978106b46a9ae615826190f --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/skip/concat.py @@ -0,0 +1,39 @@ +import numpy as np +import torch +import torch.nn as nn + + +class Concat(nn.Module): + def __init__(self, dim, *args): + super(Concat, self).__init__() + self.dim = dim + + for idx, module in enumerate(args): + self.add_module(str(idx), module) + + def forward(self, input): + inputs = [] + for module in self._modules.values(): + inputs.append(module(input)) + + inputs_shapes2 = [x.shape[2] for x in inputs] + inputs_shapes3 = [x.shape[3] for x in inputs] + + if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( + np.array(inputs_shapes3) == min(inputs_shapes3) + ): + inputs_ = inputs + else: + target_shape2 = min(inputs_shapes2) + target_shape3 = min(inputs_shapes3) + + inputs_ = [] + for inp in inputs: + diff2 = (inp.size(2) - target_shape2) // 2 + diff3 = (inp.size(3) - target_shape3) // 2 + inputs_.append(inp[:, :, diff2 : diff2 + target_shape2, diff3 : diff3 + target_shape3]) + + return torch.cat(inputs_, dim=self.dim) + + def __len__(self): + return len(self._modules) diff --git a/diffusion-posterior-sampling/bkse/models/backbones/skip/downsampler.py b/diffusion-posterior-sampling/bkse/models/backbones/skip/downsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f458b60c91f2d6dc7a346f8a895e33fc7ca69010 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/skip/downsampler.py @@ -0,0 +1,241 @@ +import numpy as np +import torch +import torch.nn as nn + + +class Downsampler(nn.Module): + """ + http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf + """ + + def __init__( + self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False + ): + super(Downsampler, self).__init__() + + assert phase in [0, 0.5], "phase should be 0 or 0.5" + + if kernel_type == "lanczos2": + support = 2 + kernel_width = 4 * factor + 1 + kernel_type_ = "lanczos" + + elif kernel_type == "lanczos3": + support = 3 + kernel_width = 6 * factor + 1 + kernel_type_ = "lanczos" + + elif kernel_type == "gauss12": + kernel_width = 7 + sigma = 1 / 2 + kernel_type_ = "gauss" + + elif kernel_type == "gauss1sq2": + kernel_width = 9 + sigma = 1.0 / np.sqrt(2) + kernel_type_ = "gauss" + + elif kernel_type in ["lanczos", "gauss", "box"]: + kernel_type_ = kernel_type + + else: + assert False, "wrong name kernel" + + # note that `kernel width` will be different to actual size for phase = 1/2 + self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) + + downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) + downsampler.weight.data[:] = 0 + downsampler.bias.data[:] = 0 + + kernel_torch = torch.from_numpy(self.kernel) + for i in range(n_planes): + downsampler.weight.data[i, i] = kernel_torch + + self.downsampler_ = downsampler + + if preserve_size: + + if self.kernel.shape[0] % 2 == 1: + pad = int((self.kernel.shape[0] - 1) / 2.0) + else: + pad = int((self.kernel.shape[0] - factor) / 2.0) + + self.padding = nn.ReplicationPad2d(pad) + + self.preserve_size = preserve_size + + def forward(self, input): + if self.preserve_size: + x = self.padding(input) + else: + x = input + self.x = x + return self.downsampler_(x) + + +class Blurconv(nn.Module): + """ + http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf + """ + + def __init__(self, n_planes=1, preserve_size=False): + super(Blurconv, self).__init__() + + # self.kernel = kernel + # blurconv = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=1, padding=0) + # blurconvr.weight.data = self.kernel + # blurconv.bias.data[:] = 0 + self.n_planes = n_planes + self.preserve_size = preserve_size + + # kernel_torch = torch.from_numpy(self.kernel) + # for i in range(n_planes): + # blurconv.weight.data[i, i] = kernel_torch + + # self.blurconv_ = blurconv + # + # if preserve_size: + # + # if self.kernel.shape[0] % 2 == 1: + # pad = int((self.kernel.shape[0] - 1) / 2.) + # else: + # pad = int((self.kernel.shape[0] - factor) / 2.) + # + # self.padding = nn.ReplicationPad2d(pad) + # + # self.preserve_size = preserve_size + + def forward(self, input, kernel): + if self.preserve_size: + if kernel.shape[0] % 2 == 1: + pad = int((kernel.shape[3] - 1) / 2.0) + else: + pad = int((kernel.shape[3] - 1.0) / 2.0) + padding = nn.ReplicationPad2d(pad) + x = padding(input) + else: + x = input + + blurconv = nn.Conv2d( + self.n_planes, self.n_planes, kernel_size=kernel.size(3), stride=1, padding=0, bias=False + ).cuda() + + blurconv.weight.data[:] = kernel + + return blurconv(x) + + +class Blurconv2(nn.Module): + """ + http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf + """ + + def __init__(self, n_planes=1, preserve_size=False, k_size=21): + super(Blurconv2, self).__init__() + + self.n_planes = n_planes + self.k_size = k_size + self.preserve_size = preserve_size + self.blurconv = nn.Conv2d(self.n_planes, self.n_planes, kernel_size=k_size, stride=1, padding=0, bias=False) + + # self.blurconv.weight.data[:] /= self.blurconv.weight.data.sum() + def forward(self, input): + if self.preserve_size: + pad = int((self.k_size - 1.0) / 2.0) + padding = nn.ReplicationPad2d(pad) + x = padding(input) + else: + x = input + # self.blurconv.weight.data[:] /= self.blurconv.weight.data.sum() + return self.blurconv(x) + + +def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): + assert kernel_type in ["lanczos", "gauss", "box"] + + # factor = float(factor) + if phase == 0.5 and kernel_type != "box": + kernel = np.zeros([kernel_width - 1, kernel_width - 1]) + else: + kernel = np.zeros([kernel_width, kernel_width]) + + if kernel_type == "box": + assert phase == 0.5, "Box filter is always half-phased" + kernel[:] = 1.0 / (kernel_width * kernel_width) + + elif kernel_type == "gauss": + assert sigma, "sigma is not specified" + assert phase != 0.5, "phase 1/2 for gauss not implemented" + + center = (kernel_width + 1.0) / 2.0 + print(center, kernel_width) + sigma_sq = sigma * sigma + + for i in range(1, kernel.shape[0] + 1): + for j in range(1, kernel.shape[1] + 1): + di = (i - center) / 2.0 + dj = (j - center) / 2.0 + kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj) / (2 * sigma_sq)) + kernel[i - 1][j - 1] = kernel[i - 1][j - 1] / (2.0 * np.pi * sigma_sq) + elif kernel_type == "lanczos": + assert support, "support is not specified" + center = (kernel_width + 1) / 2.0 + + for i in range(1, kernel.shape[0] + 1): + for j in range(1, kernel.shape[1] + 1): + + if phase == 0.5: + di = abs(i + 0.5 - center) / factor + dj = abs(j + 0.5 - center) / factor + else: + di = abs(i - center) / factor + dj = abs(j - center) / factor + + val = 1 + if di != 0: + val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) + val = val / (np.pi * np.pi * di * di) + + if dj != 0: + val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) + val = val / (np.pi * np.pi * dj * dj) + + kernel[i - 1][j - 1] = val + + else: + assert False, "wrong method name" + + kernel /= kernel.sum() + + return kernel + + +# a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) + + +################# +# Learnable downsampler + +# KS = 32 +# dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) + +# class Apply(nn.Module): +# def __init__(self, what, dim, *args): +# super(Apply, self).__init__() +# self.dim = dim + +# self.what = what + +# def forward(self, input): +# inputs = [] +# for i in range(input.size(self.dim)): +# inputs.append(self.what(input.narrow(self.dim, i, 1))) + +# return torch.cat(inputs, dim=self.dim) + +# def __len__(self): +# return len(self._modules) + +# downs = Apply(dow, 1) +# downs.type(dtype)(net_input.type(dtype)).size() diff --git a/diffusion-posterior-sampling/bkse/models/backbones/skip/non_local_dot_product.py b/diffusion-posterior-sampling/bkse/models/backbones/skip/non_local_dot_product.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebed5e5021b75b14c6fba39f27a23ccc94a22a0 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/skip/non_local_dot_product.py @@ -0,0 +1,130 @@ +import torch +from torch import nn + + +class _NonLocalBlockND(nn.Module): + def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + + self.dimension = dimension + self.sub_sample = sub_sample + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool_layer = nn.MaxPool1d(kernel_size=(2)) + bn = nn.BatchNorm1d + + self.g = conv_nd( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + if bn_layer: + self.W = nn.Sequential( + conv_nd( + in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 + ), + bn(self.in_channels), + ) + nn.init.constant_(self.W[1].weight, 0) + nn.init.constant_(self.W[1].bias, 0) + else: + self.W = conv_nd( + in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0 + ) + nn.init.constant_(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) + + self.theta = conv_nd( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + self.phi = conv_nd( + in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 + ) + + if sub_sample: + self.g = nn.Sequential(self.g, max_pool_layer) + self.phi = nn.Sequential(self.phi, max_pool_layer) + + def forward(self, x): + """ + :param x: (b, c, t, h, w) + :return: + """ + + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock1D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock1D, self).__init__( + in_channels, inter_channels=inter_channels, dimension=1, sub_sample=sub_sample, bn_layer=bn_layer + ) + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock2D, self).__init__( + in_channels, inter_channels=inter_channels, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer + ) + + +class NONLocalBlock3D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): + super(NONLocalBlock3D, self).__init__( + in_channels, inter_channels=inter_channels, dimension=3, sub_sample=sub_sample, bn_layer=bn_layer + ) + + +if __name__ == "__main__": + for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: + img = torch.zeros(2, 3, 20) + net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) + out = net(img) + print(out.size()) + + img = torch.zeros(2, 3, 20, 20) + net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) + out = net(img) + print(out.size()) + + img = torch.randn(2, 3, 8, 20, 20) + net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) + out = net(img) + print(out.size()) diff --git a/diffusion-posterior-sampling/bkse/models/backbones/skip/skip.py b/diffusion-posterior-sampling/bkse/models/backbones/skip/skip.py new file mode 100644 index 0000000000000000000000000000000000000000..186153d77d34a49ae7152ace3149b6324731560d --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/skip/skip.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn + +from .concat import Concat +from .non_local_dot_product import NONLocalBlock2D +from .util import get_activation, get_conv + + +def add_module(self, module): + self.add_module(str(len(self) + 1), module) + + +torch.nn.Module.add = add_module + + +def skip( + num_input_channels=2, + num_output_channels=3, + num_channels_down=[16, 32, 64, 128, 128], + num_channels_up=[16, 32, 64, 128, 128], + num_channels_skip=[4, 4, 4, 4, 4], + filter_size_down=3, + filter_size_up=3, + filter_skip_size=1, + need_sigmoid=True, + need_bias=True, + pad="zero", + upsample_mode="nearest", + downsample_mode="stride", + act_fun="LeakyReLU", + need1x1_up=True, +): + """Assembles encoder-decoder with skip connections. + + Arguments: + act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) + pad (string): zero|reflection (default: 'zero') + upsample_mode (string): 'nearest|bilinear' (default: 'nearest') + downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') + + """ + assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) + + n_scales = len(num_channels_down) + + if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)): + upsample_mode = [upsample_mode] * n_scales + + if not (isinstance(downsample_mode, list) or isinstance(downsample_mode, tuple)): + downsample_mode = [downsample_mode] * n_scales + + if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)): + filter_size_down = [filter_size_down] * n_scales + + if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)): + filter_size_up = [filter_size_up] * n_scales + + last_scale = n_scales - 1 + + model = nn.Sequential() + model_tmp = model + + input_depth = num_input_channels + for i in range(len(num_channels_down)): + + deeper = nn.Sequential() + skip = nn.Sequential() + + if num_channels_skip[i] != 0: + model_tmp.add(Concat(1, skip, deeper)) + else: + model_tmp.add(deeper) + + model_tmp.add( + nn.BatchNorm2d(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])) + ) + + if num_channels_skip[i] != 0: + skip.add(get_conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) + skip.add(nn.BatchNorm2d(num_channels_skip[i])) + skip.add(get_activation(act_fun)) + + # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) + + deeper.add( + get_conv( + input_depth, + num_channels_down[i], + filter_size_down[i], + 2, + bias=need_bias, + pad=pad, + downsample_mode=downsample_mode[i], + ) + ) + deeper.add(nn.BatchNorm2d(num_channels_down[i])) + deeper.add(get_activation(act_fun)) + if i > 1: + deeper.add(NONLocalBlock2D(in_channels=num_channels_down[i])) + deeper.add(get_conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) + deeper.add(nn.BatchNorm2d(num_channels_down[i])) + deeper.add(get_activation(act_fun)) + + deeper_main = nn.Sequential() + + if i == len(num_channels_down) - 1: + # The deepest + k = num_channels_down[i] + else: + deeper.add(deeper_main) + k = num_channels_up[i + 1] + + deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) + + model_tmp.add( + get_conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad) + ) + model_tmp.add(nn.BatchNorm2d(num_channels_up[i])) + model_tmp.add(get_activation(act_fun)) + + if need1x1_up: + model_tmp.add(get_conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) + model_tmp.add(nn.BatchNorm2d(num_channels_up[i])) + model_tmp.add(get_activation(act_fun)) + + input_depth = num_channels_down[i] + model_tmp = deeper_main + + model.add(get_conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) + if need_sigmoid: + model.add(nn.Sigmoid()) + + return model diff --git a/diffusion-posterior-sampling/bkse/models/backbones/skip/util.py b/diffusion-posterior-sampling/bkse/models/backbones/skip/util.py new file mode 100644 index 0000000000000000000000000000000000000000..840ad5ef89471d77ff526be4326d718e3d395dde --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/skip/util.py @@ -0,0 +1,65 @@ +import torch.nn as nn + +from .downsampler import Downsampler + + +class Swish(nn.Module): + """ + https://arxiv.org/abs/1710.05941 + The hype was so huge that I could not help but try it + """ + + def __init__(self): + super(Swish, self).__init__() + self.s = nn.Sigmoid() + + def forward(self, x): + return x * self.s(x) + + +def get_conv(in_f, out_f, kernel_size, stride=1, bias=True, pad="zero", downsample_mode="stride"): + downsampler = None + if stride != 1 and downsample_mode != "stride": + + if downsample_mode == "avg": + downsampler = nn.AvgPool2d(stride, stride) + elif downsample_mode == "max": + downsampler = nn.MaxPool2d(stride, stride) + elif downsample_mode in ["lanczos2", "lanczos3"]: + downsampler = Downsampler( + n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True + ) + else: + assert False + + stride = 1 + + padder = None + to_pad = int((kernel_size - 1) / 2) + if pad == "reflection": + padder = nn.ReflectionPad2d(to_pad) + to_pad = 0 + + convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) + + layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) + return nn.Sequential(*layers) + + +def get_activation(act_fun="LeakyReLU"): + """ + Either string defining an activation function or module (e.g. nn.ReLU) + """ + if isinstance(act_fun, str): + if act_fun == "LeakyReLU": + return nn.LeakyReLU(0.2, inplace=True) + elif act_fun == "Swish": + return Swish() + elif act_fun == "ELU": + return nn.ELU() + elif act_fun == "none": + return nn.Sequential() + else: + assert False + else: + return act_fun() diff --git a/diffusion-posterior-sampling/bkse/models/backbones/unet_parts.py b/diffusion-posterior-sampling/bkse/models/backbones/unet_parts.py new file mode 100644 index 0000000000000000000000000000000000000000..6396c7c24788c6e0133b144842313a47a90c55ec --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/backbones/unet_parts.py @@ -0,0 +1,109 @@ +""" Parts of the U-Net model """ + +import functools + +import torch +import torch.nn as nn + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__( + self, + outer_nc, + inner_nc, + input_nc=None, + submodule=None, + outermost=False, + innermost=False, + norm_layer=nn.BatchNorm2d, + use_dropout=False, + ): + """Construct a Unet submodule with skip connections. + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) --previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + self.innermost = innermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) + # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + # upconv = DoubleConv(inner_nc * 2, outer_nc) + up = [uprelu, upconv, nn.Tanh()] + down = [downconv] + self.down = nn.Sequential(*down) + self.submodule = submodule + self.up = nn.Sequential(*up) + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + # upconv = DoubleConv(inner_nc * 2, outer_nc) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + self.down = nn.Sequential(*down) + self.up = nn.Sequential(*up) + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) + # upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + # upconv = DoubleConv(inner_nc * 2, outer_nc) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + if use_dropout: + up += [nn.Dropout(0.5)] + + self.down = nn.Sequential(*down) + self.submodule = submodule + self.up = nn.Sequential(*up) + + def forward(self, x, noise): + + if self.outermost: + return self.up(self.submodule(self.down(x), noise)) + elif self.innermost: # add skip connections + if noise is None: + noise = torch.randn((1, 512, 8, 8)).cuda() * 0.0007 + return torch.cat((self.up(torch.cat((self.down(x), noise), dim=1)), x), dim=1) + else: + return torch.cat((self.up(self.submodule(self.down(x), noise)), x), dim=1) diff --git a/diffusion-posterior-sampling/bkse/models/deblurring/image_deblur.py b/diffusion-posterior-sampling/bkse/models/deblurring/image_deblur.py new file mode 100644 index 0000000000000000000000000000000000000000..d03a635abd1262155047570de19dd59fb312533c --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/deblurring/image_deblur.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import utils.util as util +from models.dips import ImageDIP, KernelDIP +from models.kernel_encoding.kernel_wizard import KernelWizard +from models.losses.hyper_laplacian_penalty import HyperLaplacianPenalty +from models.losses.perceptual_loss import PerceptualLoss +from models.losses.ssim_loss import SSIM +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm + + +class ImageDeblur: + def __init__(self, opt): + self.opt = opt + + # losses + self.ssim_loss = SSIM().cuda() + self.mse = nn.MSELoss().cuda() + self.perceptual_loss = PerceptualLoss().cuda() + self.laplace_penalty = HyperLaplacianPenalty(3, 0.66).cuda() + + self.kernel_wizard = KernelWizard(opt["KernelWizard"]).cuda() + self.kernel_wizard.load_state_dict(torch.load(opt["KernelWizard"]["pretrained"])) + + for k, v in self.kernel_wizard.named_parameters(): + v.requires_grad = False + + def reset_optimizers(self): + self.x_optimizer = torch.optim.Adam(self.x_dip.parameters(), lr=self.opt["x_lr"]) + self.k_optimizer = torch.optim.Adam(self.k_dip.parameters(), lr=self.opt["k_lr"]) + + self.x_scheduler = StepLR(self.x_optimizer, step_size=self.opt["num_iters"] // 5, gamma=0.7) + + self.k_scheduler = StepLR(self.k_optimizer, step_size=self.opt["num_iters"] // 5, gamma=0.7) + + def prepare_DIPs(self): + # x is stand for the sharp image, k is stand for the kernel + self.x_dip = ImageDIP(self.opt["ImageDIP"]).cuda() + self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda() + + # fixed input vectors of DIPs + # zk and zx are the length of the corresponding vectors + self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda() + self.dip_zx = util.get_noise(8, "noise", self.opt["img_size"]).cuda() + + def warmup(self, warmup_x, warmup_k): + # Input vector of DIPs is sampled from N(z, I) + reg_noise_std = self.opt["reg_noise_std"] + + for step in tqdm(range(self.opt["num_warmup_iters"])): + self.x_optimizer.zero_grad() + dip_zx_rand = self.dip_zx + reg_noise_std * torch.randn_like(self.dip_zx).cuda() + x = self.x_dip(dip_zx_rand) + + loss = self.mse(x, warmup_x) + loss.backward() + self.x_optimizer.step() + + print("Warming up k DIP") + for step in tqdm(range(self.opt["num_warmup_iters"])): + self.k_optimizer.zero_grad() + dip_zk_rand = self.dip_zk + reg_noise_std * torch.randn_like(self.dip_zk).cuda() + k = self.k_dip(dip_zk_rand) + + loss = self.mse(k, warmup_k) + loss.backward() + self.k_optimizer.step() + + def deblur(self, img): + pass diff --git a/diffusion-posterior-sampling/bkse/models/deblurring/joint_deblur.py b/diffusion-posterior-sampling/bkse/models/deblurring/joint_deblur.py new file mode 100644 index 0000000000000000000000000000000000000000..18839e67a9e27d7fe4eff19750d33d40bac197b3 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/deblurring/joint_deblur.py @@ -0,0 +1,63 @@ +import torch +import utils.util as util +from models.deblurring.image_deblur import ImageDeblur +from tqdm import tqdm + + +class JointDeblur(ImageDeblur): + def __init__(self, opt): + super(JointDeblur, self).__init__(opt) + + def deblur(self, y): + """Deblur image + Args: + y: Blur image + """ + y = util.img2tensor(y).unsqueeze(0).cuda() + + self.prepare_DIPs() + self.reset_optimizers() + + warmup_k = torch.load(self.opt["warmup_k_path"]).cuda() + self.warmup(y, warmup_k) + + # Input vector of DIPs is sampled from N(z, I) + + print("Deblurring") + reg_noise_std = self.opt["reg_noise_std"] + for step in tqdm(range(self.opt["num_iters"])): + dip_zx_rand = self.dip_zx + reg_noise_std * torch.randn_like(self.dip_zx).cuda() + dip_zk_rand = self.dip_zk + reg_noise_std * torch.randn_like(self.dip_zk).cuda() + + self.x_optimizer.zero_grad() + self.k_optimizer.zero_grad() + + self.x_scheduler.step() + self.k_scheduler.step() + + x = self.x_dip(dip_zx_rand) + k = self.k_dip(dip_zk_rand) + + fake_y = self.kernel_wizard.adaptKernel(x, k) + + if step < self.opt["num_iters"] // 2: + total_loss = 6e-1 * self.perceptual_loss(fake_y, y) + total_loss += 1 - self.ssim_loss(fake_y, y) + total_loss += 5e-5 * torch.norm(k) + total_loss += 2e-2 * self.laplace_penalty(x) + else: + total_loss = self.perceptual_loss(fake_y, y) + total_loss += 5e-2 * self.laplace_penalty(x) + total_loss += 5e-4 * torch.norm(k) + + total_loss.backward() + + self.x_optimizer.step() + self.k_optimizer.step() + + # debugging + # if step % 100 == 0: + # print(torch.norm(k)) + # print(f"{self.k_optimizer.param_groups[0]['lr']:.3e}") + + return util.tensor2img(x.detach()) diff --git a/diffusion-posterior-sampling/bkse/models/dips.py b/diffusion-posterior-sampling/bkse/models/dips.py new file mode 100644 index 0000000000000000000000000000000000000000..e669bca8d85c5349ef96c92e9f44e3f1fc224049 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dips.py @@ -0,0 +1,83 @@ +import models.arch_util as arch_util +import torch.nn as nn +from models.backbones.resnet import ResnetBlock +from models.backbones.skip.skip import skip + + +class KernelDIP(nn.Module): + """ + DIP (Deep Image Prior) for blur kernel + """ + + def __init__(self, opt): + super(KernelDIP, self).__init__() + + norm_layer = arch_util.get_norm_layer("none") + n_blocks = opt["n_blocks"] + nf = opt["nf"] + padding_type = opt["padding_type"] + use_dropout = opt["use_dropout"] + kernel_dim = opt["kernel_dim"] + + input_nc = 64 + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, nf, kernel_size=7, padding=0, bias=True), + norm_layer(nf), + nn.ReLU(True), + ] + + n_downsampling = 5 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + input_nc = min(nf * mult, kernel_dim) + output_nc = min(nf * mult * 2, kernel_dim) + model += [ + nn.Conv2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, bias=True), + norm_layer(nf * mult * 2), + nn.ReLU(True), + ] + + for i in range(n_blocks): # add ResNet blocks + model += [ + ResnetBlock( + kernel_dim, + padding_type=padding_type, + norm_layer=norm_layer, + use_dropout=use_dropout, + use_bias=True, + ) + ] + + self.model = nn.Sequential(*model) + + def forward(self, noise): + return self.model(noise) + + +class ImageDIP(nn.Module): + """ + DIP (Deep Image Prior) for sharp image + """ + + def __init__(self, opt): + super(ImageDIP, self).__init__() + + input_nc = opt["input_nc"] + output_nc = opt["output_nc"] + + self.model = skip( + input_nc, + output_nc, + num_channels_down=[128, 128, 128, 128, 128], + num_channels_up=[128, 128, 128, 128, 128], + num_channels_skip=[16, 16, 16, 16, 16], + upsample_mode="bilinear", + need_sigmoid=True, + need_bias=True, + pad=opt["padding_type"], + act_fun="LeakyReLU", + ) + + def forward(self, img): + return self.model(img) diff --git a/diffusion-posterior-sampling/bkse/models/dsd/bicubic.py b/diffusion-posterior-sampling/bkse/models/dsd/bicubic.py new file mode 100644 index 0000000000000000000000000000000000000000..90f60f7f71e4d156824089d61cc0fca325d7afa6 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/bicubic.py @@ -0,0 +1,76 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class BicubicDownSample(nn.Module): + def bicubic_kernel(self, x, a=-0.50): + """ + This equation is exactly copied from the website below: + https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic + """ + abs_x = torch.abs(x) + if abs_x <= 1.0: + return (a + 2.0) * torch.pow(abs_x, 3.0) - (a + 3.0) * torch.pow(abs_x, 2.0) + 1 + elif 1.0 < abs_x < 2.0: + return a * torch.pow(abs_x, 3) - 5.0 * a * torch.pow(abs_x, 2.0) + 8.0 * a * abs_x - 4.0 * a + else: + return 0.0 + + def __init__(self, factor=4, cuda=True, padding="reflect"): + super().__init__() + self.factor = factor + size = factor * 4 + k = torch.tensor( + [self.bicubic_kernel((i - torch.floor(torch.tensor(size / 2)) + 0.5) / factor) for i in range(size)], + dtype=torch.float32, + ) + k = k / torch.sum(k) + # k = torch.einsum('i,j->ij', (k, k)) + k1 = torch.reshape(k, shape=(1, 1, size, 1)) + self.k1 = torch.cat([k1, k1, k1], dim=0) + k2 = torch.reshape(k, shape=(1, 1, 1, size)) + self.k2 = torch.cat([k2, k2, k2], dim=0) + self.cuda = ".cuda" if cuda else "" + self.padding = padding + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x, nhwc=False, clip_round=False, byte_output=False): + # x = torch.from_numpy(x).type('torch.FloatTensor') + filter_height = self.factor * 4 + filter_width = self.factor * 4 + stride = self.factor + + pad_along_height = max(filter_height - stride, 0) + pad_along_width = max(filter_width - stride, 0) + filters1 = self.k1.type("torch{}.FloatTensor".format(self.cuda)) + filters2 = self.k2.type("torch{}.FloatTensor".format(self.cuda)) + + # compute actual padding values for each side + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + + # apply mirror padding + if nhwc: + x = torch.transpose(torch.transpose(x, 2, 3), 1, 2) # NHWC to NCHW + + # downscaling performed by 1-d convolution + x = F.pad(x, (0, 0, pad_top, pad_bottom), self.padding) + x = F.conv2d(input=x, weight=filters1, stride=(stride, 1), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.0) + + x = F.pad(x, (pad_left, pad_right, 0, 0), self.padding) + x = F.conv2d(input=x, weight=filters2, stride=(1, stride), groups=3) + if clip_round: + x = torch.clamp(torch.round(x), 0.0, 255.0) + + if nhwc: + x = torch.transpose(torch.transpose(x, 1, 3), 1, 2) + if byte_output: + return x.type("torch.{}.ByteTensor".format(self.cuda)) + else: + return x diff --git a/diffusion-posterior-sampling/bkse/models/dsd/dsd.py b/diffusion-posterior-sampling/bkse/models/dsd/dsd.py new file mode 100644 index 0000000000000000000000000000000000000000..a21eb5fdd34bc30b1adc52145f1d9eff50d301b2 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/dsd.py @@ -0,0 +1,194 @@ +from functools import partial +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import utils.util as util +from models.dips import KernelDIP +from models.dsd.spherical_optimizer import SphericalOptimizer +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm + + +class DSD(torch.nn.Module): + def __init__(self, opt, cache_dir): + super(DSD, self).__init__() + + self.opt = opt + + self.verbose = opt["verbose"] + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + + # Initialize synthesis network + if self.verbose: + print("Loading Synthesis Network") + self.load_synthesis_network() + if self.verbose: + print("Synthesis Network loaded!") + + self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2) + + self.initialize_mapping_network() + + def initialize_dip(self): + self.dip_zk = util.get_noise(64, "noise", (64, 64)).cuda().detach() + self.k_dip = KernelDIP(self.opt["KernelDIP"]).cuda() + + def initialize_latent_space(self): + pass + + def initialize_optimizers(self): + # Optimizer for k + self.optimizer_k = torch.optim.Adam(self.k_dip.parameters(), lr=self.opt["k_lr"]) + self.scheduler_k = StepLR( + self.optimizer_k, step_size=self.opt["num_epochs"] * self.opt["num_k_iters"] // 5, gamma=0.7 + ) + + # Optimizer for x + optimizer_dict = { + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "sgdm": partial(torch.optim.SGD, momentum=0.9), + "adamax": torch.optim.Adamax, + } + optimizer_func = optimizer_dict[self.opt["optimizer_name"]] + self.optimizer_x = SphericalOptimizer(optimizer_func, self.latent_x_var_list, lr=self.opt["x_lr"]) + + steps = self.opt["num_epochs"] * self.opt["num_x_iters"] + schedule_dict = { + "fixed": lambda x: 1, + "linear1cycle": lambda x: (9 * (1 - np.abs(x / steps - 1 / 2) * 2) + 1) / 10, + "linear1cycledrop": lambda x: (9 * (1 - np.abs(x / (0.9 * steps) - 1 / 2) * 2) + 1) / 10 + if x < 0.9 * steps + else 1 / 10 + (x - 0.9 * steps) / (0.1 * steps) * (1 / 1000 - 1 / 10), + } + schedule_func = schedule_dict[self.opt["lr_schedule"]] + self.scheduler_x = torch.optim.lr_scheduler.LambdaLR(self.optimizer_x.opt, schedule_func) + + def warmup_dip(self): + self.reg_noise_std = self.opt["reg_noise_std"] + warmup_k = torch.load("experiments/pretrained/kernel.pth") + + mse = nn.MSELoss().cuda() + + print("Warming up k DIP") + for step in tqdm(range(self.opt["num_warmup_iters"])): + self.optimizer_k.zero_grad() + dip_zk_rand = self.dip_zk + self.reg_noise_std * torch.randn_like(self.dip_zk).cuda() + k = self.k_dip(dip_zk_rand) + + loss = mse(k, warmup_k) + loss.backward() + self.optimizer_k.step() + + def optimize_k_step(self, epoch): + # Optimize k + tq_k = tqdm(range(self.opt["num_k_iters"])) + for j in tq_k: + for p in self.k_dip.parameters(): + p.requires_grad = True + for p in self.latent_x_var_list: + p.requires_grad = False + + self.optimizer_k.zero_grad() + + # Duplicate latent in case tile_latent = True + if self.opt["tile_latent"]: + latent_in = self.latent.expand(-1, 14, -1) + else: + latent_in = self.latent + + dip_zk_rand = self.dip_zk + self.reg_noise_std * torch.randn_like(self.dip_zk).cuda() + # Apply learned linear mapping to match latent distribution to that of the mapping network + latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) + + # Normalize image to [0,1] instead of [-1,1] + self.gen_im = self.get_gen_im(latent_in) + self.gen_ker = self.k_dip(dip_zk_rand) + + # Calculate Losses + loss, loss_dict = self.loss_builder(latent_in, self.gen_im, self.gen_ker, epoch) + self.cur_loss = loss.cpu().detach().numpy() + + loss.backward() + self.optimizer_k.step() + self.scheduler_k.step() + + msg = " | ".join("{}: {:.4f}".format(k, v) for k, v in loss_dict.items()) + tq_k.set_postfix(loss=msg) + + def optimize_x_step(self, epoch): + tq_x = tqdm(range(self.opt["num_x_iters"])) + for j in tq_x: + for p in self.k_dip.parameters(): + p.requires_grad = False + for p in self.latent_x_var_list: + p.requires_grad = True + + self.optimizer_x.opt.zero_grad() + + # Duplicate latent in case tile_latent = True + if self.opt["tile_latent"]: + latent_in = self.latent.expand(-1, 14, -1) + else: + latent_in = self.latent + + dip_zk_rand = self.dip_zk + self.reg_noise_std * torch.randn_like(self.dip_zk).cuda() + # Apply learned linear mapping to match latent distribution to that of the mapping network + latent_in = self.lrelu(latent_in * self.gaussian_fit["std"] + self.gaussian_fit["mean"]) + + # Normalize image to [0,1] instead of [-1,1] + self.gen_im = self.get_gen_im(latent_in) + self.gen_ker = self.k_dip(dip_zk_rand) + + # Calculate Losses + loss, loss_dict = self.loss_builder(latent_in, self.gen_im, self.gen_ker, epoch) + self.cur_loss = loss.cpu().detach().numpy() + + loss.backward() + self.optimizer_x.step() + self.scheduler_x.step() + + msg = " | ".join("{}: {:.4f}".format(k, v) for k, v in loss_dict.items()) + tq_x.set_postfix(loss=msg) + + def log(self): + if self.cur_loss < self.min_loss: + self.min_loss = self.cur_loss + self.best_im = self.gen_im.clone() + self.best_ker = self.gen_ker.clone() + + def forward(self, ref_im): + if self.opt["seed"]: + seed = self.opt["seed"] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + self.initialize_dip() + self.initialize_latent_space() + self.initialize_optimizers() + self.warmup_dip() + + self.min_loss = np.inf + self.gen_im = None + self.initialize_loss(ref_im) + + if self.verbose: + print("Optimizing") + + for epoch in range(self.opt["num_epochs"]): + print("Step: {}".format(epoch + 1)) + + self.optimize_x_step(epoch) + self.log() + self.optimize_k_step(epoch) + self.log() + + if self.opt["save_intermediate"]: + yield ( + self.best_im.cpu().detach().clamp(0, 1), + self.loss_builder.get_blur_img(self.best_im, self.best_ker), + ) diff --git a/diffusion-posterior-sampling/bkse/models/dsd/dsd_stylegan.py b/diffusion-posterior-sampling/bkse/models/dsd/dsd_stylegan.py new file mode 100644 index 0000000000000000000000000000000000000000..2c2a4e263896fc8fd53ed2d61b0884faf474d9a0 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/dsd_stylegan.py @@ -0,0 +1,81 @@ +from pathlib import Path + +import torch +from models.dsd.dsd import DSD +from models.dsd.stylegan import G_mapping, G_synthesis +from models.losses.dsd_loss import LossBuilderStyleGAN + + +class DSDStyleGAN(DSD): + def __init__(self, opt, cache_dir): + super(DSDStyleGAN, self).__init__(opt, cache_dir) + + def load_synthesis_network(self): + self.synthesis = G_synthesis().cuda() + self.synthesis.load_state_dict(torch.load("experiments/pretrained/stylegan_synthesis.pt")) + for v in self.synthesis.parameters(): + v.requires_grad = False + + def initialize_mapping_network(self): + if Path("experiments/pretrained/gaussian_fit_stylegan.pt").exists(): + self.gaussian_fit = torch.load("experiments/pretrained/gaussian_fit_stylegan.pt") + else: + if self.verbose: + print("\tRunning Mapping Network") + + mapping = G_mapping().cuda() + mapping.load_state_dict(torch.load("experiments/pretrained/stylegan_mapping.pt")) + with torch.no_grad(): + torch.manual_seed(0) + latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") + latent_out = torch.nn.LeakyReLU(5)(mapping(latent)) + self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} + torch.save(self.gaussian_fit, "experiments/pretrained/gaussian_fit_stylegan.pt") + if self.verbose: + print('\tSaved "gaussian_fit_stylegan.pt"') + + def initialize_latent_space(self): + batch_size = self.opt["batch_size"] + + # Generate latent tensor + if self.opt["tile_latent"]: + self.latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda") + else: + self.latent = torch.randn((batch_size, 18, 512), dtype=torch.float, requires_grad=True, device="cuda") + + # Generate list of noise tensors + noise = [] # stores all of the noise tensors + noise_vars = [] # stores the noise tensors that we want to optimize on + + noise_type = self.opt["noise_type"] + bad_noise_layers = self.opt["bad_noise_layers"] + for i in range(18): + # dimension of the ith noise tensor + res = (batch_size, 1, 2 ** (i // 2 + 2), 2 ** (i // 2 + 2)) + + if noise_type == "zero" or i in [int(layer) for layer in bad_noise_layers.split(".")]: + new_noise = torch.zeros(res, dtype=torch.float, device="cuda") + new_noise.requires_grad = False + elif noise_type == "fixed": + new_noise = torch.randn(res, dtype=torch.float, device="cuda") + new_noise.requires_grad = False + elif noise_type == "trainable": + new_noise = torch.randn(res, dtype=torch.float, device="cuda") + if i < self.opt["num_trainable_noise_layers"]: + new_noise.requires_grad = True + noise_vars.append(new_noise) + else: + new_noise.requires_grad = False + else: + raise Exception("unknown noise type") + + noise.append(new_noise) + + self.latent_x_var_list = [self.latent] + noise_vars + self.noise = noise + + def initialize_loss(self, ref_im): + self.loss_builder = LossBuilderStyleGAN(ref_im, self.opt).cuda() + + def get_gen_im(self, latent_in): + return (self.synthesis(latent_in, self.noise) + 1) / 2 diff --git a/diffusion-posterior-sampling/bkse/models/dsd/dsd_stylegan2.py b/diffusion-posterior-sampling/bkse/models/dsd/dsd_stylegan2.py new file mode 100644 index 0000000000000000000000000000000000000000..32e99ac1eeab41908ebfe61ac99913587c6d7149 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/dsd_stylegan2.py @@ -0,0 +1,78 @@ +from pathlib import Path + +import torch +from models.dsd.dsd import DSD +from models.dsd.stylegan2 import Generator +from models.losses.dsd_loss import LossBuilderStyleGAN2 + + +class DSDStyleGAN2(DSD): + def __init__(self, opt, cache_dir): + super(DSDStyleGAN2, self).__init__(opt, cache_dir) + + def load_synthesis_network(self): + self.synthesis = Generator(size=256, style_dim=512, n_mlp=8).cuda() + self.synthesis.load_state_dict(torch.load("experiments/pretrained/stylegan2.pt")["g_ema"], strict=False) + for v in self.synthesis.parameters(): + v.requires_grad = False + + def initialize_mapping_network(self): + if Path("experiments/pretrained/gaussian_fit_stylegan2.pt").exists(): + self.gaussian_fit = torch.load("experiments/pretrained/gaussian_fit_stylegan2.pt") + else: + if self.verbose: + print("\tRunning Mapping Network") + with torch.no_grad(): + torch.manual_seed(0) + latent = torch.randn((1000000, 512), dtype=torch.float32, device="cuda") + latent_out = torch.nn.LeakyReLU(5)(self.synthesis.get_latent(latent)) + self.gaussian_fit = {"mean": latent_out.mean(0), "std": latent_out.std(0)} + torch.save(self.gaussian_fit, "experiments/pretrained/gaussian_fit_stylegan2.pt") + if self.verbose: + print('\tSaved "gaussian_fit_stylegan2.pt"') + + def initialize_latent_space(self): + batch_size = self.opt["batch_size"] + + # Generate latent tensor + if self.opt["tile_latent"]: + self.latent = torch.randn((batch_size, 1, 512), dtype=torch.float, requires_grad=True, device="cuda") + else: + self.latent = torch.randn((batch_size, 14, 512), dtype=torch.float, requires_grad=True, device="cuda") + + # Generate list of noise tensors + noise = [] # stores all of the noise tensors + noise_vars = [] # stores the noise tensors that we want to optimize on + + for i in range(14): + res = (i + 5) // 2 + res = [1, 1, 2 ** res, 2 ** res] + + noise_type = self.opt["noise_type"] + bad_noise_layers = self.opt["bad_noise_layers"] + if noise_type == "zero" or i in [int(layer) for layer in bad_noise_layers.split(".")]: + new_noise = torch.zeros(res, dtype=torch.float, device="cuda") + new_noise.requires_grad = False + elif noise_type == "fixed": + new_noise = torch.randn(res, dtype=torch.float, device="cuda") + new_noise.requires_grad = False + elif noise_type == "trainable": + new_noise = torch.randn(res, dtype=torch.float, device="cuda") + if i < self.opt["num_trainable_noise_layers"]: + new_noise.requires_grad = True + noise_vars.append(new_noise) + else: + new_noise.requires_grad = False + else: + raise Exception("unknown noise type") + + noise.append(new_noise) + + self.latent_x_var_list = [self.latent] + noise_vars + self.noise = noise + + def initialize_loss(self, ref_im): + self.loss_builder = LossBuilderStyleGAN2(ref_im, self.opt).cuda() + + def get_gen_im(self, latent_in): + return (self.synthesis([latent_in], input_is_latent=True, noise=self.noise)[0] + 1) / 2 diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/__init__.py b/diffusion-posterior-sampling/bkse/models/dsd/op/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/fused_act.py b/diffusion-posterior-sampling/bkse/models/dsd/op/fused_act.py new file mode 100755 index 0000000000000000000000000000000000000000..d5642f912ee7b488981dba83fba4876b3a27a954 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/op/fused_act.py @@ -0,0 +1,107 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +fused = load( + "fused", + sources=[ + os.path.join(module_path, "fused_bias_act.cpp"), + os.path.join(module_path, "fused_bias_act_kernel.cu"), + ], +) + + +class FusedLeakyReLUFunctionBackward(Function): + @staticmethod + def forward(ctx, grad_output, out, bias, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + if bias: + grad_bias = grad_input.sum(dim).detach() + + else: + grad_bias = None + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + (out,) = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale) + + return gradgrad_out, None, None, None, None + + +class FusedLeakyReLUFunction(Function): + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + + if bias is None: + bias = empty + + ctx.bias = bias is not None + + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + (out,) = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale + ) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + + if bias: + self.bias = nn.Parameter(torch.zeros(channel)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): + if input.device.type == "cpu": + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return F.leaky_relu(input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2) * scale + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/fused_bias_act.cpp b/diffusion-posterior-sampling/bkse/models/dsd/op/fused_bias_act.cpp new file mode 100755 index 0000000000000000000000000000000000000000..a054318781a20596d8f516ef86745e5572aad0f7 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/op/fused_bias_act.cpp @@ -0,0 +1,21 @@ +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} \ No newline at end of file diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/fused_bias_act_kernel.cu b/diffusion-posterior-sampling/bkse/models/dsd/op/fused_bias_act_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..8d2f03c73605faee6723d002ba5de88cb465a80e --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/op/fused_bias_act_kernel.cu @@ -0,0 +1,99 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} \ No newline at end of file diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d.cpp b/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d.cpp new file mode 100755 index 0000000000000000000000000000000000000000..b07aa2056864db83ff0aacbb1068e072ba9da4ad --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d.cpp @@ -0,0 +1,23 @@ +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d.py b/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d.py new file mode 100755 index 0000000000000000000000000000000000000000..99e07366b0ee5fb87b4d9b88f312ad93e282af6d --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d.py @@ -0,0 +1,184 @@ +import os + +import torch +from torch.autograd import Function +from torch.nn import functional as F +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "upfirdn2d.cpp"), + os.path.join(module_path, "upfirdn2d_kernel.cu"), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + (kernel,) = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == "cpu": + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d_kernel.cu b/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d_kernel.cu new file mode 100755 index 0000000000000000000000000000000000000000..ed3eea30305d084a2ba93aff07c0e79c23b179c3 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} \ No newline at end of file diff --git a/diffusion-posterior-sampling/bkse/models/dsd/spherical_optimizer.py b/diffusion-posterior-sampling/bkse/models/dsd/spherical_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b24c18540e74bfd74f345c10df9494a902f37e64 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/spherical_optimizer.py @@ -0,0 +1,29 @@ +import torch +from torch.optim import Optimizer + + +# Spherical Optimizer Class +# Uses the first two dimensions as batch information +# Optimizes over the surface of a sphere using the initial radius throughout +# +# Example Usage: +# opt = SphericalOptimizer(torch.optim.SGD, [x], lr=0.01) + + +class SphericalOptimizer(Optimizer): + def __init__(self, optimizer, params, **kwargs): + self.opt = optimizer(params, **kwargs) + self.params = params + with torch.no_grad(): + self.radii = { + param: (param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt() for param in params + } + + @torch.no_grad() + def step(self, closure=None): + loss = self.opt.step(closure) + for param in self.params: + param.data.div_((param.pow(2).sum(tuple(range(2, param.ndim)), keepdim=True) + 1e-9).sqrt()) + param.mul_(self.radii[param]) + + return loss diff --git a/diffusion-posterior-sampling/bkse/models/dsd/stylegan.py b/diffusion-posterior-sampling/bkse/models/dsd/stylegan.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd993a714b1d4299c78910037b8653bfedcf53f --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/stylegan.py @@ -0,0 +1,474 @@ +# Modified from https://github.com/lernapparat/lernapparat/ + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MyLinear(nn.Module): + """Linear layer with equalized learning rate and custom learning rate multiplier.""" + + def __init__(self, input_size, output_size, gain=2 ** (0.5), use_wscale=False, lrmul=1, bias=True): + super().__init__() + he_std = gain * input_size ** (-0.5) # He init + # Equalized learning rate and custom learning rate multiplier. + if use_wscale: + init_std = 1.0 / lrmul + self.w_mul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_mul = lrmul + self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(output_size)) + self.b_mul = lrmul + else: + self.bias = None + + def forward(self, x): + bias = self.bias + if bias is not None: + bias = bias * self.b_mul + return F.linear(x, self.weight * self.w_mul, bias) + + +class MyConv2d(nn.Module): + """Conv layer with equalized learning rate and custom learning rate multiplier.""" + + def __init__( + self, + input_channels, + output_channels, + kernel_size, + gain=2 ** (0.5), + use_wscale=False, + lrmul=1, + bias=True, + intermediate=None, + upscale=False, + ): + super().__init__() + if upscale: + self.upscale = Upscale2d() + else: + self.upscale = None + he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init + self.kernel_size = kernel_size + if use_wscale: + init_std = 1.0 / lrmul + self.w_mul = he_std * lrmul + else: + init_std = he_std / lrmul + self.w_mul = lrmul + self.weight = torch.nn.Parameter( + torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std + ) + if bias: + self.bias = torch.nn.Parameter(torch.zeros(output_channels)) + self.b_mul = lrmul + else: + self.bias = None + self.intermediate = intermediate + + def forward(self, x): + bias = self.bias + if bias is not None: + bias = bias * self.b_mul + + have_convolution = False + if self.upscale is not None and min(x.shape[2:]) * 2 >= 128: + # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way + # this really needs to be cleaned up and go into the conv... + w = self.weight * self.w_mul + w = w.permute(1, 0, 2, 3) + # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?! + w = F.pad(w, (1, 1, 1, 1)) + w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] + x = F.conv_transpose2d(x, w, stride=2, padding=int((w.size(-1) - 1) // 2)) + have_convolution = True + elif self.upscale is not None: + x = self.upscale(x) + + if not have_convolution and self.intermediate is None: + return F.conv2d(x, self.weight * self.w_mul, bias, padding=int(self.kernel_size // 2)) + elif not have_convolution: + x = F.conv2d(x, self.weight * self.w_mul, None, padding=int(self.kernel_size // 2)) + + if self.intermediate is not None: + x = self.intermediate(x) + if bias is not None: + x = x + bias.view(1, -1, 1, 1) + return x + + +class NoiseLayer(nn.Module): + """adds noise. noise is per pixel (constant over channels) with per-channel weight""" + + def __init__(self, channels): + super().__init__() + self.weight = nn.Parameter(torch.zeros(channels)) + self.noise = None + + def forward(self, x, noise=None): + if noise is None and self.noise is None: + noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype) + elif noise is None: + # here is a little trick: if you get all the noiselayers and set each + # modules .noise attribute, you can have pre-defined noise. + # Very useful for analysis + noise = self.noise + x = x + self.weight.view(1, -1, 1, 1) * noise + return x + + +class StyleMod(nn.Module): + def __init__(self, latent_size, channels, use_wscale): + super(StyleMod, self).__init__() + self.lin = MyLinear(latent_size, channels * 2, gain=1.0, use_wscale=use_wscale) + + def forward(self, x, latent): + style = self.lin(latent) # style => [batch_size, n_channels*2] + shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1] + style = style.view(shape) # [batch_size, 2, n_channels, ...] + x = x * (style[:, 0] + 1.0) + style[:, 1] + return x + + +class PixelNormLayer(nn.Module): + def __init__(self, epsilon=1e-8): + super().__init__() + self.epsilon = epsilon + + def forward(self, x): + return x * torch.rsqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon) + + +class BlurLayer(nn.Module): + def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1): + super(BlurLayer, self).__init__() + kernel = [1, 2, 1] + kernel = torch.tensor(kernel, dtype=torch.float32) + kernel = kernel[:, None] * kernel[None, :] + kernel = kernel[None, None] + if normalize: + kernel = kernel / kernel.sum() + if flip: + kernel = kernel[:, :, ::-1, ::-1] + self.register_buffer("kernel", kernel) + self.stride = stride + + def forward(self, x): + # expand kernel channels + kernel = self.kernel.expand(x.size(1), -1, -1, -1) + x = F.conv2d(x, kernel, stride=self.stride, padding=int((self.kernel.size(2) - 1) / 2), groups=x.size(1)) + return x + + +def upscale2d(x, factor=2, gain=1): + assert x.dim() == 4 + if gain != 1: + x = x * gain + if factor != 1: + shape = x.shape + x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor) + x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3]) + return x + + +class Upscale2d(nn.Module): + def __init__(self, factor=2, gain=1): + super().__init__() + assert isinstance(factor, int) and factor >= 1 + self.gain = gain + self.factor = factor + + def forward(self, x): + return upscale2d(x, factor=self.factor, gain=self.gain) + + +class G_mapping(nn.Sequential): + def __init__(self, nonlinearity="lrelu", use_wscale=True): + act, gain = {"relu": (torch.relu, np.sqrt(2)), "lrelu": (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[ + nonlinearity + ] + layers = [ + ("pixel_norm", PixelNormLayer()), + ("dense0", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense0_act", act), + ("dense1", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense1_act", act), + ("dense2", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense2_act", act), + ("dense3", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense3_act", act), + ("dense4", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense4_act", act), + ("dense5", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense5_act", act), + ("dense6", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense6_act", act), + ("dense7", MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)), + ("dense7_act", act), + ] + super().__init__(OrderedDict(layers)) + + def forward(self, x): + x = super().forward(x) + return x + + +class Truncation(nn.Module): + def __init__(self, avg_latent, max_layer=8, threshold=0.7): + super().__init__() + self.max_layer = max_layer + self.threshold = threshold + self.register_buffer("avg_latent", avg_latent) + + def forward(self, x): + assert x.dim() == 3 + interp = torch.lerp(self.avg_latent, x, self.threshold) + do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1) + return torch.where(do_trunc, interp, x) + + +class LayerEpilogue(nn.Module): + """Things to do at the end of each layer.""" + + def __init__( + self, + channels, + dlatent_size, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + activation_layer, + ): + super().__init__() + layers = [] + if use_noise: + self.noise = NoiseLayer(channels) + else: + self.noise = None + layers.append(("activation", activation_layer)) + if use_pixel_norm: + layers.append(("pixel_norm", PixelNormLayer())) + if use_instance_norm: + layers.append(("instance_norm", nn.InstanceNorm2d(channels))) + + self.top_epi = nn.Sequential(OrderedDict(layers)) + if use_styles: + self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale) + else: + self.style_mod = None + + def forward(self, x, dlatents_in_slice=None, noise_in_slice=None): + if self.noise is not None: + x = self.noise(x, noise=noise_in_slice) + x = self.top_epi(x) + if self.style_mod is not None: + x = self.style_mod(x, dlatents_in_slice) + else: + assert dlatents_in_slice is None + return x + + +class InputBlock(nn.Module): + def __init__( + self, + nf, + dlatent_size, + const_input_layer, + gain, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + activation_layer, + ): + super().__init__() + self.const_input_layer = const_input_layer + self.nf = nf + if self.const_input_layer: + # called 'const' in tf + self.const = nn.Parameter(torch.ones(1, nf, 4, 4)) + self.bias = nn.Parameter(torch.ones(nf)) + else: + # tweak gain to match the official implementation of Progressing GAN + self.dense = MyLinear(dlatent_size, nf * 16, gain=gain / 4, use_wscale=use_wscale) + self.epi1 = LayerEpilogue( + nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer + ) + self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale) + self.epi2 = LayerEpilogue( + nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer + ) + + def forward(self, dlatents_in_range, noise_in_range): + batch_size = dlatents_in_range.size(0) + if self.const_input_layer: + x = self.const.expand(batch_size, -1, -1, -1) + x = x + self.bias.view(1, -1, 1, 1) + else: + x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4) + x = self.epi1(x, dlatents_in_range[:, 0], noise_in_range[0]) + x = self.conv(x) + x = self.epi2(x, dlatents_in_range[:, 1], noise_in_range[1]) + return x + + +class GSynthesisBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + blur_filter, + dlatent_size, + gain, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + activation_layer, + ): + # 2**res x 2**res # res = 3..resolution_log2 + super().__init__() + if blur_filter: + blur = BlurLayer(blur_filter) + else: + blur = None + self.conv0_up = MyConv2d( + in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale, intermediate=blur, upscale=True + ) + self.epi1 = LayerEpilogue( + out_channels, + dlatent_size, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + activation_layer, + ) + self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale) + self.epi2 = LayerEpilogue( + out_channels, + dlatent_size, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + activation_layer, + ) + + def forward(self, x, dlatents_in_range, noise_in_range): + x = self.conv0_up(x) + x = self.epi1(x, dlatents_in_range[:, 0], noise_in_range[0]) + x = self.conv1(x) + x = self.epi2(x, dlatents_in_range[:, 1], noise_in_range[1]) + return x + + +class G_synthesis(nn.Module): + def __init__( + self, + # Disentangled latent (W) dimensionality. + dlatent_size=512, + num_channels=3, # Number of output color channels. + resolution=1024, # Output resolution. + # Overall multiplier for the number of feature maps. + fmap_base=8192, + # log2 feature map reduction when doubling the resolution. + fmap_decay=1.0, + # Maximum number of feature maps in any layer. + fmap_max=512, + use_styles=True, # Enable style inputs? + const_input_layer=True, # First layer is a learned constant? + use_noise=True, # Enable noise inputs? + # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables. + randomize_noise=True, + nonlinearity="lrelu", # Activation function: 'relu', 'lrelu' + use_wscale=True, # Enable equalized learning rate? + use_pixel_norm=False, # Enable pixelwise feature vector normalization? + use_instance_norm=True, # Enable instance normalization? + # Data type to use for activations and outputs. + dtype=torch.float32, + # Low-pass filter to apply when resampling activations. None = no filtering. + blur_filter=[1, 2, 1], + ): + + super().__init__() + + def nf(stage): + return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max) + + self.dlatent_size = dlatent_size + resolution_log2 = int(np.log2(resolution)) + assert resolution == 2 ** resolution_log2 and resolution >= 4 + + act, gain = {"relu": (torch.relu, np.sqrt(2)), "lrelu": (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[ + nonlinearity + ] + blocks = [] + for res in range(2, resolution_log2 + 1): + channels = nf(res - 1) + name = "{s}x{s}".format(s=2 ** res) + if res == 2: + blocks.append( + ( + name, + InputBlock( + channels, + dlatent_size, + const_input_layer, + gain, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + act, + ), + ) + ) + + else: + blocks.append( + ( + name, + GSynthesisBlock( + last_channels, + channels, + blur_filter, + dlatent_size, + gain, + use_wscale, + use_noise, + use_pixel_norm, + use_instance_norm, + use_styles, + act, + ), + ) + ) + last_channels = channels + self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale) + self.blocks = nn.ModuleDict(OrderedDict(blocks)) + + def forward(self, dlatents_in, noise_in): + # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size]. + # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype) + for i, m in enumerate(self.blocks.values()): + if i == 0: + x = m(dlatents_in[:, 2 * i : 2 * i + 2], noise_in[2 * i : 2 * i + 2]) + else: + x = m(x, dlatents_in[:, 2 * i : 2 * i + 2], noise_in[2 * i : 2 * i + 2]) + rgb = self.torgb(x) + return rgb diff --git a/diffusion-posterior-sampling/bkse/models/dsd/stylegan2.py b/diffusion-posterior-sampling/bkse/models/dsd/stylegan2.py new file mode 100644 index 0000000000000000000000000000000000000000..ed1e26280f0ea16bd67adcbe0f9bf23ab66cc2d5 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/dsd/stylegan2.py @@ -0,0 +1,621 @@ +import math +import random + +import torch +from models.dsd.op.fused_act import FusedLeakyReLU, fused_leaky_relu +from models.dsd.op.upfirdn2d import upfirdn2d +from torch import nn +from torch.nn import functional as F + + +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor ** 2) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer("kernel", kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer("kernel", kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = F.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," + f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" + + +class ModulatedConv2d(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size ** 2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " + f"upsample={self.upsample}, downsample={self.downsample})" + ) + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape( + batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size + ) + out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = F.conv2d(input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu")) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2 ** res, 2 ** res] + self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2 ** i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + ) + ) + + self.convs.append(StyledConv(out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel)) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn(n_latent, self.style_dim, device=self.input.input.device) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + truncation * (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs + ): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + ) + ) + + if activate: + layers.append(FusedLeakyReLU(out_channel, bias=bias)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/diffusion-posterior-sampling/bkse/models/kernel_encoding/base_model.py b/diffusion-posterior-sampling/bkse/models/kernel_encoding/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4b572b9d8ac373fc98191b28a591ee0ad2b53b8b --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/kernel_encoding/base_model.py @@ -0,0 +1,131 @@ +import os +from collections import OrderedDict + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + + +class BaseModel: + def __init__(self, opt): + self.opt = opt + self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") + self.is_train = opt["is_train"] + self.schedulers = [] + self.optimizers = [] + + def feed_data(self, data): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + pass + + def get_current_losses(self): + pass + + def print_network(self): + pass + + def save(self, label): + pass + + def load(self): + pass + + def _set_lr(self, lr_groups_l): + """Set learning rate for warmup + lr_groups_l: list for lr_groups. each for a optimizer""" + for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): + for param_group, lr in zip(optimizer.param_groups, lr_groups): + param_group["lr"] = lr + + def _get_init_lr(self): + """Get the initial lr, which is set by the scheduler""" + init_lr_groups_l = [] + for optimizer in self.optimizers: + init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) + return init_lr_groups_l + + def update_learning_rate(self, cur_iter, warmup_iter=-1): + for scheduler in self.schedulers: + scheduler.step() + # set up warm-up learning rate + if cur_iter < warmup_iter: + # get initial lr for each group + init_lr_g_l = self._get_init_lr() + # modify warming-up learning rates + warm_up_lr_l = [] + for init_lr_g in init_lr_g_l: + warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) + # set learning rate + self._set_lr(warm_up_lr_l) + + def get_current_learning_rate(self): + return [param_group["lr"] for param_group in self.optimizers[0].param_groups] + + def get_network_description(self, network): + """Get the string and total parameters of the network""" + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + return str(network), sum(map(lambda x: x.numel(), network.parameters())) + + def save_network(self, network, network_label, iter_label): + save_filename = "{}_{}.pth".format(iter_label, network_label) + save_path = os.path.join(self.opt["path"]["models"], save_filename) + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + state_dict = network.state_dict() + for key, param in state_dict.items(): + state_dict[key] = param.cpu() + torch.save(state_dict, save_path) + + def load_network(self, load_path, network, strict=True, prefix=""): + if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): + network = network.module + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith("module."): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + load_net.update(load_net_clean) + + model_dict = network.state_dict() + for k, v in load_net.items(): + k = prefix + k + if (k in model_dict) and (v.shape == model_dict[k].shape): + model_dict[k] = v + else: + print("Load failed:", k) + + network.load_state_dict(model_dict, strict=True) + + def save_training_state(self, epoch, iter_step): + """ + Save training state during training, + which will be used for resuming + """ + + state = {"epoch": epoch, "iter": iter_step, "schedulers": [], "optimizers": []} + for s in self.schedulers: + state["schedulers"].append(s.state_dict()) + for o in self.optimizers: + state["optimizers"].append(o.state_dict()) + save_filename = "{}.state".format(iter_step) + save_path = os.path.join(self.opt["path"]["training_state"], save_filename) + torch.save(state, save_path) + + def resume_training(self, resume_state): + """Resume the optimizers and schedulers for training""" + resume_optimizers = resume_state["optimizers"] + resume_schedulers = resume_state["schedulers"] + assert len(resume_optimizers) == len(self.optimizers), "Wrong lengths of optimizers" + assert len(resume_schedulers) == len(self.schedulers), "Wrong lengths of schedulers" + for i, o in enumerate(resume_optimizers): + self.optimizers[i].load_state_dict(o) + for i, s in enumerate(resume_schedulers): + self.schedulers[i].load_state_dict(s) diff --git a/diffusion-posterior-sampling/bkse/models/kernel_encoding/image_base_model.py b/diffusion-posterior-sampling/bkse/models/kernel_encoding/image_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fa9543b66211b3fba6da7c3592e8bee01bacc4 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/kernel_encoding/image_base_model.py @@ -0,0 +1,194 @@ +import logging +from collections import OrderedDict + +import models.lr_scheduler as lr_scheduler +import torch +import torch.nn as nn +from models.kernel_encoding.base_model import BaseModel +from models.kernel_encoding.kernel_wizard import KernelWizard +from models.losses.charbonnier_loss import CharbonnierLoss +from torch.nn.parallel import DataParallel, DistributedDataParallel + + +logger = logging.getLogger("base") + + +class ImageBaseModel(BaseModel): + def __init__(self, opt): + super(ImageBaseModel, self).__init__(opt) + + if opt["dist"]: + self.rank = torch.distributed.get_rank() + else: + self.rank = -1 # non dist training + train_opt = opt["train"] + + # define network and load pretrained models + self.netG = KernelWizard(opt["KernelWizard"]).to(self.device) + self.use_vae = opt["KernelWizard"]["use_vae"] + if opt["dist"]: + self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) + else: + self.netG = DataParallel(self.netG) + # print network + self.print_network() + self.load() + + if self.is_train: + self.netG.train() + + # loss + loss_type = train_opt["pixel_criterion"] + if loss_type == "l1": + self.cri_pix = nn.L1Loss(reduction="sum").to(self.device) + elif loss_type == "l2": + self.cri_pix = nn.MSELoss(reduction="sum").to(self.device) + elif loss_type == "cb": + self.cri_pix = CharbonnierLoss().to(self.device) + else: + raise NotImplementedError( + "Loss type [{:s}] is not\ + recognized.".format( + loss_type + ) + ) + self.l_pix_w = train_opt["pixel_weight"] + self.l_kl_w = train_opt["kl_weight"] + + # optimizers + wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0 + params = [] + for k, v in self.netG.named_parameters(): + if v.requires_grad: + params.append(v) + else: + if self.rank <= 0: + logger.warning( + "Params [{:s}] will not\ + optimize.".format( + k + ) + ) + optim_params = [ + {"params": params, "lr": train_opt["lr_G"]}, + ] + + self.optimizer_G = torch.optim.Adam( + optim_params, lr=train_opt["lr_G"], weight_decay=wd_G, betas=(train_opt["beta1"], train_opt["beta2"]) + ) + self.optimizers.append(self.optimizer_G) + + # schedulers + if train_opt["lr_scheme"] == "MultiStepLR": + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.MultiStepLR_Restart( + optimizer, + train_opt["lr_steps"], + restarts=train_opt["restarts"], + weights=train_opt["restart_weights"], + gamma=train_opt["lr_gamma"], + clear_state=train_opt["clear_state"], + ) + ) + elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart": + for optimizer in self.optimizers: + self.schedulers.append( + lr_scheduler.CosineAnnealingLR_Restart( + optimizer, + train_opt["T_period"], + eta_min=train_opt["eta_min"], + restarts=train_opt["restarts"], + weights=train_opt["restart_weights"], + ) + ) + else: + raise NotImplementedError() + + self.log_dict = OrderedDict() + + def feed_data(self, data, need_GT=True): + self.LQ = data["LQ"].to(self.device) + self.HQ = data["HQ"].to(self.device) + + def set_params_lr_zero(self, groups): + # fix normal module + for group in groups: + self.optimizers[0].param_groups[group]["lr"] = 0 + + def optimize_parameters(self, step): + batchsz, _, _, _ = self.LQ.shape + + self.optimizer_G.zero_grad() + kernel_mean, kernel_sigma = self.netG(self.HQ, self.LQ) + + kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean) + self.fake_LQ = self.netG.module.adaptKernel(self.HQ, kernel) + + l_pix = self.l_pix_w * self.cri_pix(self.fake_LQ, self.LQ) + l_total = l_pix + + if self.use_vae: + KL_divergence = ( + self.l_kl_w + * torch.sum( + torch.pow(kernel_mean, 2) + + torch.pow(kernel_sigma, 2) + - torch.log(1e-8 + torch.pow(kernel_sigma, 2)) + - 1 + ).sum() + ) + l_total += KL_divergence + self.log_dict["l_KL"] = KL_divergence.item() / batchsz + + l_total.backward() + self.optimizer_G.step() + + # set log + self.log_dict["l_pix"] = l_pix.item() / batchsz + self.log_dict["l_total"] = l_total.item() / batchsz + + def test(self): + self.netG.eval() + with torch.no_grad(): + self.fake_H = self.netG(self.var_L) + self.netG.train() + + def get_current_log(self): + return self.log_dict + + def get_current_visuals(self, need_GT=True): + out_dict = OrderedDict() + out_dict["LQ"] = self.LQ.detach()[0].float().cpu() + out_dict["rlt"] = self.fake_LQ.detach()[0].float().cpu() + return out_dict + + def print_network(self): + s, n = self.get_network_description(self.netG) + if isinstance(self.netG, nn.DataParallel): + net_struc_str = "{} - {}".format(self.netG.__class__.__name__, self.netG.module.__class__.__name__) + else: + net_struc_str = "{}".format(self.netG.__class__.__name__) + if self.rank <= 0: + logger.info( + "Network G structure: {}, \ + with parameters: {:,d}".format( + net_struc_str, n + ) + ) + logger.info(s) + + def load(self): + if self.opt["path"]["pretrain_model_G"]: + load_path_G = self.opt["path"]["pretrain_model_G"] + if load_path_G is not None: + logger.info( + "Loading model for G [{:s}]\ + ...".format( + load_path_G + ) + ) + self.load_network(load_path_G, self.netG, self.opt["path"]["strict_load"]) + + def save(self, iter_label): + self.save_network(self.netG, "G", iter_label) diff --git a/diffusion-posterior-sampling/bkse/models/kernel_encoding/kernel_wizard.py b/diffusion-posterior-sampling/bkse/models/kernel_encoding/kernel_wizard.py new file mode 100644 index 0000000000000000000000000000000000000000..dbdd0a9e73843920229ce9d6c4b17ae5c1a0b096 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/kernel_encoding/kernel_wizard.py @@ -0,0 +1,168 @@ +import functools + +import models.arch_util as arch_util +import torch +import torch.nn as nn +from models.backbones.resnet import ResidualBlock_noBN, ResnetBlock +from models.backbones.unet_parts import UnetSkipConnectionBlock + + +# The function F in the paper +class KernelExtractor(nn.Module): + def __init__(self, opt): + super(KernelExtractor, self).__init__() + + nf = opt["nf"] + self.kernel_dim = opt["kernel_dim"] + self.use_sharp = opt["KernelExtractor"]["use_sharp"] + self.use_vae = opt["use_vae"] + + # Blur estimator + norm_layer = arch_util.get_norm_layer(opt["KernelExtractor"]["norm"]) + n_blocks = opt["KernelExtractor"]["n_blocks"] + padding_type = opt["KernelExtractor"]["padding_type"] + use_dropout = opt["KernelExtractor"]["use_dropout"] + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + input_nc = nf * 2 if self.use_sharp else nf + output_nc = self.kernel_dim * 2 if self.use_vae else self.kernel_dim + + model = [ + nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, nf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(nf), + nn.ReLU(True), + ] + + n_downsampling = 5 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + inc = min(nf * mult, output_nc) + ouc = min(nf * mult * 2, output_nc) + model += [ + nn.Conv2d(inc, ouc, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(nf * mult * 2), + nn.ReLU(True), + ] + + for i in range(n_blocks): # add ResNet blocks + model += [ + ResnetBlock( + output_nc, + padding_type=padding_type, + norm_layer=norm_layer, + use_dropout=use_dropout, + use_bias=use_bias, + ) + ] + + self.model = nn.Sequential(*model) + + def forward(self, sharp, blur): + output = self.model(torch.cat((sharp, blur), dim=1)) + if self.use_vae: + return output[:, : self.kernel_dim, :, :], output[:, self.kernel_dim :, :, :] + + return output, torch.zeros_like(output).cuda() + + +# The function G in the paper +class KernelAdapter(nn.Module): + def __init__(self, opt): + super(KernelAdapter, self).__init__() + input_nc = opt["nf"] + output_nc = opt["nf"] + ngf = opt["nf"] + norm_layer = arch_util.get_norm_layer(opt["Adapter"]["norm"]) + + # construct unet structure + unet_block = UnetSkipConnectionBlock( + ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True + ) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock( + ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock( + ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer + ) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock( + output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer + ) + + def forward(self, x, k): + """Standard forward""" + return self.model(x, k) + + +class KernelWizard(nn.Module): + def __init__(self, opt): + super(KernelWizard, self).__init__() + lrelu = nn.LeakyReLU(negative_slope=0.1) + front_RBs = opt["front_RBs"] + back_RBs = opt["back_RBs"] + num_image_channels = opt["input_nc"] + nf = opt["nf"] + + # Features extraction + resBlock_noBN_f = functools.partial(ResidualBlock_noBN, nf=nf) + feature_extractor = [] + + feature_extractor.append(nn.Conv2d(num_image_channels, nf, 3, 1, 1, bias=True)) + feature_extractor.append(lrelu) + feature_extractor.append(nn.Conv2d(nf, nf, 3, 2, 1, bias=True)) + feature_extractor.append(lrelu) + feature_extractor.append(nn.Conv2d(nf, nf, 3, 2, 1, bias=True)) + feature_extractor.append(lrelu) + + for i in range(front_RBs): + feature_extractor.append(resBlock_noBN_f()) + + self.feature_extractor = nn.Sequential(*feature_extractor) + + # Kernel extractor + self.kernel_extractor = KernelExtractor(opt) + + # kernel adapter + self.adapter = KernelAdapter(opt) + + # Reconstruction + recon_trunk = [] + for i in range(back_RBs): + recon_trunk.append(resBlock_noBN_f()) + + # upsampling + recon_trunk.append(nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)) + recon_trunk.append(nn.PixelShuffle(2)) + recon_trunk.append(lrelu) + recon_trunk.append(nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)) + recon_trunk.append(nn.PixelShuffle(2)) + recon_trunk.append(lrelu) + recon_trunk.append(nn.Conv2d(64, 64, 3, 1, 1, bias=True)) + recon_trunk.append(lrelu) + recon_trunk.append(nn.Conv2d(64, num_image_channels, 3, 1, 1, bias=True)) + + self.recon_trunk = nn.Sequential(*recon_trunk) + + def adaptKernel(self, x_sharp, kernel): + B, C, H, W = x_sharp.shape + base = x_sharp + + x_sharp = self.feature_extractor(x_sharp) + + out = self.adapter(x_sharp, kernel) + out = self.recon_trunk(out) + out += base + + return out + + def forward(self, x_sharp, x_blur): + x_sharp = self.feature_extractor(x_sharp) + x_blur = self.feature_extractor(x_blur) + + output = self.kernel_extractor(x_sharp, x_blur) + return output diff --git a/diffusion-posterior-sampling/bkse/models/losses/charbonnier_loss.py b/diffusion-posterior-sampling/bkse/models/losses/charbonnier_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d051cb8772bb22fc2dbe3e246c9bdda48b522e18 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/losses/charbonnier_loss.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn + + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-6): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + return loss diff --git a/diffusion-posterior-sampling/bkse/models/losses/dsd_loss.py b/diffusion-posterior-sampling/bkse/models/losses/dsd_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf4660dc5f3d088bcf926866914ca0790348c5e --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/losses/dsd_loss.py @@ -0,0 +1,129 @@ +import torch +from models.dsd.bicubic import BicubicDownSample +from models.kernel_encoding.kernel_wizard import KernelWizard +from models.losses.ssim_loss import SSIM + + +class LossBuilder(torch.nn.Module): + def __init__(self, ref_im, opt): + super(LossBuilder, self).__init__() + assert ref_im.shape[2] == ref_im.shape[3] + self.ref_im = ref_im + loss_str = opt["loss_str"] + self.parsed_loss = [loss_term.split("*") for loss_term in loss_str.split("+")] + self.eps = opt["eps"] + + self.ssim = SSIM().cuda() + + self.D = KernelWizard(opt["KernelWizard"]).cuda() + self.D.load_state_dict(torch.load(opt["KernelWizard"]["pretrained"])) + for v in self.D.parameters(): + v.requires_grad = False + + # Takes a list of tensors, flattens them, and concatenates them into a vector + # Used to calculate euclidian distance between lists of tensors + def flatcat(self, l): + l = l if (isinstance(l, list)) else [l] + return torch.cat([x.flatten() for x in l], dim=0) + + def _loss_l2(self, gen_im_lr, ref_im, **kwargs): + return (gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum() + + def _loss_l1(self, gen_im_lr, ref_im, **kwargs): + return 10 * ((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum()) + + # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors + def _loss_geocross(self, latent, **kwargs): + pass + + +class LossBuilderStyleGAN(LossBuilder): + def __init__(self, ref_im, opt): + super(LossBuilderStyleGAN, self).__init__(ref_im, opt) + im_size = ref_im.shape[2] + factor = opt["output_size"] // im_size + assert im_size * factor == opt["output_size"] + self.bicub = BicubicDownSample(factor=factor) + + # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors + def _loss_geocross(self, latent, **kwargs): + if latent.shape[1] == 1: + return 0 + else: + X = latent.view(-1, 1, 18, 512) + Y = latent.view(-1, 18, 1, 512) + A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() + B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() + D = 2 * torch.atan2(A, B) + D = ((D.pow(2) * 512).mean((1, 2)) / 8.0).sum() + return D + + def forward(self, latent, gen_im, kernel, step): + var_dict = { + "latent": latent, + "gen_im_lr": self.D.adaptKernel(self.bicub(gen_im), kernel), + "ref_im": self.ref_im, + } + loss = 0 + loss_fun_dict = { + "L2": self._loss_l2, + "L1": self._loss_l1, + "GEOCROSS": self._loss_geocross, + } + losses = {} + + for weight, loss_type in self.parsed_loss: + tmp_loss = loss_fun_dict[loss_type](**var_dict) + losses[loss_type] = tmp_loss + loss += float(weight) * tmp_loss + loss += 5e-5 * torch.norm(kernel) + losses["Norm"] = torch.norm(kernel) + + return loss, losses + + def get_blur_img(self, sharp_img, kernel): + return self.D.adaptKernel(self.bicub(sharp_img), kernel).cpu().detach().clamp(0, 1) + + +class LossBuilderStyleGAN2(LossBuilder): + def __init__(self, ref_im, opt): + super(LossBuilderStyleGAN2, self).__init__(ref_im, opt) + + # Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors + def _loss_geocross(self, latent, **kwargs): + if latent.shape[1] == 1: + return 0 + else: + X = latent.view(-1, 1, 14, 512) + Y = latent.view(-1, 14, 1, 512) + A = ((X - Y).pow(2).sum(-1) + 1e-9).sqrt() + B = ((X + Y).pow(2).sum(-1) + 1e-9).sqrt() + D = 2 * torch.atan2(A, B) + D = ((D.pow(2) * 512).mean((1, 2)) / 6.0).sum() + return D + + def forward(self, latent, gen_im, kernel, step): + var_dict = { + "latent": latent, + "gen_im_lr": self.D.adaptKernel(gen_im, kernel), + "ref_im": self.ref_im, + } + loss = 0 + loss_fun_dict = { + "L2": self._loss_l2, + "L1": self._loss_l1, + "GEOCROSS": self._loss_geocross, + } + losses = {} + + for weight, loss_type in self.parsed_loss: + tmp_loss = loss_fun_dict[loss_type](**var_dict) + losses[loss_type] = tmp_loss + loss += float(weight) * tmp_loss + loss += 1e-4 * torch.norm(kernel) + losses["Norm"] = torch.norm(kernel) + + return loss, losses + + def get_blur_img(self, sharp_img, kernel): + return self.D.adaptKernel(sharp_img, kernel).cpu().detach().clamp(0, 1) diff --git a/diffusion-posterior-sampling/bkse/models/losses/gan_loss.py b/diffusion-posterior-sampling/bkse/models/losses/gan_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..28416a187cf06de1002b397070278cce52ddcdb7 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/losses/gan_loss.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + + +# Define GAN loss: [vanilla | lsgan | wgan-gp] +class GANLoss(nn.Module): + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type.lower() + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == "gan" or self.gan_type == "ragan": + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == "lsgan": + self.loss = nn.MSELoss() + elif self.gan_type == "wgan-gp": + + def wgan_loss(input, target): + # target is boolean + return -1 * input.mean() if target else input.mean() + + self.loss = wgan_loss + else: + raise NotImplementedError("GAN type [{:s}] is not found".format(self.gan_type)) + + def get_target_label(self, input, target_is_real): + if self.gan_type == "wgan-gp": + return target_is_real + if target_is_real: + return torch.empty_like(input).fill_(self.real_label_val) + else: + return torch.empty_like(input).fill_(self.fake_label_val) + + def forward(self, input, target_is_real): + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + return loss diff --git a/diffusion-posterior-sampling/bkse/models/losses/hyper_laplacian_penalty.py b/diffusion-posterior-sampling/bkse/models/losses/hyper_laplacian_penalty.py new file mode 100644 index 0000000000000000000000000000000000000000..87c42ddffb4a80c31517243c8b66763def65d3eb --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/losses/hyper_laplacian_penalty.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class HyperLaplacianPenalty(nn.Module): + def __init__(self, num_channels, alpha, eps=1e-6): + super(HyperLaplacianPenalty, self).__init__() + + self.alpha = alpha + self.eps = eps + + self.Kx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).cuda() + self.Kx = self.Kx.expand(1, num_channels, 3, 3) + self.Kx.requires_grad = False + self.Ky = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda() + self.Ky = self.Ky.expand(1, num_channels, 3, 3) + self.Ky.requires_grad = False + + def forward(self, x): + gradX = F.conv2d(x, self.Kx, stride=1, padding=1) + gradY = F.conv2d(x, self.Ky, stride=1, padding=1) + grad = torch.sqrt(gradX ** 2 + gradY ** 2 + self.eps) + + loss = (grad ** self.alpha).mean() + + return loss diff --git a/diffusion-posterior-sampling/bkse/models/losses/perceptual_loss.py b/diffusion-posterior-sampling/bkse/models/losses/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..75e2f103c947dbaa96765825014d8593c6e16c94 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/losses/perceptual_loss.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import torchvision.models as models + + +class StyleLoss(nn.Module): + r""" + Perceptual loss, VGG-based + https://arxiv.org/abs/1603.08155 + https://github.com/dxyang/StyleTransfer/blob/master/utils.py + """ + + def __init__(self): + super(StyleLoss, self).__init__() + self.add_module("vgg", VGG19()) + self.criterion = torch.nn.L1Loss() + + def compute_gram(self, x): + b, ch, h, w = x.size() + f = x.view(b, ch, w * h) + f_T = f.transpose(1, 2) + G = f.bmm(f_T) / (h * w * ch) + + return G + + def __call__(self, x, y): + # Compute features + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + + # Compute loss + style_loss = 0.0 + style_loss += self.criterion(self.compute_gram(x_vgg["relu2_2"]), self.compute_gram(y_vgg["relu2_2"])) + style_loss += self.criterion(self.compute_gram(x_vgg["relu3_4"]), self.compute_gram(y_vgg["relu3_4"])) + style_loss += self.criterion(self.compute_gram(x_vgg["relu4_4"]), self.compute_gram(y_vgg["relu4_4"])) + style_loss += self.criterion(self.compute_gram(x_vgg["relu5_2"]), self.compute_gram(y_vgg["relu5_2"])) + + return style_loss + + +class PerceptualLoss(nn.Module): + r""" + Perceptual loss, VGG-based + https://arxiv.org/abs/1603.08155 + https://github.com/dxyang/StyleTransfer/blob/master/utils.py + """ + + def __init__(self, weights=[0.2, 0.4, 0.8, 1.0, 3.0]): + super(PerceptualLoss, self).__init__() + self.add_module("vgg", VGG19()) + self.criterion = torch.nn.L1Loss() + self.weights = weights + + def __call__(self, x, y): + # Compute features + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + + content_loss = 0.0 + content_loss += self.weights[0] * self.criterion(x_vgg["relu1_1"], y_vgg["relu1_1"]) + content_loss += self.weights[1] * self.criterion(x_vgg["relu2_1"], y_vgg["relu2_1"]) + content_loss += self.weights[2] * self.criterion(x_vgg["relu3_1"], y_vgg["relu3_1"]) + content_loss += self.weights[3] * self.criterion(x_vgg["relu4_1"], y_vgg["relu4_1"]) + content_loss += self.weights[4] * self.criterion(x_vgg["relu5_1"], y_vgg["relu5_1"]) + + return content_loss + + +class VGG19(torch.nn.Module): + def __init__(self): + super(VGG19, self).__init__() + features = models.vgg19(pretrained=True).features + self.relu1_1 = torch.nn.Sequential() + self.relu1_2 = torch.nn.Sequential() + + self.relu2_1 = torch.nn.Sequential() + self.relu2_2 = torch.nn.Sequential() + + self.relu3_1 = torch.nn.Sequential() + self.relu3_2 = torch.nn.Sequential() + self.relu3_3 = torch.nn.Sequential() + self.relu3_4 = torch.nn.Sequential() + + self.relu4_1 = torch.nn.Sequential() + self.relu4_2 = torch.nn.Sequential() + self.relu4_3 = torch.nn.Sequential() + self.relu4_4 = torch.nn.Sequential() + + self.relu5_1 = torch.nn.Sequential() + self.relu5_2 = torch.nn.Sequential() + self.relu5_3 = torch.nn.Sequential() + self.relu5_4 = torch.nn.Sequential() + + for x in range(2): + self.relu1_1.add_module(str(x), features[x]) + + for x in range(2, 4): + self.relu1_2.add_module(str(x), features[x]) + + for x in range(4, 7): + self.relu2_1.add_module(str(x), features[x]) + + for x in range(7, 9): + self.relu2_2.add_module(str(x), features[x]) + + for x in range(9, 12): + self.relu3_1.add_module(str(x), features[x]) + + for x in range(12, 14): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(14, 16): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(16, 18): + self.relu3_4.add_module(str(x), features[x]) + + for x in range(18, 21): + self.relu4_1.add_module(str(x), features[x]) + + for x in range(21, 23): + self.relu4_2.add_module(str(x), features[x]) + + for x in range(23, 25): + self.relu4_3.add_module(str(x), features[x]) + + for x in range(25, 27): + self.relu4_4.add_module(str(x), features[x]) + + for x in range(27, 30): + self.relu5_1.add_module(str(x), features[x]) + + for x in range(30, 32): + self.relu5_2.add_module(str(x), features[x]) + + for x in range(32, 34): + self.relu5_3.add_module(str(x), features[x]) + + for x in range(34, 36): + self.relu5_4.add_module(str(x), features[x]) + + # don't need the gradients, just want the features + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + relu1_1 = self.relu1_1(x) + relu1_2 = self.relu1_2(relu1_1) + + relu2_1 = self.relu2_1(relu1_2) + relu2_2 = self.relu2_2(relu2_1) + + relu3_1 = self.relu3_1(relu2_2) + relu3_2 = self.relu3_2(relu3_1) + relu3_3 = self.relu3_3(relu3_2) + relu3_4 = self.relu3_4(relu3_3) + + relu4_1 = self.relu4_1(relu3_4) + relu4_2 = self.relu4_2(relu4_1) + relu4_3 = self.relu4_3(relu4_2) + relu4_4 = self.relu4_4(relu4_3) + + relu5_1 = self.relu5_1(relu4_4) + relu5_2 = self.relu5_2(relu5_1) + relu5_3 = self.relu5_3(relu5_2) + relu5_4 = self.relu5_4(relu5_3) + + out = { + "relu1_1": relu1_1, + "relu1_2": relu1_2, + "relu2_1": relu2_1, + "relu2_2": relu2_2, + "relu3_1": relu3_1, + "relu3_2": relu3_2, + "relu3_3": relu3_3, + "relu3_4": relu3_4, + "relu4_1": relu4_1, + "relu4_2": relu4_2, + "relu4_3": relu4_3, + "relu4_4": relu4_4, + "relu5_1": relu5_1, + "relu5_2": relu5_2, + "relu5_3": relu5_3, + "relu5_4": relu5_4, + } + return out diff --git a/diffusion-posterior-sampling/bkse/models/losses/ssim_loss.py b/diffusion-posterior-sampling/bkse/models/losses/ssim_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..dac51530ed1e6efb98021666a41c394233dbbf53 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/losses/ssim_loss.py @@ -0,0 +1,66 @@ +from math import exp + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + + +class SSIM(torch.nn.Module): + @staticmethod + def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + @staticmethod + def create_window(window_size, channel): + _1D_window = SSIM.gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + @staticmethod + def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = self.create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = self.create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) diff --git a/diffusion-posterior-sampling/bkse/models/lr_scheduler.py b/diffusion-posterior-sampling/bkse/models/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f40dd177b645981fb65eafb235c2d91f0d169f58 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/models/lr_scheduler.py @@ -0,0 +1,162 @@ +import math +from collections import Counter, defaultdict + +import torch +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepLR_Restart(_LRScheduler): + def __init__( + self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1 + ): + self.milestones = Counter(milestones) + self.gamma = gamma + self.clear_state = clear_state + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + assert len(self.restarts) == len(self.restart_weights), "restarts and their weights do not match." + super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + if self.clear_state: + self.optimizer.state = defaultdict(dict) + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group["initial_lr"] * weight for group in self.optimizer.param_groups] + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups] + + +class CosineAnnealingLR_Restart(_LRScheduler): + def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): + self.T_period = T_period + self.T_max = self.T_period[0] # current T period + self.eta_min = eta_min + self.restarts = restarts if restarts else [0] + self.restarts = [v + 1 for v in self.restarts] + self.restart_weights = weights if weights else [1] + self.last_restart = 0 + assert len(self.restarts) == len(self.restart_weights), "restarts and their weights do not match." + super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lrs + elif self.last_epoch in self.restarts: + self.last_restart = self.last_epoch + self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [group["initial_lr"] * weight for group in self.optimizer.param_groups] + elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) + / (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + +if __name__ == "__main__": + optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, betas=(0.9, 0.99)) + ############################## + # MultiStepLR_Restart + ############################## + # Original + lr_steps = [200000, 400000, 600000, 800000] + restarts = None + restart_weights = None + + # two + lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] + restarts = [500000] + restart_weights = [1] + + # four + lr_steps = [ + 50000, + 100000, + 150000, + 200000, + 240000, + 300000, + 350000, + 400000, + 450000, + 490000, + 550000, + 600000, + 650000, + 700000, + 740000, + 800000, + 850000, + 900000, + 950000, + 990000, + ] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, clear_state=False) + + ############################## + # Cosine Annealing Restart + ############################## + # two + T_period = [500000, 500000] + restarts = [500000] + restart_weights = [1] + + # four + T_period = [250000, 250000, 250000, 250000] + restarts = [250000, 500000, 750000] + restart_weights = [1, 1, 1] + + scheduler = CosineAnnealingLR_Restart( + optimizer, T_period, eta_min=1e-7, restarts=restarts, weights=restart_weights + ) + + ############################## + # Draw figure + ############################## + N_iter = 1000000 + lr_l = list(range(N_iter)) + for i in range(N_iter): + scheduler.step() + current_lr = optimizer.param_groups[0]["lr"] + lr_l[i] = current_lr + + import matplotlib as mpl + import matplotlib.ticker as mtick + from matplotlib import pyplot as plt + + mpl.style.use("default") + import seaborn + + seaborn.set(style="whitegrid") + seaborn.set_context("paper") + + plt.figure(1) + plt.subplot(111) + plt.ticklabel_format(style="sci", axis="x", scilimits=(0, 0)) + plt.title("Title", fontsize=16, color="k") + plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label="learning rate scheme") + legend = plt.legend(loc="upper right", shadow=False) + ax = plt.gca() + labels = ax.get_xticks().tolist() + for k, v in enumerate(labels): + labels[k] = str(int(v / 1000)) + "K" + ax.set_xticklabels(labels) + ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.1e")) + + ax.set_ylabel("Learning rate") + ax.set_xlabel("Iteration") + fig = plt.gcf() + plt.show() diff --git a/diffusion-posterior-sampling/bkse/options/__init__.py b/diffusion-posterior-sampling/bkse/options/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusion-posterior-sampling/bkse/options/data_augmentation/default.yml b/diffusion-posterior-sampling/bkse/options/data_augmentation/default.yml new file mode 100644 index 0000000000000000000000000000000000000000..0e93dcf0e329cc2b008d09484f714a4f8715b47c --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/data_augmentation/default.yml @@ -0,0 +1,22 @@ +#### general settings +gpu_ids: [0] + +#### network structures +KernelWizard: + pretrained: experiments/pretrained/GOPRO_woVAE.pth + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false diff --git a/diffusion-posterior-sampling/bkse/options/domain_specific_deblur/stylegan.yml b/diffusion-posterior-sampling/bkse/options/domain_specific_deblur/stylegan.yml new file mode 100644 index 0000000000000000000000000000000000000000..7799e98b473e81212d3c2190d08967017cc2fd84 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/domain_specific_deblur/stylegan.yml @@ -0,0 +1,52 @@ +stylegan_ver: 1 +img_size: &HQ_SIZE [256, 256] +output_size: 1024 +verbose: true +num_epochs: 25 +num_warmup_iters: 150 +num_x_iters: 300 +num_k_iters: 200 +x_lr: !!float 0.2 +k_lr: !!float 1e-4 +warmup_k_path: experiments/pretrained/kernel.pth +reg_noise_std: !!float 0.001 +duplicates: 1 +batch_size: 1 +loss_str: '100*L2+0.1*GEOCROSS' +eps: !!float 1e-15 +noise_type: trainable +num_trainable_noise_layers: 5 +bad_noise_layers: '17' +optimizer_name: adam +lr_schedule: linear1cycledrop +save_intermediate: true +tile_latent: ~ +seed: ~ + +KernelDIP: + nf: 64 + n_blocks: 6 + padding_type: reflect + use_dropout: false + kernel_dim: 512 + norm: none + +KernelWizard: + pretrained: experiments/pretrained/GOPRO_woVAE.pth + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + img_size: *HQ_SIZE + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false diff --git a/diffusion-posterior-sampling/bkse/options/domain_specific_deblur/stylegan2.yml b/diffusion-posterior-sampling/bkse/options/domain_specific_deblur/stylegan2.yml new file mode 100644 index 0000000000000000000000000000000000000000..1e1f81452bd48d51fd8151b8b89cafc962fde871 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/domain_specific_deblur/stylegan2.yml @@ -0,0 +1,59 @@ +stylegan_ver: 2 +img_size: &HQ_SIZE [256, 256] +output_size: 256 +verbose: true +num_epochs: 25 +num_warmup_iters: 150 +num_x_iters: 300 +num_k_iters: 200 +x_lr: !!float 0.2 +k_lr: !!float 5e-4 +warmup_k_path: experiments/pretrained/kernel.pth +reg_noise_std: !!float 0.001 +duplicates: 1 +batch_size: 1 +loss_str: '100*L2+0.1*GEOCROSS' +eps: !!float 1e-15 +noise_type: trainable +num_trainable_noise_layers: 5 +bad_noise_layers: '17' +optimizer_name: adam +lr_schedule: linear1cycledrop +save_intermediate: true +tile_latent: ~ +seed: ~ + +ImageDIP: + input_nc: 8 + output_nc: 3 + nf: 64 + norm: none + padding_type: reflect + +KernelDIP: + nf: 64 + n_blocks: 6 + padding_type: reflect + use_dropout: false + kernel_dim: 512 + norm: none + +KernelWizard: + pretrained: experiments/pretrained/GOPRO_woVAE.pth + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + img_size: *HQ_SIZE + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false diff --git a/diffusion-posterior-sampling/bkse/options/generate_blur/default.yml b/diffusion-posterior-sampling/bkse/options/generate_blur/default.yml new file mode 100644 index 0000000000000000000000000000000000000000..13f1712a1104627defe5609fc05bef23d7c70ec1 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/generate_blur/default.yml @@ -0,0 +1,22 @@ +#### general settings +gpu_ids: [0] + +#### network structures +KernelWizard: + pretrained: experiments/pretrained/GOPRO_wVAE.pth + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + use_vae: true + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false diff --git a/diffusion-posterior-sampling/bkse/options/generic_deblur/default.yml b/diffusion-posterior-sampling/bkse/options/generic_deblur/default.yml new file mode 100644 index 0000000000000000000000000000000000000000..aadb06e33b495f806ef320fdeaffb46969eeda35 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/generic_deblur/default.yml @@ -0,0 +1,42 @@ +num_iters: 5000 +num_warmup_iters: 300 +x_lr: !!float 5e-4 +k_lr: !!float 5e-4 +img_size: &HQ_SIZE [256, 256] +warmup_k_path: experiments/pretrained/kernel.pth +reg_noise_std: !!float 0.001 + +ImageDIP: + input_nc: 8 + output_nc: 3 + nf: 64 + norm: none + padding_type: reflect + +KernelDIP: + nf: 64 + n_blocks: 6 + padding_type: reflect + use_dropout: false + kernel_dim: 512 + norm: none + +KernelWizard: + pretrained: experiments/pretrained/GOPRO_woVAE.pth + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + img_size: *HQ_SIZE + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false diff --git a/diffusion-posterior-sampling/bkse/options/kernel_encoding/GOPRO/wVAE.yml b/diffusion-posterior-sampling/bkse/options/kernel_encoding/GOPRO/wVAE.yml new file mode 100644 index 0000000000000000000000000000000000000000..7fadc1e52ec3f1017a078b00ac74f0690cc69779 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/kernel_encoding/GOPRO/wVAE.yml @@ -0,0 +1,77 @@ +#### general settings +name: GOPRO_VAE +use_tb_logger: true +model: image_base +distortion: deblur +scale: 1 +gpu_ids: [0] + +#### datasets +datasets: + train: + name: GOPRO + mode: GOPRO + interval_list: [1] + dataroot_HQ: datasets/GOPRO/train_sharp.lmdb + dataroot_LQ: datasets/GOPRO/train_blur_linear.lmdb + cache_keys: ~ + + use_shuffle: true + n_workers: 4 # per GPU + batch_size: 8 + HQ_size: &HQ_SIZE 256 + LQ_size: 256 + use_flip: true + use_rot: true + color: RGB + +#### network structures +KernelWizard: + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + img_size: *HQ_SIZE + use_vae: true + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false + +#### path +path: + pretrain_model_G: experiments/pretrained/GOPRO_wsharp_woVAE.pth + strict_load: false + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 600000 + warmup_iter: -1 # -1: no warm up + T_period: [50000, 100000, 150000, 150000, 150000] + restarts: [50000, 150000, 300000, 450000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-8 + + pixel_criterion: cb + pixel_weight: !!float 1.0 + kl_weight: !!float 10.0 + val_freq: !!float 5e3 + + manual_seed: 0 + +#### logger +logger: + print_freq: 10 + save_checkpoint_freq: !!float 5e3 diff --git a/diffusion-posterior-sampling/bkse/options/kernel_encoding/GOPRO/woVAE.yml b/diffusion-posterior-sampling/bkse/options/kernel_encoding/GOPRO/woVAE.yml new file mode 100644 index 0000000000000000000000000000000000000000..256e094c11733a25ee0c664d19aea59285962b30 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/kernel_encoding/GOPRO/woVAE.yml @@ -0,0 +1,77 @@ +#### general settings +name: GOPRO_woVAE +use_tb_logger: true +model: image_base +distortion: deblur +scale: 1 +gpu_ids: [0] + +#### datasets +datasets: + train: + name: GOPRO + mode: GOPRO + interval_list: [1] + dataroot_HQ: datasets/GOPRO/train_sharp.lmdb + dataroot_LQ: datasets/GOPRO/train_blur_linear.lmdb + cache_keys: ~ + + use_shuffle: true + n_workers: 4 # per GPU + batch_size: 16 + HQ_size: &HQ_SIZE 256 + LQ_size: 256 + use_flip: true + use_rot: true + color: RGB + +#### network structures +KernelWizard: + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + img_size: *HQ_SIZE + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false + +#### path +path: + pretrain_model_G: ~ + strict_load: false + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 600000 + warmup_iter: -1 # -1: no warm up + T_period: [50000, 100000, 150000, 150000, 150000] + restarts: [50000, 150000, 300000, 450000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-8 + + pixel_criterion: cb + pixel_weight: 1.0 + kl_weight: 0.0 + val_freq: !!float 5e3 + + manual_seed: 0 + +#### logger +logger: + print_freq: 10 + save_checkpoint_freq: !!float 5e3 diff --git a/diffusion-posterior-sampling/bkse/options/kernel_encoding/REDS/woVAE.yml b/diffusion-posterior-sampling/bkse/options/kernel_encoding/REDS/woVAE.yml new file mode 100644 index 0000000000000000000000000000000000000000..5bf18915782bdf824bfeaa49ce57e15fdb39f041 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/kernel_encoding/REDS/woVAE.yml @@ -0,0 +1,77 @@ +#### general settings +name: REDS_woVAE +use_tb_logger: true +model: image_base +distortion: deblur +scale: 1 +gpu_ids: [3] + +#### datasets +datasets: + train: + name: REDS + mode: REDS + interval_list: [1] + dataroot_HQ: datasets/REDS/train_sharp_wval.lmdb + dataroot_LQ: datasets/REDS/train_blur_wval.lmdb + cache_keys: ~ + + use_shuffle: true + n_workers: 4 # per GPU + batch_size: 13 + HQ_size: &HQ_SIZE 256 + LQ_size: 256 + use_flip: true + use_rot: true + color: RGB + +#### network structures +KernelWizard: + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + img_size: *HQ_SIZE + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false + +#### path +path: + pretrain_model_G: ~ + strict_load: false + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 600000 + warmup_iter: -1 # -1: no warm up + T_period: [50000, 100000, 150000, 150000, 150000] + restarts: [50000, 150000, 300000, 450000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-6 + + pixel_criterion: cb + pixel_weight: 1.0 + kl_weight: 0.0 + val_freq: !!float 5e3 + + manual_seed: 0 + +#### logger +logger: + print_freq: 10 + save_checkpoint_freq: !!float 5e3 diff --git a/diffusion-posterior-sampling/bkse/options/kernel_encoding/mix/woVAE.yml b/diffusion-posterior-sampling/bkse/options/kernel_encoding/mix/woVAE.yml new file mode 100644 index 0000000000000000000000000000000000000000..6881af7eb8269dc18910ae073f0a028aeaa76298 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/kernel_encoding/mix/woVAE.yml @@ -0,0 +1,79 @@ +#### general settings +name: mix_wsharp +use_tb_logger: true +model: image_base +distortion: deblur +scale: 1 +gpu_ids: [0] + +#### datasets +datasets: + train: + name: mix + mode: mix + interval_list: [1] + dataroots_HQ: ['datasets/REDS/train_sharp_wval.lmdb', 'datasets/GOPRO/train_sharp.lmdb'] + dataroots_LQ: ['datasets/REDS/train_blur_wval.lmdb', 'datasets/GOPRO/train_blur_linear.lmdb'] + dataset_weights: [1, 10] + cache_keys: ~ + + N_frames: 1 + use_shuffle: true + n_workers: 3 # per GPU + batch_size: 16 + HQ_size: 256 + LQ_size: 256 + use_flip: true + use_rot: true + color: RGB + +#### network structures +KernelWizard: + input_nc: 3 + nf: 64 + front_RBs: 10 + back_RBs: 20 + N_frames: 1 + kernel_dim: 512 + use_vae: false + KernelExtractor: + norm: none + use_sharp: true + n_blocks: 4 + padding_type: reflect + use_dropout: false + Adapter: + norm: none + use_dropout: false + +#### path +path: + pretrain_model_G: ~ + strict_load: false + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + lr_scheme: CosineAnnealingLR_Restart + beta1: 0.9 + beta2: 0.99 + niter: 600000 + warmup_iter: -1 # -1: no warm up + T_period: [50000, 100000, 150000, 150000, 150000] + restarts: [50000, 150000, 300000, 450000] + restart_weights: [1, 1, 1, 1] + eta_min: !!float 1e-8 + + pixel_criterion: cb + pixel_weight: 1.0 + kernel_weight: 0.1 + gradient_loss_weight: 0.3 + val_freq: !!float 5e3 + + manual_seed: 0 + +#### logger +logger: + print_freq: 10 + save_checkpoint_freq: !!float 5000 diff --git a/diffusion-posterior-sampling/bkse/options/options.py b/diffusion-posterior-sampling/bkse/options/options.py new file mode 100755 index 0000000000000000000000000000000000000000..1eeefa691e32857cd565e252e2049a11dde5e68b --- /dev/null +++ b/diffusion-posterior-sampling/bkse/options/options.py @@ -0,0 +1,122 @@ +import logging +import os +import os.path as osp + +import yaml +from utils.util import OrderedYaml + + +Loader, Dumper = OrderedYaml() + + +def parse(opt_path, is_train=True): + with open(opt_path, mode="r") as f: + opt = yaml.load(f, Loader=Loader) + # export CUDA_VISIBLE_DEVICES + gpu_list = ",".join(str(x) for x in opt["gpu_ids"]) + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list + print("export CUDA_VISIBLE_DEVICES=" + gpu_list) + + opt["is_train"] = is_train + if opt["distortion"] == "sr": + scale = opt["scale"] + + # datasets + for phase, dataset in opt["datasets"].items(): + phase = phase.split("_")[0] + dataset["phase"] = phase + if opt["distortion"] == "sr": + dataset["scale"] = scale + is_lmdb = False + if dataset.get("dataroot_GT", None) is not None: + dataset["dataroot_GT"] = osp.expanduser(dataset["dataroot_GT"]) + if dataset["dataroot_GT"].endswith("lmdb"): + is_lmdb = True + if dataset.get("dataroot_LQ", None) is not None: + dataset["dataroot_LQ"] = osp.expanduser(dataset["dataroot_LQ"]) + if dataset["dataroot_LQ"].endswith("lmdb"): + is_lmdb = True + dataset["data_type"] = "lmdb" if is_lmdb else "img" + if dataset["mode"].endswith("mc"): # for memcached + dataset["data_type"] = "mc" + dataset["mode"] = dataset["mode"].replace("_mc", "") + + # path + for key, path in opt["path"].items(): + if path and key in opt["path"] and key != "strict_load": + opt["path"][key] = osp.expanduser(path) + opt["path"]["root"] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + if is_train: + experiments_root = osp.join(opt["path"]["root"], "experiments", opt["name"]) + opt["path"]["experiments_root"] = experiments_root + opt["path"]["models"] = osp.join(experiments_root, "models") + opt["path"]["training_state"] = osp.join(experiments_root, "training_state") + opt["path"]["log"] = experiments_root + opt["path"]["val_images"] = osp.join(experiments_root, "val_images") + + # change some options for debug mode + if "debug" in opt["name"]: + opt["train"]["val_freq"] = 8 + opt["logger"]["print_freq"] = 1 + opt["logger"]["save_checkpoint_freq"] = 8 + else: # test + results_root = osp.join(opt["path"]["root"], "results", opt["name"]) + opt["path"]["results_root"] = results_root + opt["path"]["log"] = results_root + + # network + if opt["distortion"] == "sr": + opt["network_G"]["scale"] = scale + + return opt + + +def dict2str(opt, indent_l=1): + """dict to string for logger""" + msg = "" + for k, v in opt.items(): + if isinstance(v, dict): + msg += " " * (indent_l * 2) + k + ":[\n" + msg += dict2str(v, indent_l + 1) + msg += " " * (indent_l * 2) + "]\n" + else: + msg += " " * (indent_l * 2) + k + ": " + str(v) + "\n" + return msg + + +class NoneDict(dict): + def __missing__(self, key): + return None + + +# convert to NoneDict, which return None for missing key. +def dict_to_nonedict(opt): + if isinstance(opt, dict): + new_opt = dict() + for key, sub_opt in opt.items(): + new_opt[key] = dict_to_nonedict(sub_opt) + return NoneDict(**new_opt) + elif isinstance(opt, list): + return [dict_to_nonedict(sub_opt) for sub_opt in opt] + else: + return opt + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_model paths""" + logger = logging.getLogger("base") + if opt["path"]["resume_state"]: + if ( + opt["path"].get("pretrain_model_G", None) is not None + or opt["path"].get("pretrain_model_D", None) is not None + ): + logger.warning( + "pretrain_model path will be ignored \ + when resuming training." + ) + + opt["path"]["pretrain_model_G"] = osp.join(opt["path"]["models"], "{}_G.pth".format(resume_iter)) + logger.info("Set [pretrain_model_G] to " + opt["path"]["pretrain_model_G"]) + if "gan" in opt["model"]: + opt["path"]["pretrain_model_D"] = osp.join(opt["path"]["models"], "{}_D.pth".format(resume_iter)) + logger.info("Set [pretrain_model_D] to " + opt["path"]["pretrain_model_D"]) diff --git a/diffusion-posterior-sampling/bkse/requirements.txt b/diffusion-posterior-sampling/bkse/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4712e2b293bc0ae4a3825d655d4888024dc70686 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/requirements.txt @@ -0,0 +1,9 @@ +torch >= 1.4.0 +torchvision >= 0.5.0 +pyyaml +opencv-python +numpy +lmdb +tqdm +tensorboard >= 1.15.0 +ninja diff --git a/diffusion-posterior-sampling/bkse/scripts/create_lmdb.py b/diffusion-posterior-sampling/bkse/scripts/create_lmdb.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbbbd252fc3efd31b6e8dd60d077ba13da1332b --- /dev/null +++ b/diffusion-posterior-sampling/bkse/scripts/create_lmdb.py @@ -0,0 +1,117 @@ +import argparse +import os.path as osp +import pickle +import sys +from multiprocessing import Pool + +import cv2 +import lmdb + + +sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) +import data.util as data_util # noqa: E402 +import utils.util as util # noqa: E402 + + +def read_image_worker(path, key): + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + return (key, img) + + +def create_dataset(name, img_folder, lmdb_save_path, H_dst, W_dst, C_dst): + """Create lmdb for the dataset, each image with a fixed size + key pattern: folder_frameid + """ + # configurations + read_all_imgs = False # whether real all images to memory with multiprocessing + # Set False for use limited memory + BATCH = 5000 # After BATCH images, lmdb commits, if read_all_imgs = False + n_thread = 40 + ######################################################## + if not lmdb_save_path.endswith(".lmdb"): + raise ValueError("lmdb_save_path must end with 'lmdb'.") + if osp.exists(lmdb_save_path): + print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path)) + sys.exit(1) + + # read all the image paths to a list + print("Reading image path list ...") + all_img_list = data_util._get_paths_from_images(img_folder) + keys = [] + for img_path in all_img_list: + split_rlt = img_path.split("/") + folder = split_rlt[-2] + img_name = split_rlt[-1].split(".png")[0] + keys.append(folder + "_" + img_name) + + if read_all_imgs: + # read all images to memory (multiprocessing) + dataset = {} # store all image data. list cannot keep the order, use dict + print("Read images with multiprocessing, #thread: {} ...".format(n_thread)) + pbar = util.ProgressBar(len(all_img_list)) + + def mycallback(arg): + """get the image data and update pbar""" + key = arg[0] + dataset[key] = arg[1] + pbar.update("Reading {}".format(key)) + + pool = Pool(n_thread) + for path, key in zip(all_img_list, keys): + pool.apply_async(read_image_worker, args=(path, key), callback=mycallback) + pool.close() + pool.join() + print("Finish reading {} images.\nWrite lmdb...".format(len(all_img_list))) + + # create lmdb environment + data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes + print("data size per image is: ", data_size_per_img) + data_size = data_size_per_img * len(all_img_list) + env = lmdb.open(lmdb_save_path, map_size=data_size * 10) + + # write data to lmdb + pbar = util.ProgressBar(len(all_img_list)) + txn = env.begin(write=True) + for idx, (path, key) in enumerate(zip(all_img_list, keys)): + pbar.update("Write {}".format(key)) + key_byte = key.encode("ascii") + data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED) + + assert len(data.shape) > 2 or C_dst == 1, "different shape" + + if C_dst == 1: + H, W = data.shape + assert H == H_dst and W == W_dst, "different shape." + else: + H, W, C = data.shape + assert H == H_dst and W == W_dst and C == 3, "different shape." + txn.put(key_byte, data) + if not read_all_imgs and idx % BATCH == 0: + txn.commit() + txn = env.begin(write=True) + txn.commit() + env.close() + print("Finish writing lmdb.") + + # create meta information + meta_info = {} + meta_info["name"] = name + channel = C_dst + meta_info["resolution"] = "{}_{}_{}".format(channel, H_dst, W_dst) + meta_info["keys"] = keys + pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"), "wb")) + print("Finish creating lmdb meta info.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Kernel extractor testing") + + parser.add_argument("--H", action="store", help="source image height", type=int, required=True) + parser.add_argument("--W", action="store", help="source image height", type=int, required=True) + parser.add_argument("--C", action="store", help="source image height", type=int, required=True) + parser.add_argument("--img_folder", action="store", help="img folder", type=str, required=True) + parser.add_argument("--save_path", action="store", help="save path", type=str, default=".") + parser.add_argument("--name", action="store", help="dataset name", type=str, required=True) + + args = parser.parse_args() + create_dataset(args.name, args.img_folder, args.save_path, args.H, args.W, args.C) diff --git a/diffusion-posterior-sampling/bkse/scripts/download_dataset.py b/diffusion-posterior-sampling/bkse/scripts/download_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5d7086463abb2eeb777b1b4769604e2a79a67d88 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/scripts/download_dataset.py @@ -0,0 +1,72 @@ +import argparse +import os +import os.path as osp + +import requests + + +def download_file_from_google_drive(file_id, destination): + os.makedirs(osp.dirname(destination), exist_ok=True) + URL = "https://docs.google.com/uc?export=download" + + session = requests.Session() + + response = session.get(URL, params={"id": file_id}, stream=True) + token = get_confirm_token(response) + + if token: + params = {"id": file_id, "confirm": token} + response = session.get(URL, params=params, stream=True) + + save_response_content(response, destination) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + return value + + return None + + +def save_response_content(response, destination): + CHUNK_SIZE = 32768 + + with open(destination, "wb") as f: + for chunk in response.iter_content(CHUNK_SIZE): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + +if __name__ == "__main__": + dataset_ids = { + "GOPRO_Large": "1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2", + "train_sharp": "1YLksKtMhd2mWyVSkvhDaDLWSc1qYNCz-", + "train_blur": "1Be2cgzuuXibcqAuJekDgvHq4MLYkCgR8", + "val_sharp": "1MGeObVQ1-Z29f-myDP7-8c3u0_xECKXq", + "val_blur": "1N8z2yD0GDWmh6U4d4EADERtcUgDzGrHx", + "test_blur": "1dr0--ZBKqr4P1M8lek6JKD1Vd6bhhrZT", + } + + parser = argparse.ArgumentParser( + description="Download REDS dataset from google drive to current folder", allow_abbrev=False + ) + + parser.add_argument("--REDS_train_sharp", action="store_true", help="download REDS train_sharp.zip") + parser.add_argument("--REDS_train_blur", action="store_true", help="download REDS train_blur.zip") + parser.add_argument("--REDS_val_sharp", action="store_true", help="download REDS val_sharp.zip") + parser.add_argument("--REDS_val_blur", action="store_true", help="download REDS val_blur.zip") + parser.add_argument("--GOPRO", action="store_true", help="download GOPRO_Large.zip") + + args = parser.parse_args() + + if args.REDS_train_sharp: + download_file_from_google_drive(dataset_ids["train_sharp"], "REDS/train_sharp.zip") + if args.REDS_train_blur: + download_file_from_google_drive(dataset_ids["train_blur"], "REDS/train_blur.zip") + if args.REDS_val_sharp: + download_file_from_google_drive(dataset_ids["val_sharp"], "REDS/val_sharp.zip") + if args.REDS_val_blur: + download_file_from_google_drive(dataset_ids["val_blur"], "REDS/val_blur.zip") + if args.GOPRO: + download_file_from_google_drive(dataset_ids["GOPRO_Large"], "GOPRO/GOPRO.zip") diff --git a/diffusion-posterior-sampling/bkse/train.py b/diffusion-posterior-sampling/bkse/train.py new file mode 100644 index 0000000000000000000000000000000000000000..8d34cf05d9cb764bcf785c06f4da8083adf2a625 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/train.py @@ -0,0 +1,320 @@ +import argparse +import logging +import math +import os +import random + +import numpy as np +import options.options as option +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from data import create_dataloader, create_dataset +from data.data_sampler import DistIterSampler +from models import create_model +from utils import util + + +def init_dist(backend="nccl", **kwargs): + """initialization for distributed training""" + if mp.get_start_method(allow_none=True) != "spawn": + mp.set_start_method("spawn") + rank = int(os.environ["RANK"]) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def main(): + # options + parser = argparse.ArgumentParser() + parser.add_argument("-opt", type=str, help="Path to option YAML file.") + parser.add_argument("--launcher", choices=["none", "pytorch"], default="none", help="job launcher") + parser.add_argument("--local_rank", type=int, default=0) + args = parser.parse_args() + opt = option.parse(args.opt, is_train=True) + + # distributed training settings + if args.launcher == "none": # disabled distributed training + opt["dist"] = False + rank = -1 + print("Disabled distributed training.") + else: + opt["dist"] = True + init_dist() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # loading resume state if exists + if opt["path"].get("resume_state", None): + # distributed resuming: all load into default GPU + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt["path"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id) + ) + option.check_resume(opt, resume_state["iter"]) # check resume options + else: + resume_state = None + + # mkdir and loggers + if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) + if resume_state is None: + util.mkdir_and_rename(opt["path"]["experiments_root"]) # rename experiment folder if exists + util.mkdirs( + ( + path + for key, path in opt["path"].items() + if not key == "experiments_root" and "pretrain_model" not in key and "resume" not in key + ) + ) + + # config loggers. Before it, the log will not work + util.setup_logger( + "base", opt["path"]["log"], "train_" + opt["name"], level=logging.INFO, screen=True, tofile=True + ) + logger = logging.getLogger("base") + logger.info(option.dict2str(opt)) + # tensorboard logger + if opt["use_tb_logger"] and "debug" not in opt["name"]: + version = float(torch.__version__[0:3]) + if version >= 1.1: # PyTorch 1.1 + from torch.utils.tensorboard import SummaryWriter + else: + logger.info( + "You are using PyTorch {}. \ + Tensorboard will use [tensorboardX]".format( + version + ) + ) + from tensorboardX import SummaryWriter + tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"]) + else: + util.setup_logger("base", opt["path"]["log"], "train", level=logging.INFO, screen=True) + logger = logging.getLogger("base") + + # convert to NoneDict, which returns None for missing keys + opt = option.dict_to_nonedict(opt) + + # random seed + seed = opt["train"]["manual_seed"] + if seed is None: + seed = random.randint(1, 10000) + if rank <= 0: + logger.info("Random seed: {}".format(seed)) + util.set_random_seed(seed) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # create train and val dataloader + dataset_ratio = 200 # enlarge the size of each epoch + for phase, dataset_opt in opt["datasets"].items(): + if phase == "train": + train_set = create_dataset(dataset_opt) + train_size = int(math.ceil(len(train_set) / dataset_opt["batch_size"])) + total_iters = int(opt["train"]["niter"]) + total_epochs = int(math.ceil(total_iters / train_size)) + if opt["dist"]: + train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) + total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) + else: + train_sampler = None + train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) + if rank <= 0: + logger.info("Number of train images: {:,d}, iters: {:,d}".format(len(train_set), train_size)) + logger.info("Total epochs needed: {:d} for iters {:,d}".format(total_epochs, total_iters)) + elif phase == "val": + val_set = create_dataset(dataset_opt) + val_loader = create_dataloader(val_set, dataset_opt, opt, None) + if rank <= 0: + logger.info("Number of val images in [{:s}]: {:d}".format(dataset_opt["name"], len(val_set))) + else: + raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) + assert train_loader is not None + + # create model + model = create_model(opt) + print("Model created!") + + # resume training + if resume_state: + logger.info("Resuming training from epoch: {}, iter: {}.".format(resume_state["epoch"], resume_state["iter"])) + + start_epoch = resume_state["epoch"] + current_step = resume_state["iter"] + model.resume_training(resume_state) # handle optimizers and schedulers + else: + current_step = 0 + start_epoch = 0 + + # training + logger.info("Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step)) + for epoch in range(start_epoch, total_epochs + 1): + if opt["dist"]: + train_sampler.set_epoch(epoch) + for _, train_data in enumerate(train_loader): + current_step += 1 + if current_step > total_iters: + break + # update learning rate + model.update_learning_rate(current_step, warmup_iter=opt["train"]["warmup_iter"]) + + # training + model.feed_data(train_data) + model.optimize_parameters(current_step) + + # log + if current_step % opt["logger"]["print_freq"] == 0: + logs = model.get_current_log() + message = "[epoch:{:3d}, iter:{:8,d}, lr:(".format(epoch, current_step) + for v in model.get_current_learning_rate(): + message += "{:.3e},".format(v) + message += ")] " + for k, v in logs.items(): + message += "{:s}: {:.4e} ".format(k, v) + # tensorboard logger + if opt["use_tb_logger"] and "debug" not in opt["name"]: + if rank <= 0: + tb_logger.add_scalar(k, v, current_step) + if rank <= 0: + logger.info(message) + # validation + if opt["datasets"].get("val", None) and current_step % opt["train"]["val_freq"] == 0: + # image restoration validation + if opt["model"] in ["sr", "srgan"] and rank <= 0: + # does not support multi-GPU validation + pbar = util.ProgressBar(len(val_loader)) + avg_psnr = 0.0 + idx = 0 + for val_data in val_loader: + idx += 1 + img_name = os.path.splitext(os.path.basename(val_data["LQ_path"][0]))[0] + img_dir = os.path.join(opt["path"]["val_images"], img_name) + util.mkdir(img_dir) + + model.feed_data(val_data) + model.test() + + visuals = model.get_current_visuals() + sr_img = util.tensor2img(visuals["rlt"]) # uint8 + gt_img = util.tensor2img(visuals["GT"]) # uint8 + + # Save SR images for reference + save_img_path = os.path.join(img_dir, "{:s}_{:d}.png".format(img_name, current_step)) + util.save_img(sr_img, save_img_path) + + # calculate PSNR + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt["scale"]) + avg_psnr += util.calculate_psnr(sr_img, gt_img) + pbar.update("Test {}".format(img_name)) + + avg_psnr = avg_psnr / idx + + # log + logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr)) + # tensorboard logger + if opt["use_tb_logger"] and "debug" not in opt["name"]: + tb_logger.add_scalar("psnr", avg_psnr, current_step) + else: # video restoration validation + if opt["dist"]: + # multi-GPU testing + psnr_rlt = {} # with border and center frames + if rank == 0: + pbar = util.ProgressBar(len(val_set)) + for idx in range(rank, len(val_set), world_size): + val_data = val_set[idx] + val_data["LQs"].unsqueeze_(0) + val_data["GT"].unsqueeze_(0) + folder = val_data["folder"] + idx_d, max_idx = val_data["idx"].split("/") + idx_d, max_idx = int(idx_d), int(max_idx) + if psnr_rlt.get(folder, None) is None: + psnr_rlt[folder] = torch.zeros(max_idx, dtype=torch.float32, device="cuda") + model.feed_data(val_data) + model.test() + visuals = model.get_current_visuals() + rlt_img = util.tensor2img(visuals["rlt"]) # uint8 + gt_img = util.tensor2img(visuals["GT"]) # uint8 + # calculate PSNR + psnr_rlt[folder][idx_d] = util.calculate_psnr(rlt_img, gt_img) + + if rank == 0: + for _ in range(world_size): + pbar.update("Test {} - {}/{}".format(folder, idx_d, max_idx)) + # collect data + for _, v in psnr_rlt.items(): + dist.reduce(v, 0) + dist.barrier() + + if rank == 0: + psnr_rlt_avg = {} + psnr_total_avg = 0.0 + for k, v in psnr_rlt.items(): + psnr_rlt_avg[k] = torch.mean(v).cpu().item() + psnr_total_avg += psnr_rlt_avg[k] + psnr_total_avg /= len(psnr_rlt) + log_s = "# Validation # PSNR: {:.4e}:".format(psnr_total_avg) + for k, v in psnr_rlt_avg.items(): + log_s += " {}: {:.4e}".format(k, v) + logger.info(log_s) + if opt["use_tb_logger"] and "debug" not in opt["name"]: + tb_logger.add_scalar("psnr_avg", psnr_total_avg, current_step) + for k, v in psnr_rlt_avg.items(): + tb_logger.add_scalar(k, v, current_step) + else: + pbar = util.ProgressBar(len(val_loader)) + psnr_rlt = {} # with border and center frames + psnr_rlt_avg = {} + psnr_total_avg = 0.0 + for val_data in val_loader: + folder = val_data["folder"][0] + idx_d, max_id = val_data["idx"][0].split("/") + # border = val_data['border'].item() + if psnr_rlt.get(folder, None) is None: + psnr_rlt[folder] = [] + + model.feed_data(val_data) + model.test() + visuals = model.get_current_visuals() + rlt_img = util.tensor2img(visuals["rlt"]) # uint8 + gt_img = util.tensor2img(visuals["GT"]) # uint8 + lq_img = util.tensor2img(visuals["LQ"][2]) # uint8 + + img_dir = opt["path"]["val_images"] + util.mkdir(img_dir) + save_img_path = os.path.join(img_dir, "{}.png".format(idx_d)) + util.save_img(np.hstack((lq_img, rlt_img, gt_img)), save_img_path) + + # calculate PSNR + psnr = util.calculate_psnr(rlt_img, gt_img) + psnr_rlt[folder].append(psnr) + pbar.update("Test {} - {}".format(folder, idx_d)) + for k, v in psnr_rlt.items(): + psnr_rlt_avg[k] = sum(v) / len(v) + psnr_total_avg += psnr_rlt_avg[k] + psnr_total_avg /= len(psnr_rlt) + log_s = "# Validation # PSNR: {:.4e}:".format(psnr_total_avg) + for k, v in psnr_rlt_avg.items(): + log_s += " {}: {:.4e}".format(k, v) + logger.info(log_s) + if opt["use_tb_logger"] and "debug" not in opt["name"]: + tb_logger.add_scalar("psnr_avg", psnr_total_avg, current_step) + for k, v in psnr_rlt_avg.items(): + tb_logger.add_scalar(k, v, current_step) + + # save models and training states + if current_step % opt["logger"]["save_checkpoint_freq"] == 0: + if rank <= 0: + logger.info("Saving models and training states.") + model.save(current_step) + model.save_training_state(epoch, current_step) + + if rank <= 0: + logger.info("Saving the final model.") + model.save("latest") + logger.info("End of training.") + tb_logger.close() + + +if __name__ == "__main__": + main() diff --git a/diffusion-posterior-sampling/bkse/train_script.sh b/diffusion-posterior-sampling/bkse/train_script.sh new file mode 100755 index 0000000000000000000000000000000000000000..b12b24941c9580cd0f4c043d6ed680a992cc19d7 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/train_script.sh @@ -0,0 +1 @@ +python3.7 train.py -opt options/REDS/wsharp_woVAE.yml diff --git a/diffusion-posterior-sampling/bkse/utils/__init__.py b/diffusion-posterior-sampling/bkse/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/diffusion-posterior-sampling/bkse/utils/util.py b/diffusion-posterior-sampling/bkse/utils/util.py new file mode 100755 index 0000000000000000000000000000000000000000..127232482bc969d61a49c0d30228c2e554ed8c02 --- /dev/null +++ b/diffusion-posterior-sampling/bkse/utils/util.py @@ -0,0 +1,323 @@ +import logging +import math +import os +import random +import sys +import time +from collections import OrderedDict +from datetime import datetime +from shutil import get_terminal_size + +import cv2 +import numpy as np +import torch +import yaml +from torchvision.utils import make_grid + + +try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader +except ImportError: + from yaml import Dumper, Loader + + +def OrderedYaml(): + """yaml orderedDict support""" + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +#################### +# miscellaneous +#################### + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + "_archived_" + get_timestamp() + print("Path already exists. Rename it to [{:s}]".format(new_name)) + logger = logging.getLogger("base") + logger.info(f"Path already exists. Rename it to {new_name}") + os.rename(path, new_name) + os.makedirs(path) + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + """set up logger""" + lg = logging.getLogger(logger_name) + formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + fh = logging.FileHandler(log_file, mode="w") + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) + + +#################### +# image convert +#################### +def crop_border(img_list, crop_border): + """Crop borders of images + Args: + img_list (list [Numpy]): HWC + crop_border (int): crop border for each end of height and weight + + Returns: + (list [Numpy]): cropped image list + """ + if crop_border == 0: + return img_list + else: + return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list] + + +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + """ + Converts a torch Tensor into an image Numpy array + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + """ + + # clamp + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) + + # to range [0,1] + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + f"Only support 4D, 3D and 2D tensor. But received with dimension:\ + {n_dim}" + ) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +def save_img(img, img_path, mode="RGB"): + cv2.imwrite(img_path, img) + + +#################### +# metric +#################### + + +def calculate_psnr(img1, img2): + # img1 and img2 have range [0, 255] + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2) ** 2) + if mse == 0: + return float("inf") + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +def ssim(img1, img2): + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def calculate_ssim(img1, img2): + """calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + """ + if not img1.shape == img2.shape: + raise ValueError("Input images must have the same dimensions.") + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1, img2)) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError("Wrong input image dimensions.") + + +class ProgressBar(object): + """A progress bar which can print the progress + modified from + https://github.com/hellock/cvbase/blob/master/cvbase/progress.py + """ + + def __init__(self, task_num=0, bar_width=50, start=True): + self.task_num = task_num + max_bar_width = self._get_max_bar_width() + self.bar_width = bar_width if bar_width <= max_bar_width else max_bar_width + self.completed = 0 + if start: + self.start() + + def _get_max_bar_width(self): + terminal_width, _ = get_terminal_size() + max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) + if max_bar_width < 10: + print( + "terminal width is too small ({}), \ + please consider widen the terminal for better " + "progressbar visualization".format(terminal_width) + ) + max_bar_width = 10 + return max_bar_width + + def start(self): + if self.task_num > 0: + sys.stdout.write( + "[{}] 0/{}, elapsed: 0s, ETA:\n{}\n".format(" " * self.bar_width, self.task_num, "Start...") + ) + else: + sys.stdout.write("completed: 0, elapsed: 0s") + sys.stdout.flush() + self.start_time = time.time() + + def update(self, msg="In progress..."): + self.completed += 1 + elapsed = time.time() - self.start_time + fps = self.completed / elapsed + if self.task_num > 0: + percentage = self.completed / float(self.task_num) + eta = int(elapsed * (1 - percentage) / percentage + 0.5) + mark_width = int(self.bar_width * percentage) + bar_chars = ">" * mark_width + "-" * (self.bar_width - mark_width) + sys.stdout.write("\033[2F") # cursor up 2 lines + + # clean the output (remove extra chars since last display) + sys.stdout.write("\033[J") + sys.stdout.write( + "[{}] {}/{}, {:.1f} task/s, \ + elapsed: {}s, ETA: {:5}s\n{}\n".format( + bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg + ) + ) + else: + sys.stdout.write( + "completed: {}, elapsed: \ + {}s, {:.1f} tasks/s".format( + self.completed, int(elapsed + 0.5), fps + ) + ) + sys.stdout.flush() + + +def img2tensor(img): + return torch.from_numpy(np.ascontiguousarray(np.transpose(img / 255.0, (2, 0, 1)))).float() + + +def fill_noise(x, noise_type): + """Fills tensor `x` with noise of type `noise_type`.""" + if noise_type == "u": + x.uniform_() + elif noise_type == "n": + x.normal_() + else: + assert False + + +def np_to_torch(img_np): + """Converts image in numpy.array to torch.Tensor. + From C x W x H [0..1] to C x W x H [0..1] + """ + return torch.from_numpy(img_np)[None, :] + + +def get_noise(input_depth, method, spatial_size, noise_type="u", var=1.0 / 10): + """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) + initialized in a specific way. + Args: + input_depth: number of channels in the tensor + method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid + spatial_size: spatial size of the tensor to initialize + noise_type: 'u' for uniform; 'n' for normal + var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. + """ + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + if method == "noise": + shape = [1, input_depth, spatial_size[0], spatial_size[1]] + net_input = torch.zeros(shape) + + fill_noise(net_input, noise_type) + net_input *= var + elif method == "meshgrid": + assert input_depth == 2 + X, Y = np.meshgrid( + np.arange(0, spatial_size[1]) / float(spatial_size[1] - 1), + np.arange(0, spatial_size[0]) / float(spatial_size[0] - 1), + ) + meshgrid = np.concatenate([X[None, :], Y[None, :]]) + net_input = np_to_torch(meshgrid) + else: + assert False + + return net_input