diff --git a/DiffIR2VR_fps_10.mp4 b/DiffIR2VR_fps_10.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ae2fa7467b0144880ceeab3152f96eda48f27810 Binary files /dev/null and b/DiffIR2VR_fps_10.mp4 differ diff --git a/GMFlow/LICENSE b/GMFlow/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8ba17c78e378819527e65ef7d1a767f035a792ac --- /dev/null +++ b/GMFlow/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022, Haofei Xu + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/GMFlow/README.md b/GMFlow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..17449970f8861dec7fef8d8835fc7e92abeb2332 --- /dev/null +++ b/GMFlow/README.md @@ -0,0 +1,239 @@ +# GMFlow + + +Official PyTorch implementation of paper: + +[**GMFlow: Learning Optical Flow via Global Matching**](https://arxiv.org/abs/2111.13680), **CVPR 2022, Oral** + +Authors: [Haofei Xu](https://haofeixu.github.io/), [Jing Zhang](https://scholar.google.com.hk/citations?user=9jH5v74AAAAJ), [Jianfei Cai](https://jianfei-cai.github.io/), [Hamid Rezatofighi](https://scholar.google.com/citations?user=VxAuxMwAAAAJ), [Dacheng Tao](https://scholar.google.com/citations?user=RwlJNLcAAAAJ) + + +**11/15/2022 Update: Check out our new work: [Unifying Flow, Stereo and Depth Estimation](https://haofeixu.github.io/unimatch/) and code: [unimatch](https://github.com/autonomousvision/unimatch) for extending GMFlow to stereo and depth tasks. [More pretrained GMFlow models](https://github.com/autonomousvision/unimatch/blob/master/MODEL_ZOO.md) with different speed-accuracy trade-offs are also released. Check out our [Colab](https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing) and [HuggingFace](https://huggingface.co/spaces/haofeixu/unimatch) demo to play with GMFlow in your browser!** + + + +**A [video introduction](https://www.bilibili.com/video/BV18A4y1R7PL) (in Chinese) of GMFlow is available at bilibili!** + + + +https://user-images.githubusercontent.com/19343475/174446408-520b8a6c-9714-4ff3-978c-98e23ab29c1f.mp4 + + + + + +We streamline the optical flow estimation pipeline by reformulating optical flow as a **global matching** problem. + + + + +

+ + + + + +## Highlights + +- **Flexible & Modular design** + + We decompose the end-to-end optical flow framework into five components: + + feature extraction, feature enhancement, feature matching, flow propagation and flow refinement. + + One can easily construct a customized optical flow model by combining different components. + +- **High accuracy** + + With only one refinement, GMFlow outperforms 31-refinements RAFT on the challenging Sintel benchmark. + +- **High efficiency** + + A basic GMFlow model (without refinement) runs at 57ms (V100) or 26ms (A100) for Sintel data (436x1024). + + GMFlow gains more speedup than RAFT on high-end GPUs (e.g., A100) since GMFlow doesn't require a large number of sequential computation. + + GMFlow also simplifies backward flow computation without requiring to forward the network twice. The bidirectional flow can be used for occlusion detection with forward-backward consistency check. + +

+ + + + +## Installation + +Our code is based on pytorch 1.9.0, CUDA 10.2 and python 3.8. Higher version pytorch should also work well. + +We recommend using [conda](https://www.anaconda.com/distribution/) for installation: + +``` +conda env create -f environment.yml +conda activate gmflow +``` + +## Demos + +All pretrained models can be downloaded from [google drive](https://drive.google.com/file/d/1d5C5cgHIxWGsFR1vYs5XrQbbUiZl9TX2/view?usp=sharing). + + + +You can run a trained model on a sequence of images and visualize the results: + +``` +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/sintel_market_1 \ +--output_path output/gmflow-norefine-sintel_market_1 \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth +``` + +You can also predict bidirectional flow with `--pred_bidir_flow` enabled and use `--fwd_bwd_consistency_check` for forward-backward consistency check. More examples can be found in [scripts/demo.sh](scripts/demo.sh). + + + +## Datasets + +The datasets used to train and evaluate GMFlow are as follows: + +* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) +* [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) +* [Sintel](http://sintel.is.tue.mpg.de/) +* [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) +* [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) + +By default the dataloader [datasets.py](data/datasets.py) assumes the datasets are located in folder `datasets` and are organized as follows: + +``` +datasets +├── FlyingChairs_release +│   └── data +├── FlyingThings3D +│   ├── frames_cleanpass +│   ├── frames_finalpass +│   └── optical_flow +├── HD1K +│   ├── hd1k_challenge +│   ├── hd1k_flow_gt +│   ├── hd1k_flow_uncertainty +│   └── hd1k_input +├── KITTI +│   ├── testing +│   └── training +├── Sintel +│   ├── test +│   └── training +``` + +It is recommended to symlink your dataset root to `datasets`: + +```shell +ln -s $YOUR_DATASET_ROOT datasets +``` + +Otherwise, you may need to change the corresponding paths in [datasets.py](data/datasets.py). + + + +## Evaluation + +You can evaluate a trained GMFlow model by running: + +``` +CUDA_VISIBLE_DEVICES=0 python main.py --eval --val_dataset things sintel --resume pretrained/gmflow_things-e9887eda.pth +``` + +More evaluation scripts can be found in [scripts/evaluate.sh](scripts/evaluate.sh). + + + +For submission to Sintel and KITTI online test sets, you can run [scripts/submission.sh](scripts/submission.sh). + + + +## Training + +All training scripts on FlyingChairs, FlyingThings3D, Sintel and KITTI datasets can be found in [scripts/train_gmflow.sh](scripts/train_gmflow.sh) and [scripts/train_gmflow_with_refine.sh](scripts/train_gmflow_with_refine.sh). + +Note that the basic GMFlow model (without refinement) can be trained on 4x 16GB V100 GPUs. For training GMFlow with refinement, 8x 16GB V100 or 4x 32GB V100 or 4x 40GB A100 GPUs are required by default. You may need to tune the batch size and training iterations according to your hardware. + + + +We support using tensorboard to monitor and visualize the training process. You can first start a tensorboard session with + +```shell +tensorboard --logdir checkpoints +``` + +and then access [http://localhost:6006](http://localhost:6006) in your browser. + + + +## Citation + +If you find our work useful in your research, please consider citing our paper: + +``` +@inproceedings{xu2022gmflow, + title={GMFlow: Learning Optical Flow via Global Matching}, + author={Xu, Haofei and Zhang, Jing and Cai, Jianfei and Rezatofighi, Hamid and Tao, Dacheng}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={8121-8130}, + year={2022} +} +``` + + + +## Acknowledgements + +This project would not have been possible without relying on some awesome repos : [RAFT](https://github.com/princeton-vl/RAFT), [LoFTR](https://github.com/zju3dv/LoFTR), [DETR](https://github.com/facebookresearch/detr), [Swin](https://github.com/microsoft/Swin-Transformer), [mmdetection](https://github.com/open-mmlab/mmdetection) and [Detectron2](https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py). We thank the original authors for their excellent work. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/GMFlow/data/__init__.py b/GMFlow/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..895b3281e7af148f74ecbc13a16d329863aeb49a --- /dev/null +++ b/GMFlow/data/__init__.py @@ -0,0 +1,7 @@ +from .datasets import build_train_dataset +from .datasets import (FlyingChairs, + FlyingThings3D, + MpiSintel, + KITTI, + HD1K, + ) diff --git a/GMFlow/data/chairs_split.txt b/GMFlow/data/chairs_split.txt new file mode 100755 index 0000000000000000000000000000000000000000..6ae8f0b72a22fc061552604c94664e3a0287914e --- /dev/null +++ b/GMFlow/data/chairs_split.txt @@ -0,0 +1,22872 @@ +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +2 +2 +1 +1 +1 +1 +1 +1 +1 +2 +1 +1 +1 +1 +1 \ No newline at end of file diff --git a/GMFlow/data/datasets.py b/GMFlow/data/datasets.py new file mode 100755 index 0000000000000000000000000000000000000000..6e2f1584f9c013fb0e4d4ac331d856da363e0c9b --- /dev/null +++ b/GMFlow/data/datasets.py @@ -0,0 +1,312 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data + +import os +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from data.transforms import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False, + load_occlusion=False, + ): + self.augmentor = None + self.sparse = sparse + + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + self.load_occlusion = load_occlusion + self.occ_list = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) # [H, W, 2], [H, W] + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + if self.load_occlusion: + occlusion = frame_utils.read_gen(self.occ_list[index]) # [H, W], 0 or 255 (occluded) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + if self.load_occlusion: + occlusion = np.array(occlusion).astype(np.float32) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[..., None], (1, 1, 3)) + img2 = np.tile(img2[..., None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + if self.load_occlusion: + img1, img2, flow, occlusion = self.augmentor(img1, img2, flow, occlusion=occlusion) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if self.load_occlusion: + occlusion = torch.from_numpy(occlusion) # [H, W] + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + # mask out occluded pixels + if self.load_occlusion: + # non-occlusion: 0, occlusion: 255 + noc_valid = 1 - occlusion / 255. # 0 or 1 + + return img1, img2, flow, valid.float(), noc_valid.float() + + return img1, img2, flow, valid.float() + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', + root='datasets/Sintel', + dstype='clean', + load_occlusion=False, + ): + super(MpiSintel, self).__init__(aug_params, + load_occlusion=load_occlusion, + ) + + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if load_occlusion: + occlusion_root = osp.join(root, split, 'occlusions') + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list) - 1): + self.image_list += [[image_list[i], image_list[i + 1]]] + self.extra_info += [(scene, i)] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + if load_occlusion: + self.occ_list += sorted(glob(osp.join(occlusion_root, scene, '*.png'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', + root='datasets/FlyingChairs_release/data', + ): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images) // 2 == len(flows)) + + split_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chairs_split.txt') + split_list = np.loadtxt(split_file, dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split == 'training' and xid == 1) or (split == 'validation' and xid == 2): + self.flow_list += [flows[i]] + self.image_list += [[images[2 * i], images[2 * i + 1]]] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, + root='datasets/FlyingThings3D', + dstype='frames_cleanpass', + test_set=False, + validate_subset=True, + ): + super(FlyingThings3D, self).__init__(aug_params) + + img_dir = root + flow_dir = root + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + if test_set: + image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TEST/*/*'))) + else: + image_dirs = sorted(glob(osp.join(img_dir, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + if test_set: + flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TEST/*/*'))) + else: + flow_dirs = sorted(glob(osp.join(flow_dir, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png'))) + flows = sorted(glob(osp.join(fdir, '*.pfm'))) + for i in range(len(flows) - 1): + if direction == 'into_future': + self.image_list += [[images[i], images[i + 1]]] + self.flow_list += [flows[i]] + elif direction == 'into_past': + self.image_list += [[images[i + 1], images[i]]] + self.flow_list += [flows[i + 1]] + + # validate on 1024 subset of test set for fast speed + if test_set and validate_subset: + num_val_samples = 1024 + all_test_samples = len(self.image_list) # 7866 + + stride = all_test_samples // num_val_samples + remove = all_test_samples % num_val_samples + + # uniformly sample a subset + self.image_list = self.image_list[:-remove][::stride] + self.flow_list = self.flow_list[:-remove][::stride] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', + root='datasets/KITTI', + ): + super(KITTI, self).__init__(aug_params, sparse=True, + ) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [[frame_id]] + self.image_list += [[img1, img2]] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1K'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows) - 1): + self.flow_list += [flows[i]] + self.image_list += [[images[i], images[i + 1]]] + + seq_ix += 1 + + +def build_train_dataset(args): + """ Create the data loader for the corresponding training set """ + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + # 1041 pairs for clean and final each + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') # 40302 + + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True} + + kitti = KITTI(aug_params=aug_params) # 200 + + aug_params = {'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True} + + hd1k = HD1K(aug_params=aug_params) # 1047 + + train_dataset = 100 * sintel_clean + 100 * sintel_final + 200 * kitti + 5 * hd1k + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + + train_dataset = KITTI(aug_params, split='training', + ) + else: + raise ValueError(f'stage {args.stage} is not supported') + + return train_dataset diff --git a/GMFlow/data/transforms.py b/GMFlow/data/transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..5b1188f3833c97c50429dd5c9644fb5dab3166d7 --- /dev/null +++ b/GMFlow/data/transforms.py @@ -0,0 +1,284 @@ +import numpy as np +import cv2 +from PIL import Image +from torchvision.transforms import ColorJitter + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True, + no_eraser_aug=True, + ): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14) + + self.asymmetric_color_aug_prob = 0.2 + + if no_eraser_aug: + # we disable eraser aug since no obvious improvement is observed in our experiments + self.eraser_aug_prob = -1 + else: + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow, occlusion=None): + # randomly sample scale + ht, wd = img1.shape[:2] + + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if occlusion is not None: + occlusion = cv2.resize(occlusion, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if occlusion is not None: + occlusion = occlusion[:, ::-1] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + if occlusion is not None: + occlusion = occlusion[::-1, :] + + # In case no cropping + if img1.shape[0] - self.crop_size[0] > 0: + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + else: + y0 = 0 + if img1.shape[1] - self.crop_size[1] > 0: + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + else: + x0 = 0 + + img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + + if occlusion is not None: + occlusion = occlusion[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + return img1, img2, flow, occlusion + + return img1, img2, flow + + def __call__(self, img1, img2, flow, occlusion=None): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + + if occlusion is not None: + img1, img2, flow, occlusion = self.spatial_transform( + img1, img2, flow, occlusion) + else: + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + if occlusion is not None: + occlusion = np.ascontiguousarray(occlusion) + return img1, img2, flow, occlusion + + return img1, img2, flow + + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False, + no_eraser_aug=True, + ): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14) + self.asymmetric_color_aug_prob = 0.2 + + if no_eraser_aug: + # we disable eraser aug since no obvious improvement is observed in our experiments + self.eraser_aug_prob = -1 + else: + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid >= 1] + flow0 = flow[valid >= 1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:, 0]).astype(np.int32) + yy = np.round(coords1[:, 1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + valid = valid[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] + return img1, img2, flow, valid + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/GMFlow/environment.yml b/GMFlow/environment.yml new file mode 100755 index 0000000000000000000000000000000000000000..f7e6fd86e66d7b5fad3a38aeb8c6ae02528ca439 --- /dev/null +++ b/GMFlow/environment.yml @@ -0,0 +1,162 @@ +name: gmflow +channels: + - pytorch + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - blas=1.0=mkl + - bottleneck=1.3.2=py38heb32a55_1 + - brotli=1.0.9=he6710b0_2 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2021.10.26=h06a4308_2 + - certifi=2021.10.8=py38h06a4308_2 + - cudatoolkit=10.2.89=hfd86e86_1 + - cycler=0.10.0=py38_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.4.1=h2531618_2 + - ffmpeg=4.3=hf484d3e_0 + - fontconfig=2.13.1=h6c09931_0 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.69.0=h5202010_0 + - gmp=6.2.1=h2531618_2 + - gnutls=3.6.15=he1e5248_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - imageio=2.9.0=pyhd3eb1b0_0 + - intel-openmp=2021.3.0=h06a4308_3350 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py38h2531618_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=9.3.0=h5101ec6_17 + - libiconv=1.15=h63c8f33_5 + - libidn2=2.3.2=h7f8727e_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.2.0=h85742a9_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.0.3=h1bed415_2 + - libuv=1.40.0=h7b6447c_0 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.12=h03d6c58_0 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.4.2=py38h06a4308_0 + - matplotlib-base=3.4.2=py38hab158f2_0 + - mkl=2021.3.0=h06a4308_520 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.0=py38h42c9631_2 + - mkl_random=1.2.2=py38h51133e4_0 + - munkres=1.1.4=py_0 + - ncurses=6.2=he6710b0_1 + - nettle=3.7.3=hbbd107a_1 + - ninja=1.10.2=hff7bd54_1 + - numexpr=2.7.3=py38h22e1b3c_1 + - numpy=1.20.3=py38hf144106_0 + - numpy-base=1.20.3=py38h74d4b33_0 + - olefile=0.46=py_0 + - openh264=2.1.0=hd408876_0 + - openjpeg=2.3.0=h05c96fa_1 + - openssl=1.1.1m=h7f8727e_0 + - pandas=1.3.2=py38h8c16a72_0 + - pcre=8.45=h295c915_0 + - pillow=8.3.1=py38h2c7a002_0 + - pip=21.2.2=py38h06a4308_0 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py38h05f1152_4 + - python=3.8.11=h12debd9_0_cpython + - python-dateutil=2.8.2=pyhd3eb1b0_0 + - pytorch=1.9.0=py3.8_cuda10.2_cudnn7.6.5_0 + - pytz=2021.1=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - scipy=1.6.2=py38had2a1c9_1 + - seaborn=0.11.2=pyhd3eb1b0_0 + - setuptools=52.0.0=py38h06a4308_0 + - sip=4.19.13=py38he6710b0_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - tk=8.6.10=hbc83047_0 + - torchaudio=0.9.0=py38 + - torchvision=0.10.0=py38_cu102 + - tornado=6.1=py38h27cfd23_0 + - typing_extensions=3.10.0.0=pyh06a4308_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 + - pip: + - absl-py==0.13.0 + - argon2-cffi==21.1.0 + - attrs==21.2.0 + - backcall==0.2.0 + - bleach==4.1.0 + - cachetools==4.2.2 + - cffi==1.14.6 + - charset-normalizer==2.0.4 + - debugpy==1.4.3 + - decorator==5.1.0 + - defusedxml==0.7.1 + - einops==0.3.2 + - entrypoints==0.3 + - google-auth==1.34.0 + - google-auth-oauthlib==0.4.5 + - grpcio==1.39.0 + - idna==3.2 + - ipykernel==6.4.1 + - ipython==7.27.0 + - ipython-genutils==0.2.0 + - jedi==0.18.0 + - jinja2==3.0.1 + - jsonschema==3.2.0 + - jupyter-client==7.0.3 + - jupyter-core==4.8.1 + - jupyterlab-pygments==0.1.2 + - markdown==3.3.4 + - markupsafe==2.0.1 + - matplotlib-inline==0.1.3 + - mistune==0.8.4 + - nbclient==0.5.4 + - nbconvert==6.1.0 + - nbformat==5.1.3 + - nest-asyncio==1.5.1 + - oauthlib==3.1.1 + - opencv-python==4.5.3.56 + - packaging==21.0 + - pandocfilters==1.5.0 + - parso==0.8.2 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - prometheus-client==0.11.0 + - prompt-toolkit==3.0.20 + - protobuf==3.17.3 + - ptyprocess==0.7.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pycparser==2.20 + - pygments==2.10.0 + - pyrsistent==0.18.0 + - pyzmq==22.3.0 + - requests==2.26.0 + - requests-oauthlib==1.3.0 + - rsa==4.7.2 + - send2trash==1.8.0 + - tensorboard==2.5.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.0 + - terminado==0.12.1 + - testpath==0.5.0 + - traitlets==5.1.0 + - urllib3==1.26.6 + - wcwidth==0.2.5 + - webencodings==0.5.1 + - werkzeug==2.0.1 diff --git a/GMFlow/evaluate.py b/GMFlow/evaluate.py new file mode 100755 index 0000000000000000000000000000000000000000..be6a3f53d009c843e05a6a5dd75ec9034788b29b --- /dev/null +++ b/GMFlow/evaluate.py @@ -0,0 +1,689 @@ +from PIL import Image +import os +import time +import numpy as np +import torch +import torch.nn.functional as F + +import data +from utils import frame_utils +from utils.flow_viz import save_vis_flow_tofile + +from utils.utils import InputPadder, compute_out_of_boundary_mask +from glob import glob +from gmflow.geometry import forward_backward_consistency_check + + +@torch.no_grad() +def create_sintel_submission(model, + output_path='sintel_submission', + padding_factor=8, + save_vis_flow=False, + no_save_flo=False, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + ): + """ Create submission for the Sintel leaderboard """ + model.eval() + for dstype in ['clean', 'final']: + test_dataset = data.MpiSintel(split='test', aug_params=None, dstype=dstype) + + flow_prev, sequence_prev = None, None + for test_id in range(len(test_dataset)): + image1, image2, (sequence, frame) = test_dataset[test_id] + if sequence != sequence_prev: + flow_prev = None + + padder = InputPadder(image1.shape, padding_factor=padding_factor) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + ) + + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() + + output_dir = os.path.join(output_path, dstype, sequence) + output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame + 1)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + if not no_save_flo: + frame_utils.writeFlow(output_file, flow) + sequence_prev = sequence + + # Save vis flow + if save_vis_flow: + vis_flow_file = output_file.replace('.flo', '.png') + save_vis_flow_tofile(flow, vis_flow_file) + + +@torch.no_grad() +def create_kitti_submission(model, + output_path='kitti_submission', + padding_factor=8, + save_vis_flow=False, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + ): + """ Create submission for the Sintel leaderboard """ + model.eval() + test_dataset = data.KITTI(split='testing', aug_params=None) + + if not os.path.exists(output_path): + os.makedirs(output_path) + + for test_id in range(len(test_dataset)): + image1, image2, (frame_id,) = test_dataset[test_id] + padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + ) + + flow_pr = results_dict['flow_preds'][-1] + + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() + + output_filename = os.path.join(output_path, frame_id) + + if save_vis_flow: + vis_flow_file = output_filename + save_vis_flow_tofile(flow, vis_flow_file) + else: + frame_utils.writeFlowKITTI(output_filename, flow) + + +@torch.no_grad() +def validate_chairs(model, + with_speed_metric=False, + attn_splits_list=False, + corr_radius_list=False, + prop_radius_list=False, + ): + """ Perform evaluation on the FlyingChairs (test) split """ + model.eval() + epe_list = [] + results = {} + + if with_speed_metric: + s0_10_list = [] + s10_40_list = [] + s40plus_list = [] + + val_dataset = data.FlyingChairs(split='validation') + + print('Number of validation image pairs: %d' % len(val_dataset)) + + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, _ = val_dataset[val_id] + + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + ) + + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + + assert flow_pr.size()[-2:] == flow_gt.size()[-2:] + + epe = torch.sum((flow_pr[0].cpu() - flow_gt) ** 2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + if with_speed_metric: + flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt() + valid_mask = (flow_gt_speed < 10) + if valid_mask.max() > 0: + s0_10_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) + if valid_mask.max() > 0: + s10_40_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed > 40) + if valid_mask.max() > 0: + s40plus_list.append(epe[valid_mask].cpu().numpy()) + + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all > 1) + px3 = np.mean(epe_all > 3) + px5 = np.mean(epe_all > 5) + print("Validation Chairs EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (epe, px1, px3, px5)) + results['chairs_epe'] = epe + results['chairs_1px'] = px1 + results['chairs_3px'] = px3 + results['chairs_5px'] = px5 + + if with_speed_metric: + s0_10 = np.mean(np.concatenate(s0_10_list)) + s10_40 = np.mean(np.concatenate(s10_40_list)) + s40plus = np.mean(np.concatenate(s40plus_list)) + + print("Validation Chairs s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % ( + s0_10, + s10_40, + s40plus)) + + results['chairs_s0_10'] = s0_10 + results['chairs_s10_40'] = s10_40 + results['chairs_s40+'] = s40plus + + return results + + +@torch.no_grad() +def validate_things(model, + padding_factor=8, + with_speed_metric=False, + max_val_flow=400, + val_things_clean_only=True, + attn_splits_list=False, + corr_radius_list=False, + prop_radius_list=False, + ): + """ Peform validation using the Things (test) split """ + model.eval() + results = {} + + for dstype in ['frames_cleanpass', 'frames_finalpass']: + if val_things_clean_only: + if dstype == 'frames_finalpass': + continue + + val_dataset = data.FlyingThings3D(dstype=dstype, test_set=True, validate_subset=True, + ) + print('Number of validation image pairs: %d' % len(val_dataset)) + epe_list = [] + + if with_speed_metric: + s0_10_list = [] + s10_40_list = [] + s40plus_list = [] + + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, valid_gt = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape, padding_factor=padding_factor) + image1, image2 = padder.pad(image1, image2) + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + ) + flow_pr = results_dict['flow_preds'][-1] + + flow = padder.unpad(flow_pr[0]).cpu() + + # Evaluation on flow <= max_val_flow + flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt() + valid_gt = valid_gt * (flow_gt_speed < max_val_flow) + valid_gt = valid_gt.contiguous() + + epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() + val = valid_gt >= 0.5 + epe_list.append(epe[val].cpu().numpy()) + + if with_speed_metric: + valid_mask = (flow_gt_speed < 10) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s0_10_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s10_40_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed > 40) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s40plus_list.append(epe[valid_mask].cpu().numpy()) + + epe_list = np.mean(np.concatenate(epe_list)) + + epe = np.mean(epe_list) + + if dstype == 'frames_cleanpass': + dstype = 'things_clean' + if dstype == 'frames_finalpass': + dstype = 'things_final' + + print("Validation Things test set (%s) EPE: %.3f" % (dstype, epe)) + results[dstype + '_epe'] = epe + + if with_speed_metric: + s0_10 = np.mean(np.concatenate(s0_10_list)) + s10_40 = np.mean(np.concatenate(s10_40_list)) + s40plus = np.mean(np.concatenate(s40plus_list)) + + print("Validation Things test (%s) s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % ( + dstype, s0_10, + s10_40, + s40plus)) + + results[dstype + '_s0_10'] = s0_10 + results[dstype + '_s10_40'] = s10_40 + results[dstype + '_s40+'] = s40plus + + return results + + +@torch.no_grad() +def validate_sintel(model, + count_time=False, + padding_factor=8, + with_speed_metric=False, + evaluate_matched_unmatched=False, + attn_splits_list=False, + corr_radius_list=False, + prop_radius_list=False, + ): + """ Peform validation using the Sintel (train) split """ + model.eval() + results = {} + + if count_time: + total_time = 0 + num_runs = 100 + + for dstype in ['clean', 'final']: + val_dataset = data.MpiSintel(split='training', dstype=dstype, + load_occlusion=evaluate_matched_unmatched, + ) + + print('Number of validation image pairs: %d' % len(val_dataset)) + epe_list = [] + + if evaluate_matched_unmatched: + matched_epe_list = [] + unmatched_epe_list = [] + + if with_speed_metric: + s0_10_list = [] + s10_40_list = [] + s40plus_list = [] + + for val_id in range(len(val_dataset)): + if evaluate_matched_unmatched: + image1, image2, flow_gt, valid, noc_valid = val_dataset[val_id] + + # compuate in-image-plane valid mask + in_image_valid = compute_out_of_boundary_mask(flow_gt.unsqueeze(0)).squeeze(0) # [H, W] + + else: + image1, image2, flow_gt, _ = val_dataset[val_id] + + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape, padding_factor=padding_factor) + image1, image2 = padder.pad(image1, image2) + + if count_time and val_id >= 5: # 5 warmup + torch.cuda.synchronize() + time_start = time.perf_counter() + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + ) + + # useful when using parallel branches + flow_pr = results_dict['flow_preds'][-1] + + if count_time and val_id >= 5: + torch.cuda.synchronize() + total_time += time.perf_counter() - time_start + + if val_id >= num_runs + 4: + break + + flow = padder.unpad(flow_pr[0]).cpu() + + epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() + epe_list.append(epe.view(-1).numpy()) + + if evaluate_matched_unmatched: + matched_valid_mask = (noc_valid > 0.5) & (in_image_valid > 0.5) + + if matched_valid_mask.max() > 0: + matched_epe_list.append(epe[matched_valid_mask].cpu().numpy()) + unmatched_epe_list.append(epe[~matched_valid_mask].cpu().numpy()) + + if with_speed_metric: + flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt() + valid_mask = (flow_gt_speed < 10) + if valid_mask.max() > 0: + s0_10_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) + if valid_mask.max() > 0: + s10_40_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed > 40) + if valid_mask.max() > 0: + s40plus_list.append(epe[valid_mask].cpu().numpy()) + + epe_all = np.concatenate(epe_list) + epe = np.mean(epe_all) + px1 = np.mean(epe_all > 1) + px3 = np.mean(epe_all > 3) + px5 = np.mean(epe_all > 5) + + dstype_ori = dstype + + print("Validation Sintel (%s) EPE: %.3f, 1px: %.3f, 3px: %.3f, 5px: %.3f" % (dstype_ori, epe, px1, px3, px5)) + + dstype = 'sintel_' + dstype + + results[dstype + '_epe'] = np.mean(epe_list) + results[dstype + '_1px'] = px1 + results[dstype + '_3px'] = px3 + results[dstype + '_5px'] = px5 + + if with_speed_metric: + s0_10 = np.mean(np.concatenate(s0_10_list)) + s10_40 = np.mean(np.concatenate(s10_40_list)) + s40plus = np.mean(np.concatenate(s40plus_list)) + + print("Validation Sintel (%s) s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % ( + dstype_ori, s0_10, + s10_40, + s40plus)) + + results[dstype + '_s0_10'] = s0_10 + results[dstype + '_s10_40'] = s10_40 + results[dstype + '_s40+'] = s40plus + + if count_time: + print('Time: %.6fs' % (total_time / num_runs)) + break # only the clean pass when counting time + + if evaluate_matched_unmatched: + matched_epe = np.mean(np.concatenate(matched_epe_list)) + unmatched_epe = np.mean(np.concatenate(unmatched_epe_list)) + + print('Validatation Sintel (%s) matched epe: %.3f, unmatched epe: %.3f' % ( + dstype_ori, matched_epe, unmatched_epe)) + + results[dstype + '_matched'] = matched_epe + results[dstype + '_unmatched'] = unmatched_epe + + return results + + +@torch.no_grad() +def validate_kitti(model, + padding_factor=8, + with_speed_metric=False, + average_over_pixels=True, + attn_splits_list=False, + corr_radius_list=False, + prop_radius_list=False, + ): + """ Peform validation using the KITTI-2015 (train) split """ + model.eval() + + val_dataset = data.KITTI(split='training') + print('Number of validation image pairs: %d' % len(val_dataset)) + + out_list, epe_list = [], [] + results = {} + + if with_speed_metric: + if average_over_pixels: + s0_10_list = [] + s10_40_list = [] + s40plus_list = [] + else: + s0_10_epe_sum = 0 + s0_10_valid_samples = 0 + s10_40_epe_sum = 0 + s10_40_valid_samples = 0 + s40plus_epe_sum = 0 + s40plus_valid_samples = 0 + + for val_id in range(len(val_dataset)): + image1, image2, flow_gt, valid_gt = val_dataset[val_id] + image1 = image1[None].cuda() + image2 = image2[None].cuda() + + padder = InputPadder(image1.shape, mode='kitti', padding_factor=padding_factor) + image1, image2 = padder.pad(image1, image2) + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + ) + + # useful when using parallel branches + flow_pr = results_dict['flow_preds'][-1] + + flow = padder.unpad(flow_pr[0]).cpu() + + epe = torch.sum((flow - flow_gt) ** 2, dim=0).sqrt() + mag = torch.sum(flow_gt ** 2, dim=0).sqrt() + + if with_speed_metric: + # flow_gt_speed = torch.sum(flow_gt ** 2, dim=0).sqrt() + flow_gt_speed = mag + + if average_over_pixels: + valid_mask = (flow_gt_speed < 10) * (valid_gt >= 0.5) # note KITTI GT is sparse + if valid_mask.max() > 0: + s0_10_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s10_40_list.append(epe[valid_mask].cpu().numpy()) + + valid_mask = (flow_gt_speed > 40) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s40plus_list.append(epe[valid_mask].cpu().numpy()) + + else: + valid_mask = (flow_gt_speed < 10) * (valid_gt >= 0.5) # note KITTI GT is sparse + if valid_mask.max() > 0: + s0_10_epe_sum += (epe * valid_mask).sum() / valid_mask.sum() + s0_10_valid_samples += 1 + + valid_mask = (flow_gt_speed >= 10) * (flow_gt_speed <= 40) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s10_40_epe_sum += (epe * valid_mask).sum() / valid_mask.sum() + s10_40_valid_samples += 1 + + valid_mask = (flow_gt_speed > 40) * (valid_gt >= 0.5) + if valid_mask.max() > 0: + s40plus_epe_sum += (epe * valid_mask).sum() / valid_mask.sum() + s40plus_valid_samples += 1 + + epe = epe.view(-1) + mag = mag.view(-1) + val = valid_gt.view(-1) >= 0.5 + + out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() + + if average_over_pixels: + epe_list.append(epe[val].cpu().numpy()) + else: + epe_list.append(epe[val].mean().item()) + + out_list.append(out[val].cpu().numpy()) + + if average_over_pixels: + epe_list = np.concatenate(epe_list) + else: + epe_list = np.array(epe_list) + out_list = np.concatenate(out_list) + + epe = np.mean(epe_list) + f1 = 100 * np.mean(out_list) + + print("Validation KITTI EPE: %.3f, F1-all: %.3f" % (epe, f1)) + results['kitti_epe'] = epe + results['kitti_f1'] = f1 + + if with_speed_metric: + if average_over_pixels: + s0_10 = np.mean(np.concatenate(s0_10_list)) + s10_40 = np.mean(np.concatenate(s10_40_list)) + s40plus = np.mean(np.concatenate(s40plus_list)) + else: + s0_10 = s0_10_epe_sum / s0_10_valid_samples + s10_40 = s10_40_epe_sum / s10_40_valid_samples + s40plus = s40plus_epe_sum / s40plus_valid_samples + + print("Validation KITTI s0_10: %.3f, s10_40: %.3f, s40+: %.3f" % ( + s0_10, + s10_40, + s40plus)) + + results['kitti_s0_10'] = s0_10 + results['kitti_s10_40'] = s10_40 + results['kitti_s40+'] = s40plus + + return results + + +@torch.no_grad() +def inference_on_dir(model, + inference_dir, + output_path='output', + padding_factor=8, + inference_size=None, + paired_data=False, # dir of paired testdata instead of a sequence + save_flo_flow=False, # save as .flo for quantative evaluation + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + pred_bidir_flow=False, + fwd_bwd_consistency_check=False, + ): + """ Inference on a directory """ + model.eval() + + if fwd_bwd_consistency_check: + assert pred_bidir_flow + + if not os.path.exists(output_path): + os.makedirs(output_path) + + filenames = sorted(glob(inference_dir + '/*')) + print('%d images found' % len(filenames)) + + stride = 2 if paired_data else 1 + + if paired_data: + assert len(filenames) % 2 == 0 + + for test_id in range(0, len(filenames) - 1, stride): + + image1 = frame_utils.read_gen(filenames[test_id]) + image2 = frame_utils.read_gen(filenames[test_id + 1]) + + image1 = np.array(image1).astype(np.uint8) + image2 = np.array(image2).astype(np.uint8) + + if len(image1.shape) == 2: # gray image, for example, HD1K + image1 = np.tile(image1[..., None], (1, 1, 3)) + image2 = np.tile(image2[..., None], (1, 1, 3)) + else: + image1 = image1[..., :3] + image2 = image2[..., :3] + + image1 = torch.from_numpy(image1).permute(2, 0, 1).float() + image2 = torch.from_numpy(image2).permute(2, 0, 1).float() + + if inference_size is None: + padder = InputPadder(image1.shape, padding_factor=padding_factor) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + else: + image1, image2 = image1[None].cuda(), image2[None].cuda() + + # resize before inference + if inference_size is not None: + assert isinstance(inference_size, list) or isinstance(inference_size, tuple) + ori_size = image1.shape[-2:] + image1 = F.interpolate(image1, size=inference_size, mode='bilinear', + align_corners=True) + image2 = F.interpolate(image2, size=inference_size, mode='bilinear', + align_corners=True) + + results_dict = model(image1, image2, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + pred_bidir_flow=pred_bidir_flow, + ) + + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + + # resize back + if inference_size is not None: + flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear', + align_corners=True) + flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1] + flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2] + + if inference_size is None: + flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy() # [H, W, 2] + else: + flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2] + + output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_flow.png') + + # save vis flow + save_vis_flow_tofile(flow, output_file) + + # also predict backward flow + if pred_bidir_flow: + assert flow_pr.size(0) == 2 # [2, H, W, 2] + + if inference_size is None: + flow_bwd = padder.unpad(flow_pr[1]).permute(1, 2, 0).cpu().numpy() # [H, W, 2] + else: + flow_bwd = flow_pr[1].permute(1, 2, 0).cpu().numpy() # [H, W, 2] + + output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_flow_bwd.png') + + # save vis flow + save_vis_flow_tofile(flow_bwd, output_file) + + # forward-backward consistency check + # occlusion is 1 + if fwd_bwd_consistency_check: + if inference_size is None: + fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] + bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] + else: + fwd_flow = flow_pr[0].unsqueeze(0) + bwd_flow = flow_pr[1].unsqueeze(0) + + fwd_occ, bwd_occ = forward_backward_consistency_check(fwd_flow, bwd_flow) # [1, H, W] float + + fwd_occ_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_occ.png') + bwd_occ_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_occ_bwd.png') + + Image.fromarray((fwd_occ[0].cpu().numpy() * 255.).astype(np.uint8)).save(fwd_occ_file) + Image.fromarray((bwd_occ[0].cpu().numpy() * 255.).astype(np.uint8)).save(bwd_occ_file) + + if save_flo_flow: + output_file = os.path.join(output_path, os.path.basename(filenames[test_id])[:-4] + '_pred.flo') + frame_utils.writeFlow(output_file, flow) diff --git a/GMFlow/gmflow/__init__.py b/GMFlow/gmflow/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/GMFlow/gmflow/__pycache__/__init__.cpython-310.pyc b/GMFlow/gmflow/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2c0447e0a009abcf60240f1d1e93cb8ba60c21 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/__init__.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/__init__.cpython-38.pyc b/GMFlow/gmflow/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55e7230bf7f289a37146f01b18dd92ddbaa45d44 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/__init__.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/__init__.cpython-39.pyc b/GMFlow/gmflow/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d709828469629d7b023ad13d9964a55f53266bb1 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/__init__.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/backbone.cpython-310.pyc b/GMFlow/gmflow/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c97074ab83bd528d114930fdd38b2774b1d22ddf Binary files /dev/null and b/GMFlow/gmflow/__pycache__/backbone.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/backbone.cpython-38.pyc b/GMFlow/gmflow/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd393f1e19b5dbc6656ea93079a0d882e8e7ff80 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/backbone.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/backbone.cpython-39.pyc b/GMFlow/gmflow/__pycache__/backbone.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f0a3ce7838e9ec61b0c2b8af04c67d1cb8aef1 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/backbone.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/geometry.cpython-310.pyc b/GMFlow/gmflow/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b48b20412a527305bfd9f25a43c42d8abf710c5 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/geometry.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/geometry.cpython-38.pyc b/GMFlow/gmflow/__pycache__/geometry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..944abfe9fe9fd1cf13af905af71773acf3cb15b3 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/geometry.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/geometry.cpython-39.pyc b/GMFlow/gmflow/__pycache__/geometry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5288e71586db6bf17fb953e0a51134e0db99dec Binary files /dev/null and b/GMFlow/gmflow/__pycache__/geometry.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/gmflow.cpython-310.pyc b/GMFlow/gmflow/__pycache__/gmflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f488bbd9b95d56387864a06e9c3a2cd83dc7513 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/gmflow.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/gmflow.cpython-38.pyc b/GMFlow/gmflow/__pycache__/gmflow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a9023804c0b1f08b5f6a546023da4988b0b357 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/gmflow.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/gmflow.cpython-39.pyc b/GMFlow/gmflow/__pycache__/gmflow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115886a97c9c986e39d16d9bb8bd7f295bb350c5 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/gmflow.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/matching.cpython-310.pyc b/GMFlow/gmflow/__pycache__/matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..862b1c8259396bec67bdeeac72136f38a0819f6a Binary files /dev/null and b/GMFlow/gmflow/__pycache__/matching.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/matching.cpython-38.pyc b/GMFlow/gmflow/__pycache__/matching.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca828908e35aa44d5a382c3170c5d733084b770b Binary files /dev/null and b/GMFlow/gmflow/__pycache__/matching.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/matching.cpython-39.pyc b/GMFlow/gmflow/__pycache__/matching.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f07bf70d96efddcd56fd159bf9e8ff69cdd5bf7f Binary files /dev/null and b/GMFlow/gmflow/__pycache__/matching.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/position.cpython-310.pyc b/GMFlow/gmflow/__pycache__/position.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30c4400f36e48a3dabe5a9183d40220251a5aebc Binary files /dev/null and b/GMFlow/gmflow/__pycache__/position.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/position.cpython-38.pyc b/GMFlow/gmflow/__pycache__/position.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c24cd2ab47fe67adf67ed8f803f5ae988d12ee7 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/position.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/position.cpython-39.pyc b/GMFlow/gmflow/__pycache__/position.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b93c69a8b9e92728bc69b62ed9d6383e942cd834 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/position.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/transformer.cpython-310.pyc b/GMFlow/gmflow/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..906d3c7d695eb81db7f90a267e93c0ba59392143 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/transformer.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/transformer.cpython-38.pyc b/GMFlow/gmflow/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a62f5271ed84c2c20c7331bf718662c9aa2e5f0 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/transformer.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/transformer.cpython-39.pyc b/GMFlow/gmflow/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e29b1454bfc1ed5193cbaa67ef4030f70dc22b5 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/transformer.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/trident_conv.cpython-310.pyc b/GMFlow/gmflow/__pycache__/trident_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2a72b2dcfc908d4429335d90eeeb0bedd3e385d Binary files /dev/null and b/GMFlow/gmflow/__pycache__/trident_conv.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/trident_conv.cpython-38.pyc b/GMFlow/gmflow/__pycache__/trident_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..895087256abcb8e32e92c6bc3867a228fbceea57 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/trident_conv.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/trident_conv.cpython-39.pyc b/GMFlow/gmflow/__pycache__/trident_conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..614e59267de602eef03b0acca520ef46cddec860 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/trident_conv.cpython-39.pyc differ diff --git a/GMFlow/gmflow/__pycache__/utils.cpython-310.pyc b/GMFlow/gmflow/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f443ca7fc92a7436f69939d27d4135af88398f9 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/utils.cpython-310.pyc differ diff --git a/GMFlow/gmflow/__pycache__/utils.cpython-38.pyc b/GMFlow/gmflow/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f9eddd259e8924429672c8037a9b79e35bb4251 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/utils.cpython-38.pyc differ diff --git a/GMFlow/gmflow/__pycache__/utils.cpython-39.pyc b/GMFlow/gmflow/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290bde8f1b9771807cdf186e9f30a65ecf27b7e0 Binary files /dev/null and b/GMFlow/gmflow/__pycache__/utils.cpython-39.pyc differ diff --git a/GMFlow/gmflow/backbone.py b/GMFlow/gmflow/backbone.py new file mode 100755 index 0000000000000000000000000000000000000000..d5c92b7d8698a41d11b29f084b3ab4953dd2a7bd --- /dev/null +++ b/GMFlow/gmflow/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/GMFlow/gmflow/geometry.py b/GMFlow/gmflow/geometry.py new file mode 100755 index 0000000000000000000000000000000000000000..207e98fded56c0e7e63d63626ddace65b910bf9c --- /dev/null +++ b/GMFlow/gmflow/geometry.py @@ -0,0 +1,96 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ diff --git a/GMFlow/gmflow/gmflow.py b/GMFlow/gmflow/gmflow.py new file mode 100755 index 0000000000000000000000000000000000000000..6bd0ff933ccbc04680002110171e05102e2805b1 --- /dev/null +++ b/GMFlow/gmflow/gmflow.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer, FeatureFlowAttention +from .matching import global_correlation_softmax, local_correlation_softmax +from .geometry import flow_warp +from .utils import normalize_img, feature_add_position + + +class GMFlow(nn.Module): + def __init__(self, + num_scales=1, + upsample_factor=8, + feature_channels=128, + attention_type='swin', + num_transformer_layers=6, + ffn_dim_expansion=4, + num_head=1, + **kwargs, + ): + super(GMFlow, self).__init__() + + self.num_scales = num_scales + self.feature_channels = feature_channels + self.upsample_factor = upsample_factor + self.attention_type = attention_type + self.num_transformer_layers = num_transformer_layers + + # CNN backbone + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # flow propagation with self-attn + self.feature_flow_attn = FeatureFlowAttention(in_channels=feature_channels) + + # convex upsampling: concat feature0 and flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + ): + if bilinear: + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * upsample_factor + + else: + # convex upsampling + concat = torch.cat((flow, feature), dim=1) + + mask = self.upsampler(concat) + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, self.upsample_factor * h, + self.upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + def forward(self, img0, img1, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + pred_bidir_flow=False, + **kwargs, + ): + + results_dict = {} + flow_preds = [] + + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # import ipdb; ipdb.set_trace() + # resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if scale_idx > 0: + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + flow = flow.detach() + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + + attn_splits = attn_splits_list[scale_idx] + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, attn_num_splits=attn_splits) + + # correlation and softmax + if corr_radius == -1: # global matching + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + else: # local matching + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + # upsample to the original resolution for supervison + if self.training: # only need to upsample intermediate flow predictions at training time + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor) + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if pred_bidir_flow and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius) + + # bilinear upsampling at training time except the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, upsample_factor=upsample_factor) + flow_preds.append(flow_up) + + if scale_idx == self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0) + flow_preds.append(flow_up) + + results_dict.update({'flow_preds': flow_preds}) + + return results_dict diff --git a/GMFlow/gmflow/matching.py b/GMFlow/gmflow/matching.py new file mode 100755 index 0000000000000000000000000000000000000000..e920081552c3040c95b6a7b55779249f76cbad4b --- /dev/null +++ b/GMFlow/gmflow/matching.py @@ -0,0 +1,83 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob diff --git a/GMFlow/gmflow/position.py b/GMFlow/gmflow/position.py new file mode 100755 index 0000000000000000000000000000000000000000..42435d0fef24737d3cae7463ca411a635979cf33 --- /dev/null +++ b/GMFlow/gmflow/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/GMFlow/gmflow/transformer.py b/GMFlow/gmflow/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..dcf657c86959c2b4528c12f698cd6a26874e432f --- /dev/null +++ b/GMFlow/gmflow/transformer.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, merge_splits + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=256, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + with_shift=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + attention_type='swin', + ffn_dim_expansion=4, + **kwargs, + ): + super(FeatureTransformer, self).__init__() + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=True if attention_type == 'swin' and i % 2 == 1 else False, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + if self.attention_type == 'swin' and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for layer in self.layers: + concat0 = layer(concat0, concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 + + +class FeatureFlowAttention(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(FeatureFlowAttention, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, 2, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, 2) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/GMFlow/gmflow/trident_conv.py b/GMFlow/gmflow/trident_conv.py new file mode 100755 index 0000000000000000000000000000000000000000..445663c2d1065e10899f728ad2628e313f218024 --- /dev/null +++ b/GMFlow/gmflow/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/GMFlow/gmflow/utils.py b/GMFlow/gmflow/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..60e7d8994bfc2f66e333ce4d60bbbd4488d50553 --- /dev/null +++ b/GMFlow/gmflow/utils.py @@ -0,0 +1,97 @@ +import torch +from .position import PositionEmbeddingSine + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + # if h % num_splits: + # feature = feature[:, :, :-1, :] + # if w % num_splits: + # feature = feature[:, :, :, :-1] + # b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + # if h % num_splits: + # feature = feature[:, :, :-1, :] + # if w % num_splits: + # feature = feature[:, :, :, :-1] + # b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 / 255. - mean) / std + img1 = (img1 / 255. - mean) / std + + return img0, img1 + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + # import ipdb; ipdb.set_trace() + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 diff --git a/GMFlow/loss.py b/GMFlow/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f9f0b0216b2a277f422daab92d5c5b4f53458ae3 --- /dev/null +++ b/GMFlow/loss.py @@ -0,0 +1,37 @@ +import torch + + +def flow_loss_func(flow_preds, flow_gt, valid, + gamma=0.9, + max_flow=400, + **kwargs, + ): + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W] + valid = (valid >= 0.5) & (mag < max_flow) + + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + + i_loss = (flow_preds[i] - flow_gt).abs() + + flow_loss += i_weight * (valid[:, None] * i_loss).mean() + + epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() + + if valid.max() < 0.5: + pass + + epe = epe.view(-1)[valid.view(-1)] + + metrics = { + 'epe': epe.mean().item(), + '1px': (epe > 1).float().mean().item(), + '3px': (epe > 3).float().mean().item(), + '5px': (epe > 5).float().mean().item(), + } + + return flow_loss, metrics diff --git a/GMFlow/main.py b/GMFlow/main.py new file mode 100755 index 0000000000000000000000000000000000000000..281b402e2a9032dd992dd8dc126e9bd897d86f3d --- /dev/null +++ b/GMFlow/main.py @@ -0,0 +1,557 @@ +import torch +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +import argparse +import numpy as np +import os + +from data import build_train_dataset +from gmflow.gmflow import GMFlow +from loss import flow_loss_func +from evaluate import (validate_chairs, validate_things, validate_sintel, validate_kitti, + create_sintel_submission, create_kitti_submission, inference_on_dir) + +from utils.logger import Logger +from utils import misc +from utils.dist_utils import get_dist_info, init_dist, setup_for_distributed + + +def get_args_parser(): + parser = argparse.ArgumentParser() + + # dataset + parser.add_argument('--checkpoint_dir', default='tmp', type=str, + help='where to save the training log and models') + parser.add_argument('--stage', default='chairs', type=str, + help='training stage') + parser.add_argument('--image_size', default=[384, 512], type=int, nargs='+', + help='image size for training') + parser.add_argument('--padding_factor', default=16, type=int, + help='the input should be divisible by padding_factor, otherwise do padding') + + parser.add_argument('--max_flow', default=400, type=int, + help='exclude very large motions during training') + parser.add_argument('--val_dataset', default=['chairs'], type=str, nargs='+', + help='validation dataset') + parser.add_argument('--with_speed_metric', action='store_true', + help='with speed metric when evaluation') + + # training + parser.add_argument('--lr', default=4e-4, type=float) + parser.add_argument('--batch_size', default=12, type=int) + parser.add_argument('--num_workers', default=4, type=int) + parser.add_argument('--weight_decay', default=1e-4, type=float) + parser.add_argument('--grad_clip', default=1.0, type=float) + parser.add_argument('--num_steps', default=100000, type=int) + parser.add_argument('--seed', default=326, type=int) + parser.add_argument('--summary_freq', default=100, type=int) + parser.add_argument('--val_freq', default=10000, type=int) + parser.add_argument('--save_ckpt_freq', default=10000, type=int) + parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int) + + # resume pretrained model or resume training + parser.add_argument('--resume', default=None, type=str, + help='resume from pretrain model for finetuing or resume from terminated training') + parser.add_argument('--strict_resume', action='store_true') + parser.add_argument('--no_resume_optimizer', action='store_true') + + # GMFlow model + parser.add_argument('--num_scales', default=1, type=int, + help='basic gmflow model uses a single 1/8 feature, the refinement uses 1/4 feature') + parser.add_argument('--feature_channels', default=128, type=int) + parser.add_argument('--upsample_factor', default=8, type=int) + parser.add_argument('--num_transformer_layers', default=6, type=int) + parser.add_argument('--num_head', default=1, type=int) + parser.add_argument('--attention_type', default='swin', type=str) + parser.add_argument('--ffn_dim_expansion', default=4, type=int) + + parser.add_argument('--attn_splits_list', default=[2], type=int, nargs='+', + help='number of splits in attention') + parser.add_argument('--corr_radius_list', default=[-1], type=int, nargs='+', + help='correlation radius for matching, -1 indicates global matching') + parser.add_argument('--prop_radius_list', default=[-1], type=int, nargs='+', + help='self-attention radius for flow propagation, -1 indicates global attention') + + # loss + parser.add_argument('--gamma', default=0.9, type=float, + help='loss weight') + + # evaluation + parser.add_argument('--eval', action='store_true') + parser.add_argument('--save_eval_to_file', action='store_true') + parser.add_argument('--evaluate_matched_unmatched', action='store_true') + + # inference on a directory + parser.add_argument('--inference_dir', default=None, type=str) + parser.add_argument('--inference_size', default=None, type=int, nargs='+', + help='can specify the inference size') + parser.add_argument('--dir_paired_data', action='store_true', + help='Paired data in a dir instead of a sequence') + parser.add_argument('--save_flo_flow', action='store_true') + parser.add_argument('--pred_bidir_flow', action='store_true', + help='predict bidirectional flow') + parser.add_argument('--fwd_bwd_consistency_check', action='store_true', + help='forward backward consistency check with bidirection flow') + + # predict on sintel and kitti test set for submission + parser.add_argument('--submission', action='store_true', + help='submission to sintel or kitti test sets') + parser.add_argument('--output_path', default='output', type=str, + help='where to save the prediction results') + parser.add_argument('--save_vis_flow', action='store_true', + help='visualize flow prediction as .png image') + parser.add_argument('--no_save_flo', action='store_true', + help='not save flow as .flo') + + # distributed training + parser.add_argument('--local_rank', default=0, type=int) + parser.add_argument('--distributed', action='store_true') + parser.add_argument('--launcher', default='none', type=str, choices=['none', 'pytorch']) + parser.add_argument('--gpu_ids', default=0, type=int, nargs='+') + + parser.add_argument('--count_time', action='store_true', + help='measure the inference time on sintel') + + return parser + + +def main(args): + if not args.eval and not args.submission and args.inference_dir is None: + if args.local_rank == 0: + print('pytorch version:', torch.__version__) + print(args) + misc.save_args(args) + misc.check_path(args.checkpoint_dir) + misc.save_command(args.checkpoint_dir) + + seed = args.seed + torch.manual_seed(seed) + np.random.seed(seed) + + torch.backends.cudnn.benchmark = True + + if args.launcher == 'none': + args.distributed = False + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + args.distributed = True + + # adjust batch size for each gpu + assert args.batch_size % torch.cuda.device_count() == 0 + args.batch_size = args.batch_size // torch.cuda.device_count() + + dist_params = dict(backend='nccl') + init_dist(args.launcher, **dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + args.gpu_ids = range(world_size) + device = torch.device('cuda:{}'.format(args.local_rank)) + + setup_for_distributed(args.local_rank == 0) + + # model + model = GMFlow(feature_channels=args.feature_channels, + num_scales=args.num_scales, + upsample_factor=args.upsample_factor, + num_head=args.num_head, + attention_type=args.attention_type, + ffn_dim_expansion=args.ffn_dim_expansion, + num_transformer_layers=args.num_transformer_layers, + ).to(device) + + if not args.eval and not args.submission and not args.inference_dir: + print('Model definition:') + print(model) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model.to(device), + device_ids=[args.local_rank], + output_device=args.local_rank) + model_without_ddp = model.module + else: + if torch.cuda.device_count() > 1: + print('Use %d GPUs' % torch.cuda.device_count()) + model = torch.nn.DataParallel(model) + + model_without_ddp = model.module + else: + model_without_ddp = model + + num_params = sum(p.numel() for p in model.parameters()) + print('Number of params:', num_params) + if not args.eval and not args.submission and args.inference_dir is None: + save_name = '%d_parameters' % num_params + open(os.path.join(args.checkpoint_dir, save_name), 'a').close() + + optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr, + weight_decay=args.weight_decay) + + start_epoch = 0 + start_step = 0 + # resume checkpoints + if args.resume: + print('Load checkpoint: %s' % args.resume) + + loc = 'cuda:{}'.format(args.local_rank) + checkpoint = torch.load(args.resume, map_location=loc) + + weights = checkpoint['model'] if 'model' in checkpoint else checkpoint + + model_without_ddp.load_state_dict(weights, strict=args.strict_resume) + + if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \ + args.no_resume_optimizer: + print('Load optimizer') + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + start_step = checkpoint['step'] + + print('start_epoch: %d, start_step: %d' % (start_epoch, start_step)) + + # evaluate + if args.eval: + val_results = {} + + if 'chairs' in args.val_dataset: + results_dict = validate_chairs(model_without_ddp, + with_speed_metric=args.with_speed_metric, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + + val_results.update(results_dict) + + if 'things' in args.val_dataset: + results_dict = validate_things(model_without_ddp, + padding_factor=args.padding_factor, + with_speed_metric=args.with_speed_metric, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + val_results.update(results_dict) + + if 'sintel' in args.val_dataset: + results_dict = validate_sintel(model_without_ddp, + count_time=args.count_time, + padding_factor=args.padding_factor, + with_speed_metric=args.with_speed_metric, + evaluate_matched_unmatched=args.evaluate_matched_unmatched, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + val_results.update(results_dict) + + if 'kitti' in args.val_dataset: + results_dict = validate_kitti(model_without_ddp, + padding_factor=args.padding_factor, + with_speed_metric=args.with_speed_metric, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + val_results.update(results_dict) + + if args.save_eval_to_file: + misc.check_path(args.checkpoint_dir) + val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') + with open(val_file, 'a') as f: + f.write('\neval results after training done\n\n') + metrics = ['chairs_epe', 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+', + 'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40', 'things_clean_s40+', + 'things_final_epe', 'things_final_s0_10', 'things_final_s10_40', 'things_final_s40+', + 'sintel_clean_epe', 'sintel_clean_s0_10', 'sintel_clean_s10_40', 'sintel_clean_s40+', + 'sintel_final_epe', 'sintel_final_s0_10', 'sintel_final_s10_40', 'sintel_final_s40+', + 'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+', + ] + eval_metrics = [] + for metric in metrics: + if metric in val_results.keys(): + eval_metrics.append(metric) + + metrics_values = [val_results[metric] for metric in eval_metrics] + + num_metrics = len(eval_metrics) + + # save as markdown format + f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics)) + f.write(("| {:20.3f} " * num_metrics).format(*metrics_values)) + + f.write('\n\n') + + return + + # Sintel and KITTI submission + if args.submission: + # NOTE: args.val_dataset is a list + if args.val_dataset[0] == 'sintel': + create_sintel_submission(model_without_ddp, + output_path=args.output_path, + padding_factor=args.padding_factor, + save_vis_flow=args.save_vis_flow, + no_save_flo=args.no_save_flo, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + elif args.val_dataset[0] == 'kitti': + create_kitti_submission(model_without_ddp, + output_path=args.output_path, + padding_factor=args.padding_factor, + save_vis_flow=args.save_vis_flow, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + else: + raise ValueError(f'Not supported dataset for submission') + + return + + # inferece on a dir + if args.inference_dir is not None: + inference_on_dir(model_without_ddp, + inference_dir=args.inference_dir, + output_path=args.output_path, + padding_factor=args.padding_factor, + inference_size=args.inference_size, + paired_data=args.dir_paired_data, + save_flo_flow=args.save_flo_flow, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + pred_bidir_flow=args.pred_bidir_flow, + fwd_bwd_consistency_check=args.fwd_bwd_consistency_check, + ) + + return + + # training datset + train_dataset = build_train_dataset(args) + print('Number of training images:', len(train_dataset)) + + # Multi-processing + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, + num_replicas=torch.cuda.device_count(), + rank=args.local_rank) + else: + train_sampler = None + + shuffle = False if args.distributed else True + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=shuffle, num_workers=args.num_workers, + pin_memory=True, drop_last=True, + sampler=train_sampler) + + last_epoch = start_step if args.resume and start_step > 0 else -1 + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, args.lr, + args.num_steps + 10, + pct_start=0.05, + cycle_momentum=False, + anneal_strategy='cos', + last_epoch=last_epoch, + ) + + if args.local_rank == 0: + summary_writer = SummaryWriter(args.checkpoint_dir) + logger = Logger(lr_scheduler, summary_writer, args.summary_freq, + start_step=start_step) + + total_steps = start_step + epoch = start_epoch + print('Start training') + + while total_steps < args.num_steps: + model.train() + + # mannual change random seed for shuffling every epoch + if args.distributed: + train_sampler.set_epoch(epoch) + + for i, sample in enumerate(train_loader): + img1, img2, flow_gt, valid = [x.to(device) for x in sample] + + results_dict = model(img1, img2, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + + flow_preds = results_dict['flow_preds'] + + loss, metrics = flow_loss_func(flow_preds, flow_gt, valid, + gamma=args.gamma, + max_flow=args.max_flow, + ) + + if isinstance(loss, float): + continue + + if torch.isnan(loss): + continue + + metrics.update({'total_loss': loss.item()}) + + # more efficient zero_grad + for param in model_without_ddp.parameters(): + param.grad = None + + loss.backward() + + # Gradient clipping + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + + optimizer.step() + + lr_scheduler.step() + + if args.local_rank == 0: + logger.push(metrics) + + logger.add_image_summary(img1, img2, flow_preds, flow_gt) + + total_steps += 1 + + if total_steps % args.save_ckpt_freq == 0 or total_steps == args.num_steps: + if args.local_rank == 0: + checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps) + torch.save({ + 'model': model_without_ddp.state_dict() + }, checkpoint_path) + + if total_steps % args.save_latest_ckpt_freq == 0: + checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth') + + if args.local_rank == 0: + torch.save({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'step': total_steps, + 'epoch': epoch, + }, checkpoint_path) + + if total_steps % args.val_freq == 0: + print('Start validation') + + val_results = {} + # support validation on multiple datasets + if 'chairs' in args.val_dataset: + results_dict = validate_chairs(model_without_ddp, + with_speed_metric=args.with_speed_metric, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + if args.local_rank == 0: + val_results.update(results_dict) + + if 'things' in args.val_dataset: + results_dict = validate_things(model_without_ddp, + padding_factor=args.padding_factor, + with_speed_metric=args.with_speed_metric, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + if args.local_rank == 0: + val_results.update(results_dict) + + if 'sintel' in args.val_dataset: + results_dict = validate_sintel(model_without_ddp, + count_time=args.count_time, + padding_factor=args.padding_factor, + with_speed_metric=args.with_speed_metric, + evaluate_matched_unmatched=args.evaluate_matched_unmatched, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + if args.local_rank == 0: + val_results.update(results_dict) + + if 'kitti' in args.val_dataset: + results_dict = validate_kitti(model_without_ddp, + padding_factor=args.padding_factor, + with_speed_metric=args.with_speed_metric, + attn_splits_list=args.attn_splits_list, + corr_radius_list=args.corr_radius_list, + prop_radius_list=args.prop_radius_list, + ) + if args.local_rank == 0: + val_results.update(results_dict) + + if args.local_rank == 0: + logger.write_dict(val_results) + + # Save validation results + val_file = os.path.join(args.checkpoint_dir, 'val_results.txt') + with open(val_file, 'a') as f: + f.write('step: %06d\n' % total_steps) + if args.evaluate_matched_unmatched: + metrics = ['chairs_epe', + 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+', + 'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40', + 'things_clean_s40+', + 'sintel_clean_epe', 'sintel_clean_matched', 'sintel_clean_unmatched', + 'sintel_clean_s0_10', 'sintel_clean_s10_40', + 'sintel_clean_s40+', + 'sintel_final_epe', 'sintel_final_matched', 'sintel_final_unmatched', + 'sintel_final_s0_10', 'sintel_final_s10_40', + 'sintel_final_s40+', + 'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+', + ] + else: + metrics = ['chairs_epe', 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+', + 'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40', + 'things_clean_s40+', + 'sintel_clean_epe', 'sintel_clean_s0_10', 'sintel_clean_s10_40', + 'sintel_clean_s40+', + 'sintel_final_epe', 'sintel_final_s0_10', 'sintel_final_s10_40', + 'sintel_final_s40+', + 'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+', + ] + + eval_metrics = [] + for metric in metrics: + if metric in val_results.keys(): + eval_metrics.append(metric) + + metrics_values = [val_results[metric] for metric in eval_metrics] + + num_metrics = len(eval_metrics) + + # save as markdown format + if args.evaluate_matched_unmatched: + f.write(("| {:>25} " * num_metrics + '\n').format(*eval_metrics)) + f.write(("| {:25.3f} " * num_metrics).format(*metrics_values)) + else: + f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics)) + f.write(("| {:20.3f} " * num_metrics).format(*metrics_values)) + + f.write('\n\n') + + model.train() + + if total_steps >= args.num_steps: + print('Training done') + + return + + epoch += 1 + + +if __name__ == '__main__': + parser = get_args_parser() + args = parser.parse_args() + + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + main(args) diff --git a/GMFlow/scripts/demo.sh b/GMFlow/scripts/demo.sh new file mode 100755 index 0000000000000000000000000000000000000000..3aa5d2675781286d81512446af5865ff222491be --- /dev/null +++ b/GMFlow/scripts/demo.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash + +# inference GMFlow without refinement + +# sintel + +# only predict forward flow +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/sintel_market_1 \ +--output_path output/gmflow-norefine-sintel_market_1 \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth + +# predict forward & backward flow +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/sintel_market_1 \ +--output_path output/gmflow-norefine-sintel_market_1 \ +--pred_bidir_flow \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth + + +# predict forward & backward flow with forward-backward consistency check +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/sintel_market_1 \ +--output_path output/gmflow-norefine-sintel_market_1 \ +--pred_bidir_flow \ +--fwd_bwd_consistency_check \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth + + +# davis + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/davis_breakdance-flare \ +--output_path output/gmflow-norefine-davis_breakdance-flare \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth + + + + +# inference GMFlow with refinement + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/davis_breakdance-flare \ +--output_path output/gmflow-withrefine-davis_breakdance-flare \ +--resume pretrained/gmflow_with_refine_sintel-3ed1cf48.pth \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 + + + + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--inference_dir demo/sintel_test_clean_market_1 \ +--output_path output/gmflow-norefine-sintel_test_clean_market_1 \ +--pred_bidir_flow \ +--fwd_bwd_consistency_check \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth + + diff --git a/GMFlow/scripts/evaluate.sh b/GMFlow/scripts/evaluate.sh new file mode 100755 index 0000000000000000000000000000000000000000..fa6dbefeddd2292a7fe5bfc277501080ccdd007a --- /dev/null +++ b/GMFlow/scripts/evaluate.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash + +# evaluate GMFlow without refinement + +# evaluate chairs & things trained model on things and sintel (Table 3 of GMFlow paper) +# the output should be: +# Number of validation image pairs: 1024 +# Validation Things test set (things_clean) EPE: 3.475 +# Validation Things test (things_clean) s0_10: 0.666, s10_40: 1.310, s40+: 8.968 +# Number of validation image pairs: 1041 +# Validation Sintel (clean) EPE: 1.495, 1px: 0.161, 3px: 0.059, 5px: 0.040 +# Validation Sintel (clean) s0_10: 0.457, s10_40: 1.770, s40+: 8.257 +# Number of validation image pairs: 1041 +# Validation Sintel (final) EPE: 2.955, 1px: 0.209, 3px: 0.098, 5px: 0.071 +# Validation Sintel (final) s0_10: 0.725, s10_40: 3.446, s40+: 17.701 + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--eval \ +--resume pretrained/gmflow_things-e9887eda.pth \ +--val_dataset things sintel \ +--with_speed_metric + + + +# evaluate GMFlow with refinement + +# evaluate chairs & things trained model on things and sintel (Table 3 of GMFlow paper) +# the output should be: +# Validation Things test set (things_clean) EPE: 2.804 +# Validation Things test (things_clean) s0_10: 0.527, s10_40: 1.009, s40+: 7.314 +# Number of validation image pairs: 1041 +# Validation Sintel (clean) EPE: 1.084, 1px: 0.092, 3px: 0.040, 5px: 0.028 +# Validation Sintel (clean) s0_10: 0.303, s10_40: 1.252, s40+: 6.261 +# Number of validation image pairs: 1041 +# Validation Sintel (final) EPE: 2.475, 1px: 0.147, 3px: 0.077, 5px: 0.058 +# Validation Sintel (final) s0_10: 0.511, s10_40: 2.810, s40+: 15.669 + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--eval \ +--resume pretrained/gmflow_with_refine_things-36579974.pth \ +--val_dataset things sintel \ +--with_speed_metric \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 + + + +# evaluate matched & matched on sintel + +# evaluate GMFlow without refinement + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--eval \ +--evaluate_matched_unmatched \ +--resume pretrained/gmflow_things-e9887eda.pth \ +--val_dataset sintel + +# evaluate GMFlow with refinement + +CUDA_VISIBLE_DEVICES=0 python main.py \ +--eval \ +--evaluate_matched_unmatched \ +--resume pretrained/gmflow_with_refine_things-36579974.pth \ +--val_dataset sintel \ +--with_speed_metric \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 + + + + + + + + diff --git a/GMFlow/scripts/submission.sh b/GMFlow/scripts/submission.sh new file mode 100755 index 0000000000000000000000000000000000000000..c19223eafc3bb379a528cb16c7ff19f467a1c17a --- /dev/null +++ b/GMFlow/scripts/submission.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + + +# generate prediction results for submission on sintel and kitti online servers + + +# GMFlow without refinement + +# submission to sintel +CUDA_VISIBLE_DEVICES=0 python main.py \ +--submission \ +--output_path submission/sintel-gmflow-norefine \ +--val_dataset sintel \ +--resume pretrained/gmflow_sintel-0c07dcb3.pth + +# submission to kitti +CUDA_VISIBLE_DEVICES=0 python main.py \ +--submission \ +--output_path submission/kitti-gmflow-norefine \ +--val_dataset kitti \ +--resume pretrained/gmflow_kitti-285701a8.pth + + +# you can also visualize the predictions before submission +# CUDA_VISIBLE_DEVICES=0 python main.py \ +# --submission \ +# --output_path submission/sintel-gmflow-norefine-vis \ +# --save_vis_flow \ +# --no_save_flo \ +# --val_dataset sintel \ +# --resume pretrained/gmflow_sintel.pth + + + + +# GMFlow with refinement + +# submission to sintel +CUDA_VISIBLE_DEVICES=0 python main.py \ +--submission \ +--output_path submission/sintel-gmflow-withrefine \ +--val_dataset sintel \ +--resume pretrained/gmflow_with_refine_sintel-3ed1cf48.pth \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 + +# submission to kitti +CUDA_VISIBLE_DEVICES=0 python main.py \ +--submission \ +--output_path submission/kitti-gmflow-withrefine \ +--val_dataset kitti \ +--resume pretrained/gmflow_with_refine_kitti-8d3b9786.pth \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 + + + + + diff --git a/GMFlow/scripts/train_gmflow.sh b/GMFlow/scripts/train_gmflow.sh new file mode 100755 index 0000000000000000000000000000000000000000..a04fc583393c8ea35cb9298a1f814a75052f2a96 --- /dev/null +++ b/GMFlow/scripts/train_gmflow.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash + +# GMFlow without refinement + +# number of gpus for training, please set according to your hardware +# by default use all gpus on a machine +# can be trained on 4x 16GB V100 or 2x 32GB V100 or 2x 40GB A100 gpus +NUM_GPUS=4 + +# chairs +CHECKPOINT_DIR=checkpoints/chairs-gmflow && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--batch_size 16 \ +--val_dataset chairs sintel kitti \ +--lr 4e-4 \ +--image_size 384 512 \ +--padding_factor 16 \ +--upsample_factor 8 \ +--with_speed_metric \ +--val_freq 10000 \ +--save_ckpt_freq 10000 \ +--num_steps 100000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + +# things (our final model is trained for 800K iterations, for ablation study, you can train for 200K) +CHECKPOINT_DIR=checkpoints/things-gmflow && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--resume checkpoints/chairs-gmflow/step_100000.pth \ +--stage things \ +--batch_size 8 \ +--val_dataset things sintel kitti \ +--lr 2e-4 \ +--image_size 384 768 \ +--padding_factor 16 \ +--upsample_factor 8 \ +--with_speed_metric \ +--val_freq 40000 \ +--save_ckpt_freq 50000 \ +--num_steps 800000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + +# sintel +CHECKPOINT_DIR=checkpoints/sintel-gmflow && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--resume checkpoints/things-gmflow/step_800000.pth \ +--stage sintel \ +--batch_size 8 \ +--val_dataset sintel kitti \ +--lr 2e-4 \ +--image_size 320 896 \ +--padding_factor 16 \ +--upsample_factor 8 \ +--with_speed_metric \ +--val_freq 20000 \ +--save_ckpt_freq 20000 \ +--num_steps 200000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + +# kitti +CHECKPOINT_DIR=checkpoints/kitti-gmflow && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--resume checkpoints/sintel-gmflow/step_200000.pth \ +--stage kitti \ +--batch_size 8 \ +--val_dataset kitti \ +--lr 2e-4 \ +--image_size 320 1152 \ +--padding_factor 16 \ +--upsample_factor 8 \ +--with_speed_metric \ +--val_freq 10000 \ +--save_ckpt_freq 10000 \ +--num_steps 100000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + + +# a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint +# an example: resume chairs training +# CHECKPOINT_DIR=checkpoints/chairs-gmflow && \ +# mkdir -p ${CHECKPOINT_DIR} && \ +# python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +# --launcher pytorch \ +# --checkpoint_dir ${CHECKPOINT_DIR} \ +# --resume checkpoints/chairs-gmflow/checkpoint_latest.pth \ +# --batch_size 16 \ +# --val_dataset chairs sintel kitti \ +# --lr 4e-4 \ +# --image_size 384 512 \ +# --padding_factor 16 \ +# --upsample_factor 8 \ +# --with_speed_metric \ +# --val_freq 10000 \ +# --save_ckpt_freq 10000 \ +# --num_steps 100000 \ +# 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + diff --git a/GMFlow/scripts/train_gmflow_with_refine.sh b/GMFlow/scripts/train_gmflow_with_refine.sh new file mode 100755 index 0000000000000000000000000000000000000000..db8ed3d423fd2993cb3c25a171fc52abe2c4c792 --- /dev/null +++ b/GMFlow/scripts/train_gmflow_with_refine.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +# GMFlow with refinement + +# number of gpus for training, please set according to your hardware +# by default use all gpus on a machine +# can be trained on 4x 32G V100 or 4x 40GB A100 or 8x 16G V100 gpus +NUM_GPUS=4 + +# chairs +CHECKPOINT_DIR=checkpoints/chairs-gmflow_with_refine && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--batch_size 16 \ +--val_dataset chairs sintel kitti \ +--lr 4e-4 \ +--image_size 384 512 \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 \ +--with_speed_metric \ +--val_freq 10000 \ +--save_ckpt_freq 10000 \ +--num_steps 100000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + +# things (our final model is trained for 800K iterations, for ablation study, you can train for 200K) +CHECKPOINT_DIR=checkpoints/things-gmflow_with_refine && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--resume checkpoints/chairs-gmflow_with_refine/step_100000.pth \ +--stage things \ +--batch_size 8 \ +--val_dataset things sintel kitti \ +--lr 2e-4 \ +--image_size 384 768 \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 \ +--with_speed_metric \ +--val_freq 40000 \ +--save_ckpt_freq 50000 \ +--num_steps 800000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + +# sintel +CHECKPOINT_DIR=checkpoints/sintel-gmflow_with_refine && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--resume checkpoints/things-gmflow_with_refine/step_800000.pth \ +--stage sintel \ +--batch_size 8 \ +--val_dataset sintel kitti \ +--lr 2e-4 \ +--image_size 320 896 \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 \ +--with_speed_metric \ +--val_freq 20000 \ +--save_ckpt_freq 20000 \ +--num_steps 200000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + +# kitti +CHECKPOINT_DIR=checkpoints/kitti-gmflow_with_refine && \ +mkdir -p ${CHECKPOINT_DIR} && \ +python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +--launcher pytorch \ +--checkpoint_dir ${CHECKPOINT_DIR} \ +--resume checkpoints/sintel-gmflow_with_refine/step_200000.pth \ +--stage kitti \ +--batch_size 8 \ +--val_dataset kitti \ +--lr 2e-4 \ +--image_size 320 1152 \ +--padding_factor 32 \ +--upsample_factor 4 \ +--num_scales 2 \ +--attn_splits_list 2 8 \ +--corr_radius_list -1 4 \ +--prop_radius_list -1 1 \ +--with_speed_metric \ +--val_freq 10000 \ +--save_ckpt_freq 10000 \ +--num_steps 100000 \ +2>&1 | tee -a ${CHECKPOINT_DIR}/train.log + + + +# a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint +# an example: resume chairs training +# CHECKPOINT_DIR=checkpoints/chairs-gmflow_with_refine && \ +# mkdir -p ${CHECKPOINT_DIR} && \ +# python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main.py \ +# --launcher pytorch \ +# --checkpoint_dir ${CHECKPOINT_DIR} \ +# --resume checkpoints/chairs-gmflow_with_refine/checkpoint_latest.pth \ +# --batch_size 16 \ +# --val_dataset chairs sintel kitti \ +# --lr 4e-4 \ +# --image_size 384 512 \ +# --padding_factor 32 \ +# --upsample_factor 4 \ +# --num_scales 2 \ +# --attn_splits_list 2 8 \ +# --corr_radius_list -1 4 \ +# --prop_radius_list -1 1 \ +# --with_speed_metric \ +# --val_freq 10000 \ +# --save_ckpt_freq 10000 \ +# --num_steps 100000 \ +# 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log diff --git a/GMFlow/utils/__pycache__/utils.cpython-310.pyc b/GMFlow/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a734c1baf6d7f971eb31240c8192391723673b59 Binary files /dev/null and b/GMFlow/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/GMFlow/utils/__pycache__/utils.cpython-38.pyc b/GMFlow/utils/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9860f55af63fffa410b2818f68de28007c8631ad Binary files /dev/null and b/GMFlow/utils/__pycache__/utils.cpython-38.pyc differ diff --git a/GMFlow/utils/__pycache__/utils.cpython-39.pyc b/GMFlow/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4abe98b4e957a324b2e2455264f76b022f64909c Binary files /dev/null and b/GMFlow/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/GMFlow/utils/dist_utils.py b/GMFlow/utils/dist_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..3c70f089225ad8cfb741f71809f4018c11711a72 --- /dev/null +++ b/GMFlow/utils/dist_utils.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py + +import os +import subprocess + +import torch +import torch.multiprocessing as mp +from torch import distributed as dist + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend, **kwargs): + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print diff --git a/GMFlow/utils/flow_viz.py b/GMFlow/utils/flow_viz.py new file mode 100755 index 0000000000000000000000000000000000000000..9b782c07841b27526ef8c9fa070b480a01545c31 --- /dev/null +++ b/GMFlow/utils/flow_viz.py @@ -0,0 +1,291 @@ +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def make_colorwheel(): + ''' + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + ''' + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col:col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col:col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col:col + MR, 0] = 255 + return colorwheel + + +def flow_compute_color(u, v, convert_to_bgr=False): + ''' + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param u: np.ndarray, input horizontal flow + :param v: np.ndarray, input vertical flow + :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB + :return: + ''' + + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 1 + f = fk - k0 + + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range? + + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + + return flow_image + + +def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): + ''' + Expects a two dimensional flow image of shape [H,W,2] + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param flow_uv: np.ndarray of shape [H,W,2] + :param clip_flow: float, maximum clipping value for flow + :return: + ''' + + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + + return flow_compute_color(u, v, convert_to_bgr) + + +UNKNOWN_FLOW_THRESH = 1e7 +SMALLFLOW = 0.0 +LARGEFLOW = 1e8 + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) + colorwheel[col:col + YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) + colorwheel[col:col + CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col + MR, 0] = 255 + + return colorwheel + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u ** 2 + v ** 2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols + 1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel, 1)): + tmp = colorwheel[:, i] + col0 = tmp[k0 - 1] / 255 + col1 = tmp[k1 - 1] / 255 + col = (1 - f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) + + return img + + +# from https://github.com/gengshan-y/VCN +def flow_to_image(flow): + """ + Convert flow into middlebury color code image + :param flow: optical flow map + :return: optical flow image in middlebury color + """ + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + u = u / (maxrad + np.finfo(float).eps) + v = v / (maxrad + np.finfo(float).eps) + + img = compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + +def save_vis_flow_tofile(flow, output_path): + vis_flow = flow_to_image(flow) + from PIL import Image + img = Image.fromarray(vis_flow) + img.save(output_path) + + +def flow_tensor_to_image(flow): + """Used for tensorboard visualization""" + flow = flow.permute(1, 2, 0) # [H, W, 2] + flow = flow.detach().cpu().numpy() + flow = flow_to_image(flow) # [H, W, 3] + flow = np.transpose(flow, (2, 0, 1)) # [3, H, W] + + return flow diff --git a/GMFlow/utils/frame_utils.py b/GMFlow/utils/frame_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..9005ed1d2005d25456e620c467a5d688e8c0a783 --- /dev/null +++ b/GMFlow/utils/frame_utils.py @@ -0,0 +1,131 @@ +import numpy as np +from PIL import Image +from os.path import * +import re +import cv2 + +TAG_CHAR = np.array([202021.25], np.float32) + + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape testdata into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + + +def writeFlow(filename, uv, v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert (uv.ndim == 3) + assert (uv.shape[2] == 2) + u = uv[:, :, 0] + v = uv[:, :, 1] + else: + u = uv + + assert (u.shape == v.shape) + height, width = u.shape + f = open(filename, 'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2 ** 15) / 64.0 + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2 ** 15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] diff --git a/GMFlow/utils/logger.py b/GMFlow/utils/logger.py new file mode 100755 index 0000000000000000000000000000000000000000..07ab133fb0d4bbc9e84918fd276eb429f06d730a --- /dev/null +++ b/GMFlow/utils/logger.py @@ -0,0 +1,68 @@ +import torch + +from utils.flow_viz import flow_tensor_to_image + + +class Logger: + def __init__(self, lr_scheduler, + summary_writer, + summary_freq=100, + start_step=0, + ): + self.lr_scheduler = lr_scheduler + self.total_steps = start_step + self.running_loss = {} + self.summary_writer = summary_writer + self.summary_freq = summary_freq + + def print_training_status(self, mode='train'): + + print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) + + for k in self.running_loss: + self.summary_writer.add_scalar(mode + '/' + k, + self.running_loss[k] / self.summary_freq, self.total_steps) + self.running_loss[k] = 0.0 + + def lr_summary(self): + lr = self.lr_scheduler.get_last_lr()[0] + self.summary_writer.add_scalar('lr', lr, self.total_steps) + + def add_image_summary(self, img1, img2, flow_preds, flow_gt, mode='train', + ): + if self.total_steps % self.summary_freq == 0: + img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) + img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard + + flow_pred = flow_tensor_to_image(flow_preds[-1][0]) + forward_flow_gt = flow_tensor_to_image(flow_gt[0]) + flow_concat = torch.cat((torch.from_numpy(flow_pred), + torch.from_numpy(forward_flow_gt)), dim=-1) + + concat = torch.cat((img_concat, flow_concat), dim=-2) + + self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) + + def push(self, metrics, mode='train'): + self.total_steps += 1 + + self.lr_summary() + + for key in metrics: + if key not in self.running_loss: + self.running_loss[key] = 0.0 + + self.running_loss[key] += metrics[key] + + if self.total_steps % self.summary_freq == 0: + self.print_training_status(mode) + self.running_loss = {} + + def write_dict(self, results): + for key in results: + tag = key.split('_')[0] + tag = tag + '/' + key + self.summary_writer.add_scalar(tag, results[key], self.total_steps) + + def close(self): + self.summary_writer.close() diff --git a/GMFlow/utils/misc.py b/GMFlow/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..c2de906d8181e9e24d2f51e0be03a19c04960d06 --- /dev/null +++ b/GMFlow/utils/misc.py @@ -0,0 +1,42 @@ +import os +import numpy as np +import sys +import json + + +def read_text_lines(filepath): + with open(filepath, 'r') as f: + lines = f.readlines() + lines = [l.rstrip() for l in lines] + return lines + + +def check_path(path): + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing + + +def save_command(save_path, filename='command_train.txt'): + check_path(save_path) + command = sys.argv + save_file = os.path.join(save_path, filename) + # Save all training commands when resuming training + with open(save_file, 'a') as f: + f.write(' '.join(command)) + f.write('\n\n') + + +def save_args(args, filename='args.json'): + args_dict = vars(args) + check_path(args.checkpoint_dir) + save_path = os.path.join(args.checkpoint_dir, filename) + + # Save all training args when resuming training + with open(save_path, 'a') as f: + json.dump(args_dict, f, indent=4, sort_keys=False) + f.write('\n\n') + + +def int_list(s): + """Convert string to int list""" + return [int(x) for x in s.split(',')] diff --git a/GMFlow/utils/utils.py b/GMFlow/utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..76f5518b7e5b769527907b31a1c1c00ba6cfe4f1 --- /dev/null +++ b/GMFlow/utils/utils.py @@ -0,0 +1,58 @@ +import torch +import torch.nn.functional as F + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, dims, mode='sintel', padding_factor=8): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor + pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor + if mode == 'sintel': + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def coords_grid(batch, ht, wd, normalize=False): + if normalize: # [-1, 1] + coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1, + 2 * torch.arange(wd) / (wd - 1) - 1) + else: + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W] + + +def compute_out_of_boundary_mask(flow): + # flow: [B, 2, H, W] + assert flow.dim() == 4 and flow.size(1) == 2 + b, _, h, w = flow.shape + init_coords = coords_grid(b, h, w).to(flow.device) + corres = init_coords + flow # [B, 2, H, W] + + max_w = w - 1 + max_h = h - 1 + + valid_mask = (corres[:, 0] >= 0) & (corres[:, 0] <= max_w) & (corres[:, 1] >= 0) & (corres[:, 1] <= max_h) + + # in case very large flow + flow_mask = (flow[:, 0].abs() <= max_w) & (flow[:, 1].abs() <= max_h) + + valid_mask = valid_mask & flow_mask + + return valid_mask # [B, H, W] + + +def count_parameters(model): + num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return num diff --git a/README.md b/README.md index 70cf0b807b476bdb1766920a0a5ee3a5db31c884..214d5315b5827d967882bb162366f5ce4901969b 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,18 @@ --- title: DiffIR2VR -emoji: 📈 -colorFrom: green -colorTo: yellow +emoji: 👌🏻 +colorFrom: purple +colorTo: pink sdk: gradio -sdk_version: 4.37.1 +sdk_version: 4.36.1 app_file: app.py pinned: false +models: + - weights/BSRNet.pth + - weights/gmflow_sintel-0c07dcb3.pth + - weights/scunet_color_real_psnr.pth + - weights/v2-1_512-ema-pruned.ckpt + - weights/v2.pth --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/__pycache__/NaRCan_model.cpython-39.pyc b/__pycache__/NaRCan_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c7437e493c007e4600f21ec1bde3818a1956655 Binary files /dev/null and b/__pycache__/NaRCan_model.cpython-39.pyc differ diff --git a/__pycache__/util.cpython-39.pyc b/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cbc740e01b804aa3d2cc8ee8684769eb033252f Binary files /dev/null and b/__pycache__/util.cpython-39.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5c743b63c38dbbbf8361a75a66c150c6479c9e85 --- /dev/null +++ b/app.py @@ -0,0 +1,312 @@ +import os +import cv2 +import torch +import spaces +import imageio +import numpy as np +import gradio as gr +torch.jit.script = lambda f: f + +import argparse +from utils.batch_inference import ( + BSRInferenceLoop, BIDInferenceLoop +) + +# import subprocess +# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +def get_example(task): + case = { + "dn": [ + ['examples/bus.mp4',], + ['examples/koala.mp4',], + ['examples/flamingo.mp4',], + ['examples/rhino.mp4',], + ['examples/elephant.mp4',], + ['examples/sheep.mp4',], + ['examples/dog-agility.mp4',], + # ['examples/dog-gooses.mp4',], + ], + "sr": [ + ['examples/bus_sr.mp4',], + ['examples/koala_sr.mp4',], + ['examples/flamingo_sr.mp4',], + ['examples/rhino_sr.mp4',], + ['examples/elephant_sr.mp4',], + ['examples/sheep_sr.mp4',], + ['examples/dog-agility_sr.mp4',], + # ['examples/dog-gooses_sr.mp4',], + ] + + } + return case[task] + + + +def update_prompt(input_video): + video_name = input_video.split('/')[-1] + return set_default_prompt(video_name) + + +# Map videos to corresponding images +video_to_image = { + 'bus.mp4': ['examples_frames/bus'], + 'koala.mp4': ['examples_frames/koala'], + 'dog-gooses.mp4': ['examples_frames/dog-gooses'], + 'flamingo.mp4': ['examples_frames/flamingo'], + 'rhino.mp4': ['examples_frames/rhino'], + 'elephant.mp4': ['examples_frames/elephant'], + 'sheep.mp4': ['examples_frames/sheep'], + 'dog-agility.mp4': ['examples_frames/dog-agility'], + + 'bus_sr.mp4': ['examples_frames/bus_sr'], + 'koala_sr.mp4': ['examples_frames/koala_sr'], + 'dog-gooses_sr.mp4': ['examples_frames/dog_gooses_sr'], + 'flamingo_sr.mp4': ['examples_frames/flamingo_sr'], + 'rhino_sr.mp4': ['examples_frames/rhino_sr'], + 'elephant_sr.mp4': ['examples_frames/elephant_sr'], + 'sheep_sr.mp4': ['examples_frames/sheep_sr'], + 'dog-agility_sr.mp4': ['examples_frames/dog-agility_sr'], +} + + +def images_to_video(image_list, output_path, fps=10): + # Convert PIL Images to numpy arrays + frames = [np.array(img).astype(np.uint8) for img in image_list] + frames = frames[:20] + + # Create video writer + writer = imageio.get_writer(output_path, fps=fps, codec='libx264') + + for frame in frames: + writer.append_data(frame) + + writer.close() + + + +@spaces.GPU(duration=120) +def DiffBIR_restore(input_video, prompt, sr_ratio, n_frames, n_steps, guidance_scale, seed, n_prompt, task): + + video_name = input_video.split('/')[-1] + if video_name in video_to_image: + frames_path = video_to_image[video_name][0] + else: + return None + print(f"[INFO] input_video: {input_video}") + print(f"[INFO] Frames path: {frames_path}") + args = argparse.Namespace() + + # args.task = True, choices=["sr", "dn", "fr", "fr_bg"] + args.task = task + args.upscale = sr_ratio + + ### sampling parameters + args.steps = n_steps + args.better_start = True + args.tiled = False + args.tile_size = 512 + args.tile_stride = 256 + args.pos_prompt = prompt + args.neg_prompt = n_prompt + args.cfg_scale = guidance_scale + ### input parameters + args.input = frames_path + args.n_samples = 1 + args.batch_size = 10 + args.final_size = (480, 854) + args.config = "configs/inference/my_cldm.yaml" + ### guidance parameters + args.guidance = False + args.g_loss = "w_mse" + args.g_scale = 0.0 + args.g_start = 1001 + args.g_stop = -1 + args.g_space = "latent" + args.g_repeat = 1 + ### output parameters + args.output = " " + ### common parameters + args.seed = seed + args.device = "cuda" + + args.n_frames = n_frames + ### latent control parameters + args.warp_period = [0, 0.1] + args.merge_period = [0, 0] + args.ToMe_period = [0, 1] + args.merge_ratio = [0.6, 0] + + if args.task == "sr": + restored_vid_path = BSRInferenceLoop(args).run() + elif args.task == "dn": + restored_vid_path = BIDInferenceLoop(args).run() + + torch.cuda.empty_cache() + return restored_vid_path + +######## +# demo # +######## + + +intro = """ +
+

+ DiffIR2VR - Zero-Shot Video Restoration +

+[Project page] [arXiv] +
Note that this page is a limited demo of DiffIR2VR. For more configurations, please visit our GitHub page. The code will be released soon!
+
+""" + + +with gr.Blocks(css="style.css") as demo: + + gr.HTML(intro) + + + with gr.Tab(label="Super-resolution with DiffBIR"): + with gr.Row(): + input_video = gr.Video(label="Input Video") + output_video = gr.Video(label="Restored Video", interactive=False) + + with gr.Row(): + run_button = gr.Button("Restore your video !", visible=True) + + with gr.Accordion('Advanced options', open=False): + prompt = gr.Textbox( + label="Prompt", + max_lines=1, + placeholder="describe your video content" + # value="bear, Van Gogh Style" + ) + sr_ratio = gr.Slider(label='SR ratio', + minimum=1, + maximum=16, + value=4, + step=1) + n_frames = gr.Slider(label='Frames', + minimum=1, + maximum=60, + value=10, + step=1) + n_steps = gr.Slider(label='Steps', + minimum=1, + maximum=100, + value=10, + step=1) + guidance_scale = gr.Slider(label='Guidance Scale', + minimum=0.1, + maximum=30.0, + value=4.0, + step=0.1) + seed = gr.Slider(label='Seed', + minimum=-1, + maximum=1000, + step=1, + randomize=True) + n_prompt = gr.Textbox( + label='Negative Prompt', + value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" + ) + task = gr.Textbox(value="sr", visible=False) + # input_video.change( + # fn = update_prompt, + # inputs = [input_video], + # outputs = [prompt], + # queue = False) + + run_button.click(fn = DiffBIR_restore, + inputs = [input_video, + prompt, + sr_ratio, + n_frames, + n_steps, + guidance_scale, + seed, + n_prompt, + task + ], + outputs = [output_video] + ) + gr.Examples( + examples=get_example("sr"), + label='Examples', + inputs=[input_video], + outputs=[output_video], + examples_per_page=7 + ) + + with gr.Tab(label="Denoise with DiffBIR"): + with gr.Row(): + input_video = gr.Video(label="Input Video") + output_video = gr.Video(label="Restored Video", interactive=False) + + with gr.Row(): + run_button = gr.Button("Restore your video !", visible=True) + + with gr.Accordion('Advanced options', open=False): + prompt = gr.Textbox( + label="Prompt", + max_lines=1, + placeholder="describe your video content" + # value="bear, Van Gogh Style" + ) + n_frames = gr.Slider(label='Frames', + minimum=1, + maximum=60, + value=10, + step=1) + n_steps = gr.Slider(label='Steps', + minimum=1, + maximum=100, + value=10, + step=1) + guidance_scale = gr.Slider(label='Guidance Scale', + minimum=0.1, + maximum=30.0, + value=4.0, + step=0.1) + seed = gr.Slider(label='Seed', + minimum=-1, + maximum=1000, + step=1, + randomize=True) + n_prompt = gr.Textbox( + label='Negative Prompt', + value="low quality, blurry, low-resolution, noisy, unsharp, weird textures" + ) + task = gr.Textbox(value="dn", visible=False) + sr_ratio = gr.Number(value=1, visible=False) + + # input_video.change( + # fn = update_prompt, + # inputs = [input_video], + # outputs = [prompt], + # queue = False) + run_button.click(fn = DiffBIR_restore, + inputs = [input_video, + prompt, + sr_ratio, + n_frames, + n_steps, + guidance_scale, + seed, + n_prompt, + task + ], + outputs = [output_video] + ) + gr.Examples( + examples=get_example("dn"), + label='Examples', + inputs=[input_video], + outputs=[output_video], + examples_per_page=7 + ) + +demo.queue() + +demo.launch(share=True) \ No newline at end of file diff --git a/configs/inference/bsrnet.yaml b/configs/inference/bsrnet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eac1177b6d69b8f2b0724944768df06da2367c86 --- /dev/null +++ b/configs/inference/bsrnet.yaml @@ -0,0 +1,8 @@ +target: model.RRDBNet +params: + in_nc: 3 + out_nc: 3 + nf: 64 + nb: 23 + gc: 32 + sf: 4 diff --git a/configs/inference/cldm.yaml b/configs/inference/cldm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48f2592ce82f1ddddc549f0dda3c1326d9ce7e60 --- /dev/null +++ b/configs/inference/cldm.yaml @@ -0,0 +1,65 @@ +target: model.ControlLDM +params: + latent_scale_factor: 0.18215 + unet_cfg: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + vae_cfg: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + clip_cfg: + embed_dim: 1024 + vision_cfg: + image_size: 224 + layers: 32 + width: 1280 + head_width: 80 + patch_size: 14 + text_cfg: + context_length: 77 + vocab_size: 49408 + width: 1024 + heads: 16 + layers: 24 + layer: "penultimate" + controlnet_cfg: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + hint_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False diff --git a/configs/inference/diffusion.yaml b/configs/inference/diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5541cfd5b47eba89fa15ab3c8fac596b7d47570d --- /dev/null +++ b/configs/inference/diffusion.yaml @@ -0,0 +1,5 @@ +target: model.Diffusion +params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 diff --git a/configs/inference/my_cldm.yaml b/configs/inference/my_cldm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a07b8ecda8f3e681ace2833cfca9491c4f99515d --- /dev/null +++ b/configs/inference/my_cldm.yaml @@ -0,0 +1,88 @@ + + +target: model.ControlLDM +params: + latent_warp_cfg: + latent_control: True + # interval: 5 + # x0_strength: 1 + warp_period: [0, 0.1] + merge_period: [0, 0] + cross_period: [0, 0] + mask_period: [0, 0] + ada_period: [0, 0] + + VidToMe_cfg: + flow_merge: False + ToMe_period: [0, 1] + merge_ratio: [0.9, 0] + merge_global: False + global_merge_ratio: 0.3 + seed: 123 + batch_size: 1 + align_batch: False + global_rand: 0.1 + + latent_scale_factor: 0.18215 + unet_cfg: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + vae_cfg: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + clip_cfg: + embed_dim: 1024 + vision_cfg: + image_size: 224 + layers: 32 + width: 1280 + head_width: 80 + patch_size: 14 + text_cfg: + context_length: 77 + vocab_size: 49408 + width: 1024 + heads: 16 + layers: 24 + layer: "penultimate" + controlnet_cfg: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + hint_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False diff --git a/configs/inference/scunet.yaml b/configs/inference/scunet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3baa69f0357e51f483cb1b1dd213b1cda85f456 --- /dev/null +++ b/configs/inference/scunet.yaml @@ -0,0 +1,5 @@ +target: model.SCUNet +params: + in_nc: 3 + config: [4,4,4,4,4,4,4] + dim: 64 diff --git a/configs/inference/swinir.yaml b/configs/inference/swinir.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d4a9354c06e20a7c863c04110d892d78fc15aaf --- /dev/null +++ b/configs/inference/swinir.yaml @@ -0,0 +1,16 @@ +target: model.SwinIR +params: + img_size: 64 + patch_size: 1 + in_chans: 3 + embed_dim: 180 + depths: [6, 6, 6, 6, 6, 6, 6, 6] + num_heads: [6, 6, 6, 6, 6, 6, 6, 6] + window_size: 8 + mlp_ratio: 2 + sf: 8 + img_range: 1.0 + upsampler: "nearest+conv" + resi_connection: "1conv" + unshuffle: True + unshuffle_scale: 8 diff --git a/controller/__pycache__/controller.cpython-310.pyc b/controller/__pycache__/controller.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ede9bc88f894de2b5d1289624b04b1579fcec1e8 Binary files /dev/null and b/controller/__pycache__/controller.cpython-310.pyc differ diff --git a/controller/__pycache__/controller.cpython-39.pyc b/controller/__pycache__/controller.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d3a3185a49ff871561d89bff9675ff3daf63a3 Binary files /dev/null and b/controller/__pycache__/controller.cpython-39.pyc differ diff --git a/controller/controller.py b/controller/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..80ebb9f2f70741a17c682f58c8396ed84cbe2889 --- /dev/null +++ b/controller/controller.py @@ -0,0 +1,415 @@ +import gc + +import torch +import torch.nn.functional as F + +from einops import repeat, rearrange +from vidtome import merge +from utils.flow_utils import flow_warp, coords_grid + +# AdaIn + + +def calc_mean_std(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + assert (len(size) == 4) + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + +class AttentionControl(): + + def __init__(self, + warp_period=(0.0, 0.0), + merge_period=(0.0, 0.0), + merge_ratio=(0.3, 0.3), + ToMe_period=(0.0, 1.0), + mask_period=(0.0, 0.0), + cross_period=(0.0, 0.0), + ada_period=(0.0, 0.0), + inner_strength=1.0, + loose_cfatnn=False, + flow_merge=True, + ): + + self.cur_frame_idx = 0 + + self.step_store = self.get_empty_store() + self.cur_step = 0 + self.total_step = 0 + self.cur_index = 0 + self.init_store = False + self.restore = False + self.update = False + self.flow = None + self.mask = None + self.cldm = None + self.decoded_imgs = [] + self.restorex0 = True + self.updatex0 = False + self.inner_strength = inner_strength + self.cross_period = cross_period + self.mask_period = mask_period + self.ada_period = ada_period + self.warp_period = warp_period + self.ToMe_period = ToMe_period + self.merge_period = merge_period + self.merge_ratio = merge_ratio + self.keyframe_idx = 0 + self.flow_merge = flow_merge + self.distances = {} + self.flow_correspondence = {} + self.non_pad_ratio = (1.0, 1.0) + self.up_resolution = 1280 if loose_cfatnn else 1281 + + @staticmethod + def get_empty_store(): + return { + 'first': [], + 'previous': [], + 'x0_previous': [], + 'first_ada': [], + 'pre_x0': [], + "pre_keyframe_lq": None, + "flows": None, + "occ_masks": None, + "flow_confids": None, + "merge": None, + "unmerge": None, + "corres_scores": None, + "flows2": None, + "flow_confids2": None, + } + + def forward(self, context, is_cross: bool, place_in_unet: str): + cross_period = (self.total_step * self.cross_period[0], + self.total_step * self.cross_period[1]) + if not is_cross and place_in_unet == 'up' and context.shape[ + 2] < self.up_resolution: + if self.init_store: + self.step_store['first'].append(context.detach()) + self.step_store['previous'].append(context.detach()) + if self.update: + tmp = context.clone().detach() + if self.restore and self.cur_step >= cross_period[0] and \ + self.cur_step <= cross_period[1]: + # context = torch.cat( + # (self.step_store['first'][self.cur_index], + # self.step_store['previous'][self.cur_index]), + # dim=1).clone() + context = self.step_store['previous'][self.cur_index].clone() + if self.update: + self.step_store['previous'][self.cur_index] = tmp + self.cur_index += 1 + # print(is_cross, place_in_unet, context.shape[2]) + # import ipdb; ipdb.set_trace() + return context + + def update_x0(self, x0, cur_frame=0): + # if self.init_store: + # self.step_store['x0_previous'].append(x0.detach()) + # style_mean, style_std = calc_mean_std(x0.detach()) + # self.step_store['first_ada'].append(style_mean.detach()) + # self.step_store['first_ada'].append(style_std.detach()) + # if self.updatex0: + # tmp = x0.clone().detach() + if self.restorex0: + # if self.cur_step >= self.total_step * self.ada_period[ + # 0] and self.cur_step <= self.total_step * self.ada_period[ + # 1]: + # x0 = F.instance_norm(x0) * self.step_store['first_ada'][ + # 2 * self.cur_step + + # 1] + self.step_store['first_ada'][2 * self.cur_step] + if self.cur_step >= self.total_step * self.warp_period[ + 0] and self.cur_step < int(self.total_step * self.warp_period[1]): + + # mid_x = repeat(x[mid][None], 'b c h w -> (repeat b) c h w', repeat=x.shape[0]) + mid = x0.shape[0] // 2 + if len(self.step_store["pre_x0"]) == int(self.total_step * self.warp_period[1]): + print(f"[INFO] keyframe latent warping @ step {self.cur_step}...") + x0[mid] = (1 - self.step_store["occ_masks"][mid]) * x0[mid] + \ + flow_warp(self.step_store["pre_x0"][self.cur_step][None], self.step_store["flows"][mid], mode='nearest')[0] * self.step_store["occ_masks"][mid] + + print(f"[INFO] local latent warping @ step {self.cur_step}...") + for i in range(x0.shape[0]): + if i == mid: + continue + x0[i] = (1 - self.step_store["occ_masks"][i]) * x0[i] + \ + flow_warp(x0[mid][None], self.step_store["flows"][i], mode='nearest')[0] * self.step_store["occ_masks"][i] + # x = rearrange(x, 'b c h w -> b (h w) c', h=64) + # self.step_store['x0_previous'][self.cur_step] = tmp + # print(f"[INFO] storeing {self.cur_frame_idx} th frame x0 for step {self.cur_step}...") + if len(self.step_store["pre_x0"]) < int(self.total_step * self.warp_period[1]): + self.step_store['pre_x0'].append(x0[mid]) + else: + self.step_store['pre_x0'][self.cur_step] = x0[mid] + + return x0 + + def merge_x0(self, x0, merge_ratio): + # print(f"[INFO] {self.total_step * self.merge_period[0]} {self.cur_step} {int(self.total_step * self.merge_period[1])} ...") + if self.cur_step >= self.total_step * self.merge_period[0] and \ + self.cur_step < int(self.total_step * self.merge_period[1]): + print(f"[INFO] latent merging @ step {self.cur_step}...") + + B, C, H, W = x0.shape + non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio + padding_size_w = W - int(W * non_pad_ratio_w) + padding_size_h = H - int(H * non_pad_ratio_h) + non_pad_w = W - padding_size_w + non_pad_h = H - padding_size_h + padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool) + if padding_size_w: + padding_mask[:, -padding_size_w:] = 1 + if padding_size_h: + padding_mask[-padding_size_h:, :] = 1 + padding_mask = rearrange(padding_mask, 'h w -> (h w)') + + idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64) + non_pad_idx = idx_buffer[None, ~padding_mask, None] + del idx_buffer, padding_mask + x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H) + x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C)) + # import ipdb; ipdb.set_trace() + # merge.visualize_correspondence(x_non_pad[0][None], x_non_pad[B//2][None], ratio=0.3, H=H, out="latent_correspondence.png") + + # m, u, ret_dict = merge.bipartite_soft_matching_randframe( + # x_non_pad, B, merge_ratio, 0, target_stride=B) + import copy + flows = copy.deepcopy(self.step_store["flows"]) + for i in range(B): + if flows[i] is not None: + flows[i] = flows[i][:, :, :non_pad_h, :non_pad_w] + # merge.visualize_flow_correspondence(x_non_pad[1][None], x_non_pad[B // 2][None], flow=flows[1], flow_confid=self.step_store["flow_confids"][1], \ + # ratio=0.8, H=H, out=f"flow_correspondence_08.png") + # import ipdb; ipdb.set_trace() + x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c') + m, u, ret_dict = merge.bipartite_soft_matching_randframe( + x_non_pad, B, merge_ratio, 0, target_stride=B, + H=H, + flow=flows, + flow_confid=self.step_store["flow_confids"], + ) + x_non_pad = u(m(x_non_pad)) + # x_non_pad = self.step_store["unmerge"](self.step_store["merge"](x_non_pad)) + x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', b=B) + # print(torch.mean(x0[0]).item(), torch.mean(x0[1]).item(), torch.mean(x0[2]).item(), torch.mean(x0[3]).item(), torch.mean(x0[4]).item()) + # print(torch.std(x0[0]).item(), torch.std(x0[1]).item(), torch.std(x0[2]).item(), torch.std(x0[3]).item(), torch.std(x0[4]).item()) + # import ipdb; ipdb.set_trace() + x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) + x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H) + # import ipdb; ipdb.set_trace() + + return x0 + + def merge_x0_scores(self, x0, merge_ratio, merge_mode="replace"): + # print(f"[INFO] {self.total_step * self.merge_period[0]} {self.cur_step} {int(self.total_step * self.merge_period[1])} ...") + # import ipdb; ipdb.set_trace() + if self.cur_step >= self.total_step * self.merge_period[0] and \ + self.cur_step < int(self.total_step * self.merge_period[1]): + print(f"[INFO] latent merging @ step {self.cur_step}...") + + B, C, H, W = x0.shape + non_pad_ratio_h, non_pad_ratio_w = self.non_pad_ratio + padding_size_w = W - int(W * non_pad_ratio_w) + padding_size_h = H - int(H * non_pad_ratio_h) + padding_mask = torch.zeros((H, W), device=x0.device, dtype=torch.bool) + if padding_size_w: + padding_mask[:, -padding_size_w:] = 1 + if padding_size_h: + padding_mask[-padding_size_h:, :] = 1 + padding_mask = rearrange(padding_mask, 'h w -> (h w)') + + idx_buffer = torch.arange(H*W, device=x0.device, dtype=torch.int64) + non_pad_idx = idx_buffer[None, ~padding_mask, None] + x0 = rearrange(x0, 'b c h w -> b (h w) c', h=H) + x_non_pad = torch.gather(x0, dim=1, index=non_pad_idx.expand(B, -1, C)) + x_non_pad_A, x_non_pad_N = x_non_pad.shape[1], x_non_pad.shape[1] * B + mid = B // 2 + + x_non_pad_ = x_non_pad.clone() + x_non_pad = rearrange(x_non_pad, 'b a c -> 1 (b a) c') + # import ipdb; ipdb.set_trace() + + idx_buffer = torch.arange(x_non_pad_N, device=x0.device, dtype=torch.int64) + randf = torch.tensor(B // 2, dtype=torch.int).to(x0.device) + # print(f"[INFO] {randf.item()} th frame as target") + dst_select = ((torch.div(idx_buffer, x_non_pad_A, rounding_mode='floor')) % B == randf).to(torch.bool) + # a_idx: src index. b_idx: dst index + a_idx = idx_buffer[None, ~dst_select, None] + b_idx = idx_buffer[None, dst_select, None] + del idx_buffer, padding_mask + num_dst = b_idx.shape[1] + # b, _, _ = x_non_pad.shape + b = 1 + src = torch.gather(x_non_pad, dim=1, index=a_idx.expand(b, x_non_pad_N - num_dst, C)) + tar = torch.gather(x_non_pad, dim=1, index=b_idx.expand(b, num_dst, C)) + # tar = x_non_pad[mid][None] + # src = torch.cat((x_non_pad[:mid], x_non_pad[mid+1:]), dim=0) + # src = rearrange(src, 'b n c -> 1 (b n) c') + # print(f"[INFO] {x_non_pad.shape} {src.shape} {tar.shape} ...") + # print(f"[INFO] maximum score {torch.max(self.step_store['corres_scores'])} ...") + flow_src_idx = self.flow_correspondence[H][0] + flow_tar_idx = self.flow_correspondence[H][1] + flow_confid = self.step_store["flow_confids"][:mid] + self.step_store["flow_confids"][mid+1:] + flow_confid = torch.cat(flow_confid, dim=0) + flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') + scores = F.normalize(self.step_store["corres_scores"], p=2, dim=-1) + + flow_confid -= (torch.max(flow_confid) - torch.max(scores)) + + # merge.visualize_correspondence_score(x_non_pad_[0][None], x_non_pad_[mid][None], + # score=scores[:,:x_non_pad_A], + # ratio=0.2, H=H-padding_size_h, out="latent_correspondence.png") + # import ipdb; ipdb.set_trace() + scores[:, flow_src_idx[0, :, 0], flow_tar_idx[0, :, 0]] += (flow_confid[:, flow_src_idx[0, :, 0]] * 0.3) + # merge.visualize_correspondence_score(x_non_pad_[0][None], x_non_pad_[mid][None], + # score=scores[:,:x_non_pad_A], + # ratio=0.2, H=H-padding_size_h, out="latent_correspondence_flow.png") + + # import ipdb; ipdb.set_trace() + r = min(src.shape[1], int(src.shape[1] * merge_ratio)) + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + tar_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx) + unm = torch.gather(src, dim=-2, index=unm_idx.expand(-1, -1, C)) + if merge_mode != "replace": + src = torch.gather(src, dim=-2, index=src_idx.expand(-1, -1, C)) + # In other mode such as mean, combine matched src and dst tokens. + tar = tar.scatter_reduce(-2, tar_idx.expand(-1, -1, C), + src, reduce=merge_mode, include_self=True) + # In replace mode, just cat unmerged tokens and tar tokens. Ignore src tokens. + # token = torch.cat([unm, tar], dim=1) + + # unm_len = unm_idx.shape[1] + # unm, tar = token[..., :unm_len, :], token[..., unm_len:, :] + src = torch.gather(tar, dim=-2, index=tar_idx.expand(-1, -1, C)) + # Combine back to the original shape + # x_non_pad = torch.zeros(b, x_non_pad_N, C, device=x0.device, dtype=x0.dtype) + # Scatter dst tokens + x_non_pad.scatter_(dim=-2, index=b_idx.expand(b, -1, C), src=tar) + # Scatter unmerged tokens + x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1), + dim=1, index=unm_idx).expand(-1, -1, C), src=unm) + # Scatter src tokens + x_non_pad.scatter_(dim=-2, index=torch.gather(a_idx.expand(b, -1, 1), + dim=1, index=src_idx).expand(-1, -1, C), src=src) + + x_non_pad = rearrange(x_non_pad, '1 (b a) c -> b a c', a=x_non_pad_A) + x0.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) + x0 = rearrange(x0, 'b (h w) c -> b c h w ', h=H) + + return x0 + + def set_distance(self, B, H, W, radius, device): + y, x = torch.meshgrid(torch.arange(H), torch.arange(W)) + coords = torch.stack((y, x), dim=-1).float().to(device) + coords = rearrange(coords, 'h w c -> (h w) c') + + # Calculate the Euclidean distance between all pixels + distances = torch.cdist(coords, coords) + # radius = W // 30 + radius = 1 if radius == 0 else radius + # print(f"[INFO] W: {W} Radius: {radius} ") + distances //= radius + distances = torch.exp(-distances) + # distances += torch.diag_embed(torch.ones(A)).to(metric.device) + distances = repeat(distances, 'h a -> 1 (b h) a', b=B) + self.distances[H] = distances + + def set_flow_correspondence(self, B, H, W, key_idx, flow_confid, flow): + + if len(flow) != B - 1: + flow_confid = flow_confid[:key_idx] + flow_confid[key_idx+1:] + flow = flow[:key_idx] + flow[key_idx+1:] + + flow_confid = torch.cat(flow_confid, dim=0) + flow = torch.cat(flow, dim=0) + flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') + + edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] + + src_idx = edge_idx[..., :, :] # Merged Tokens + + A = H * W + src_idx_tensor = src_idx[0, : ,0] + f = src_idx_tensor // A + id = src_idx_tensor % A + x = id % W + y = id // W + + # Stack the results into a 2D tensor + src_fxy = torch.stack((f, x, y), dim=1) + # import ipdb; ipdb.set_trace() + grid = coords_grid(B-1, H, W).to(flow.device) + flow # [F-1, 2, H, W] + + x = grid[src_fxy[:, 0], 0, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, W-1).long() + y = grid[src_fxy[:, 0], 1, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, H-1).long() + tar_xy = torch.stack((x, y), dim=1) + tar_idx = y * W + x + tar_idx = rearrange(tar_idx, ' d -> 1 d 1') + + self.flow_correspondence[H] = (src_idx, tar_idx) + + def set_merge(self, merge, unmerge): + self.step_store["merge"] = merge + self.step_store["unmerge"] = unmerge + + def set_warp(self, flows, masks, flow_confids=None): + self.step_store["flows"] = flows + self.step_store["occ_masks"] = masks + if flow_confids is not None: + self.step_store["flow_confids"] = flow_confids + + def set_warp2(self, flows, flow_confids): + self.step_store["flows2"] = flows + self.step_store["flow_confids2"] = flow_confids + + def set_pre_keyframe_lq(self, pre_keyframe_lq): + self.step_store["pre_keyframe_lq"] = pre_keyframe_lq + + def __call__(self, context, is_cross: bool, place_in_unet: str): + context = self.forward(context, is_cross, place_in_unet) + return context + + def set_cur_frame_idx(self, frame_idx): + self.cur_frame_idx = frame_idx + + def set_step(self, step): + self.cur_step = step + + def set_total_step(self, total_step): + self.total_step = total_step + self.cur_index = 0 + + def clear_store(self): + del self.step_store + torch.cuda.empty_cache() + gc.collect() + self.step_store = self.get_empty_store() + + def set_task(self, task, restore_step=1.0): + self.init_store = False + self.restore = False + self.update = False + self.cur_index = 0 + self.restore_step = restore_step + self.updatex0 = False + self.restorex0 = False + if 'initfirst' in task: + self.init_store = True + self.clear_store() + if 'updatestyle' in task: + self.update = True + if 'keepstyle' in task: + self.restore = True + if 'updatex0' in task: + self.updatex0 = True + if 'keepx0' in task: + self.restorex0 = True diff --git a/examples/bus.mp4 b/examples/bus.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..84b23a48226abdf16a6c9efdab2437cf26fa6ff8 Binary files /dev/null and b/examples/bus.mp4 differ diff --git a/examples/bus_sr.mp4 b/examples/bus_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ca787959919ef946bb4fe240dc8dfc3627ccb973 Binary files /dev/null and b/examples/bus_sr.mp4 differ diff --git a/examples/dog-agility.mp4 b/examples/dog-agility.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9377fb2a320c7587612d87e382701ee417fcded1 Binary files /dev/null and b/examples/dog-agility.mp4 differ diff --git a/examples/dog-agility_sr.mp4 b/examples/dog-agility_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..42ed21bc663effd8d5b292644054939cc59cf0dc Binary files /dev/null and b/examples/dog-agility_sr.mp4 differ diff --git a/examples/dog-gooses.mp4 b/examples/dog-gooses.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a5c1aba1efb1531713f323da6c41e6d17a3b20e0 Binary files /dev/null and b/examples/dog-gooses.mp4 differ diff --git a/examples/dog-gooses_sr.mp4 b/examples/dog-gooses_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..158b1ee95871085db8449da6d456495df016d0df Binary files /dev/null and b/examples/dog-gooses_sr.mp4 differ diff --git a/examples/elephant.mp4 b/examples/elephant.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5316dc3c9f64ff959a7c92f8b82646e55d4f8302 Binary files /dev/null and b/examples/elephant.mp4 differ diff --git a/examples/elephant_sr.mp4 b/examples/elephant_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a2465449d41683946fcf631adb4ad020777eaa27 Binary files /dev/null and b/examples/elephant_sr.mp4 differ diff --git a/examples/flamingo.mp4 b/examples/flamingo.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..14374d87e19dd89d9379947298abd4dbd308b54b Binary files /dev/null and b/examples/flamingo.mp4 differ diff --git a/examples/flamingo_sr.mp4 b/examples/flamingo_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..68d9f2b5670c77ee6b5c9e012293bbcf5bc6d4d2 Binary files /dev/null and b/examples/flamingo_sr.mp4 differ diff --git a/examples/koala.mp4 b/examples/koala.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..10c0a56fa294dd3078db6c26c0d351db46aefead Binary files /dev/null and b/examples/koala.mp4 differ diff --git a/examples/koala_sr.mp4 b/examples/koala_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..5ff1a81cc5929932510dacb1f22c32de37292450 Binary files /dev/null and b/examples/koala_sr.mp4 differ diff --git a/examples/rhino.mp4 b/examples/rhino.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3c522a5e15d9d5af0b9eb23cad7a61d5a4f0678a Binary files /dev/null and b/examples/rhino.mp4 differ diff --git a/examples/rhino_sr.mp4 b/examples/rhino_sr.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..487379a56a34b352876134768fbf852b4abfeb6e Binary files /dev/null and b/examples/rhino_sr.mp4 differ diff --git a/examples_frames/bus/00000.jpg b/examples_frames/bus/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7bf9321ee93376c51483d36d43d6980b8b5f2ee5 Binary files /dev/null and b/examples_frames/bus/00000.jpg differ diff --git a/examples_frames/bus/00001.jpg b/examples_frames/bus/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ea3a563ac8ec4e2894de3a7408d34efd5a4d7135 Binary files /dev/null and b/examples_frames/bus/00001.jpg differ diff --git a/examples_frames/bus/00002.jpg b/examples_frames/bus/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fd0416c07c1854e59b28304d11bff79bcf526ddc Binary files /dev/null and b/examples_frames/bus/00002.jpg differ diff --git a/examples_frames/bus/00003.jpg b/examples_frames/bus/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..94f8db4de56e5f49357baff501a68ff855af9e46 Binary files /dev/null and b/examples_frames/bus/00003.jpg differ diff --git a/examples_frames/bus/00004.jpg b/examples_frames/bus/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5d2442b04ecb899ea841ed9c45af46ff68f03fdb Binary files /dev/null and b/examples_frames/bus/00004.jpg differ diff --git a/examples_frames/bus/00005.jpg b/examples_frames/bus/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e7382d8c0f4ecaa375b0099aacaf54c454ab078d Binary files /dev/null and b/examples_frames/bus/00005.jpg differ diff --git a/examples_frames/bus/00006.jpg b/examples_frames/bus/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b810303d9397e95520b15124798d0549607aeacc Binary files /dev/null and b/examples_frames/bus/00006.jpg differ diff --git a/examples_frames/bus/00007.jpg b/examples_frames/bus/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b233351197765c43742d55cdb5eac08fad521328 Binary files /dev/null and b/examples_frames/bus/00007.jpg differ diff --git a/examples_frames/bus/00008.jpg b/examples_frames/bus/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..23f9db309af12f9d61a3b2d900041678c3b85201 Binary files /dev/null and b/examples_frames/bus/00008.jpg differ diff --git a/examples_frames/bus/00009.jpg b/examples_frames/bus/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab2369a23fec81431d9deef1fac8f60248acb84a Binary files /dev/null and b/examples_frames/bus/00009.jpg differ diff --git a/examples_frames/bus/00010.jpg b/examples_frames/bus/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63f6004b12b258f260fe20d56bb41fd5fa0061c0 Binary files /dev/null and b/examples_frames/bus/00010.jpg differ diff --git a/examples_frames/bus/00011.jpg b/examples_frames/bus/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c1ff2bf124063908ed379ff0bd8740b60a5c2ac2 Binary files /dev/null and b/examples_frames/bus/00011.jpg differ diff --git a/examples_frames/bus/00012.jpg b/examples_frames/bus/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f4ab05be031df1f2c14eb7aab17d9b4e9515efd0 Binary files /dev/null and b/examples_frames/bus/00012.jpg differ diff --git a/examples_frames/bus/00013.jpg b/examples_frames/bus/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e564995592bac408d250d7858970a3a10a33c80e Binary files /dev/null and b/examples_frames/bus/00013.jpg differ diff --git a/examples_frames/bus/00014.jpg b/examples_frames/bus/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68d56be0afb717763c3d3dfacda936510fdf89b0 Binary files /dev/null and b/examples_frames/bus/00014.jpg differ diff --git a/examples_frames/bus/00015.jpg b/examples_frames/bus/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f6011ad4d8b3cc61b265e1dab366ca920bf6eecd Binary files /dev/null and b/examples_frames/bus/00015.jpg differ diff --git a/examples_frames/bus/00016.jpg b/examples_frames/bus/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aa7211c88169992a9f70b5eb2c88548c3d155f09 Binary files /dev/null and b/examples_frames/bus/00016.jpg differ diff --git a/examples_frames/bus/00017.jpg b/examples_frames/bus/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0ce63e6bcb8cdca9a511294c2e2d96dad730e408 Binary files /dev/null and b/examples_frames/bus/00017.jpg differ diff --git a/examples_frames/bus/00018.jpg b/examples_frames/bus/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..60fba7aba88e89abe6d68530c25a12a385360ec3 Binary files /dev/null and b/examples_frames/bus/00018.jpg differ diff --git a/examples_frames/bus/00019.jpg b/examples_frames/bus/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ad15ce697569d9c0a18238ba7fe790915edaeb0 Binary files /dev/null and b/examples_frames/bus/00019.jpg differ diff --git a/examples_frames/bus_sr/00000.jpg b/examples_frames/bus_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dfade87b77f8096a9da18efddedf57ddb827a7e6 Binary files /dev/null and b/examples_frames/bus_sr/00000.jpg differ diff --git a/examples_frames/bus_sr/00001.jpg b/examples_frames/bus_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fbc2b7dbc1ed845c92b1abff1a3bdebf6534ac81 Binary files /dev/null and b/examples_frames/bus_sr/00001.jpg differ diff --git a/examples_frames/bus_sr/00002.jpg b/examples_frames/bus_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6c5c3c3fa25808fffa2cb04a11be5c939fa5aa2e Binary files /dev/null and b/examples_frames/bus_sr/00002.jpg differ diff --git a/examples_frames/bus_sr/00003.jpg b/examples_frames/bus_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..82c2857bf24836ea9f7cddf6b74c4bb4a0ee792a Binary files /dev/null and b/examples_frames/bus_sr/00003.jpg differ diff --git a/examples_frames/bus_sr/00004.jpg b/examples_frames/bus_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ec6ab6e2723d2605d1d3fa20670e5cf8e8cae437 Binary files /dev/null and b/examples_frames/bus_sr/00004.jpg differ diff --git a/examples_frames/bus_sr/00005.jpg b/examples_frames/bus_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6436ab6d28fec97afbaa2090c9707165adbe1c50 Binary files /dev/null and b/examples_frames/bus_sr/00005.jpg differ diff --git a/examples_frames/bus_sr/00006.jpg b/examples_frames/bus_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4b9a7acf7c6c6536c90be00bb5d578dfecb2b7a Binary files /dev/null and b/examples_frames/bus_sr/00006.jpg differ diff --git a/examples_frames/bus_sr/00007.jpg b/examples_frames/bus_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b76a1d1033f39f3950b95b325330d30c92d94435 Binary files /dev/null and b/examples_frames/bus_sr/00007.jpg differ diff --git a/examples_frames/bus_sr/00008.jpg b/examples_frames/bus_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c8535ebd8d5738d9aa2ffc208809a3402214597d Binary files /dev/null and b/examples_frames/bus_sr/00008.jpg differ diff --git a/examples_frames/bus_sr/00009.jpg b/examples_frames/bus_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4c2f4657abaf5522ccb1a4a2a73250a55ab6031e Binary files /dev/null and b/examples_frames/bus_sr/00009.jpg differ diff --git a/examples_frames/bus_sr/00010.jpg b/examples_frames/bus_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c37824f02fa6a34b8063ae0c6a7e5fd71eaff82c Binary files /dev/null and b/examples_frames/bus_sr/00010.jpg differ diff --git a/examples_frames/bus_sr/00011.jpg b/examples_frames/bus_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..69a760b1adaf697c2bebe0d51ebaa40a1c33e111 Binary files /dev/null and b/examples_frames/bus_sr/00011.jpg differ diff --git a/examples_frames/bus_sr/00012.jpg b/examples_frames/bus_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6593db42a4053c44181e3613b6955bc0d44307b Binary files /dev/null and b/examples_frames/bus_sr/00012.jpg differ diff --git a/examples_frames/bus_sr/00013.jpg b/examples_frames/bus_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..53606a3a4195d7499aa0296da9e005d9feb363a7 Binary files /dev/null and b/examples_frames/bus_sr/00013.jpg differ diff --git a/examples_frames/bus_sr/00014.jpg b/examples_frames/bus_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f330ce17af2ad53ef297f98115f198ea888bfa5b Binary files /dev/null and b/examples_frames/bus_sr/00014.jpg differ diff --git a/examples_frames/bus_sr/00015.jpg b/examples_frames/bus_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce7f9a5313134a9289c246c076a98d52b5752c1b Binary files /dev/null and b/examples_frames/bus_sr/00015.jpg differ diff --git a/examples_frames/bus_sr/00016.jpg b/examples_frames/bus_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b5f481fd9621b212c47dfa298bb53843e906c0a1 Binary files /dev/null and b/examples_frames/bus_sr/00016.jpg differ diff --git a/examples_frames/bus_sr/00017.jpg b/examples_frames/bus_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9921f45339ba3ecd8b6413da5e772089aa4473d0 Binary files /dev/null and b/examples_frames/bus_sr/00017.jpg differ diff --git a/examples_frames/bus_sr/00018.jpg b/examples_frames/bus_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3028fb0fb3a5451e3ed20f762aeb40d5e6b56ef1 Binary files /dev/null and b/examples_frames/bus_sr/00018.jpg differ diff --git a/examples_frames/bus_sr/00019.jpg b/examples_frames/bus_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..67f075f671434152a5f43fe21f6a4b7be4bfbc6b Binary files /dev/null and b/examples_frames/bus_sr/00019.jpg differ diff --git a/examples_frames/dog-agility/00000.jpg b/examples_frames/dog-agility/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f42ada3735bbe3e5ca7cdb6027c9c3b64fa90f15 Binary files /dev/null and b/examples_frames/dog-agility/00000.jpg differ diff --git a/examples_frames/dog-agility/00001.jpg b/examples_frames/dog-agility/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6fc1c7bf675d71adb28c061edcd62bad96bc08a7 Binary files /dev/null and b/examples_frames/dog-agility/00001.jpg differ diff --git a/examples_frames/dog-agility/00002.jpg b/examples_frames/dog-agility/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7cbdc9e891283664f803c224d30261bf808e1557 Binary files /dev/null and b/examples_frames/dog-agility/00002.jpg differ diff --git a/examples_frames/dog-agility/00003.jpg b/examples_frames/dog-agility/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4a1c586183e138440e5dc3d13a564c14f379f61b Binary files /dev/null and b/examples_frames/dog-agility/00003.jpg differ diff --git a/examples_frames/dog-agility/00004.jpg b/examples_frames/dog-agility/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..28a741f9a019a88ee8872d1ded1cec2397bbde8b Binary files /dev/null and b/examples_frames/dog-agility/00004.jpg differ diff --git a/examples_frames/dog-agility/00005.jpg b/examples_frames/dog-agility/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dc4c63c9047dbf6f33192c1ea182afb5c2e3e748 Binary files /dev/null and b/examples_frames/dog-agility/00005.jpg differ diff --git a/examples_frames/dog-agility/00006.jpg b/examples_frames/dog-agility/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..230ab333b9105750d8cad568343c3541134d76e6 Binary files /dev/null and b/examples_frames/dog-agility/00006.jpg differ diff --git a/examples_frames/dog-agility/00007.jpg b/examples_frames/dog-agility/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b77ff6c92074b57dd71897e33f7d207a467678c1 Binary files /dev/null and b/examples_frames/dog-agility/00007.jpg differ diff --git a/examples_frames/dog-agility/00008.jpg b/examples_frames/dog-agility/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d995939b43146a00a98425126f44f2f78fffe7d5 Binary files /dev/null and b/examples_frames/dog-agility/00008.jpg differ diff --git a/examples_frames/dog-agility/00009.jpg b/examples_frames/dog-agility/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..507f2952113b7ab65f41d7cdf7a304c36353b498 Binary files /dev/null and b/examples_frames/dog-agility/00009.jpg differ diff --git a/examples_frames/dog-agility/00010.jpg b/examples_frames/dog-agility/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0fa25d9805989b5224818dec98e2a5dba47e3d62 Binary files /dev/null and b/examples_frames/dog-agility/00010.jpg differ diff --git a/examples_frames/dog-agility/00011.jpg b/examples_frames/dog-agility/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..31f9671d6b511ba0246e0c4f055674b090375beb Binary files /dev/null and b/examples_frames/dog-agility/00011.jpg differ diff --git a/examples_frames/dog-agility/00012.jpg b/examples_frames/dog-agility/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1441d5230b8f97bd7ded21f3730a94806a41dcce Binary files /dev/null and b/examples_frames/dog-agility/00012.jpg differ diff --git a/examples_frames/dog-agility/00013.jpg b/examples_frames/dog-agility/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51b50ac2a76fce73c377b4d33457035e6cd56bc6 Binary files /dev/null and b/examples_frames/dog-agility/00013.jpg differ diff --git a/examples_frames/dog-agility/00014.jpg b/examples_frames/dog-agility/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a7445afddb86bfdff082b260ef8801238ef3ac4f Binary files /dev/null and b/examples_frames/dog-agility/00014.jpg differ diff --git a/examples_frames/dog-agility/00015.jpg b/examples_frames/dog-agility/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e7043d2341ca0eda1c6ba9b91272504592ebc48e Binary files /dev/null and b/examples_frames/dog-agility/00015.jpg differ diff --git a/examples_frames/dog-agility/00016.jpg b/examples_frames/dog-agility/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..be9eac4df6b0f82f46b74313f08996ce6b302fee Binary files /dev/null and b/examples_frames/dog-agility/00016.jpg differ diff --git a/examples_frames/dog-agility/00017.jpg b/examples_frames/dog-agility/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d8184b17ad58f7c384413d2255f6fb0abd3f544 Binary files /dev/null and b/examples_frames/dog-agility/00017.jpg differ diff --git a/examples_frames/dog-agility/00018.jpg b/examples_frames/dog-agility/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..57a12a1e3aa818fb622404cb452b4de84c8b7ea4 Binary files /dev/null and b/examples_frames/dog-agility/00018.jpg differ diff --git a/examples_frames/dog-agility/00019.jpg b/examples_frames/dog-agility/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c30184f785f2d06610c0eb21cb7b50d3305ccea Binary files /dev/null and b/examples_frames/dog-agility/00019.jpg differ diff --git a/examples_frames/dog-agility_sr/00000.jpg b/examples_frames/dog-agility_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..27fed5900aed17180bba86e3e1c41a26c3d169df Binary files /dev/null and b/examples_frames/dog-agility_sr/00000.jpg differ diff --git a/examples_frames/dog-agility_sr/00001.jpg b/examples_frames/dog-agility_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e126f26fcd42aeeb96cf1a4913073a36d9faef1e Binary files /dev/null and b/examples_frames/dog-agility_sr/00001.jpg differ diff --git a/examples_frames/dog-agility_sr/00002.jpg b/examples_frames/dog-agility_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4625a0e3edfba27bcfcc2786f4e6936c68d620dd Binary files /dev/null and b/examples_frames/dog-agility_sr/00002.jpg differ diff --git a/examples_frames/dog-agility_sr/00003.jpg b/examples_frames/dog-agility_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d229f14d11c61a5efd680eeaa73f9d56c3806df2 Binary files /dev/null and b/examples_frames/dog-agility_sr/00003.jpg differ diff --git a/examples_frames/dog-agility_sr/00004.jpg b/examples_frames/dog-agility_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..13549cc1a6573dec855c20e0e46ad9745115fe54 Binary files /dev/null and b/examples_frames/dog-agility_sr/00004.jpg differ diff --git a/examples_frames/dog-agility_sr/00005.jpg b/examples_frames/dog-agility_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3baa5e239555823af021675ed782a5ac4999926 Binary files /dev/null and b/examples_frames/dog-agility_sr/00005.jpg differ diff --git a/examples_frames/dog-agility_sr/00006.jpg b/examples_frames/dog-agility_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5c60a62c1a310c4095a7e6a26be3c0cf295b2f9 Binary files /dev/null and b/examples_frames/dog-agility_sr/00006.jpg differ diff --git a/examples_frames/dog-agility_sr/00007.jpg b/examples_frames/dog-agility_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..84d30f6679b34fa0ce15186a7a58ac86f5629562 Binary files /dev/null and b/examples_frames/dog-agility_sr/00007.jpg differ diff --git a/examples_frames/dog-agility_sr/00008.jpg b/examples_frames/dog-agility_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..75f9715c247269acb3e9bddd859a68db1f9f57d3 Binary files /dev/null and b/examples_frames/dog-agility_sr/00008.jpg differ diff --git a/examples_frames/dog-agility_sr/00009.jpg b/examples_frames/dog-agility_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a665b2f359993ff3900b5fb5c54dcafdd44f2edd Binary files /dev/null and b/examples_frames/dog-agility_sr/00009.jpg differ diff --git a/examples_frames/dog-agility_sr/00010.jpg b/examples_frames/dog-agility_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..545d23a4268891f4cc27cb41dd7f4a1637ec575e Binary files /dev/null and b/examples_frames/dog-agility_sr/00010.jpg differ diff --git a/examples_frames/dog-agility_sr/00011.jpg b/examples_frames/dog-agility_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a35b823a253a9b67a8362962339b5fd3492e70e8 Binary files /dev/null and b/examples_frames/dog-agility_sr/00011.jpg differ diff --git a/examples_frames/dog-agility_sr/00012.jpg b/examples_frames/dog-agility_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..48ac8b6a8dd93fb22b192167aedc81a3a03b3c95 Binary files /dev/null and b/examples_frames/dog-agility_sr/00012.jpg differ diff --git a/examples_frames/dog-agility_sr/00013.jpg b/examples_frames/dog-agility_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d28e7a0bf34f2b85bf6012ba596dcf75468c54a2 Binary files /dev/null and b/examples_frames/dog-agility_sr/00013.jpg differ diff --git a/examples_frames/dog-agility_sr/00014.jpg b/examples_frames/dog-agility_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e860e7e96752ff9e3c22969881d2041888b8cde9 Binary files /dev/null and b/examples_frames/dog-agility_sr/00014.jpg differ diff --git a/examples_frames/dog-agility_sr/00015.jpg b/examples_frames/dog-agility_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4cbf214a7bb7f77a4098f05953d0128ff0625a0c Binary files /dev/null and b/examples_frames/dog-agility_sr/00015.jpg differ diff --git a/examples_frames/dog-agility_sr/00016.jpg b/examples_frames/dog-agility_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7bc24b8e49a04ba5524197644cafc071ae1a691b Binary files /dev/null and b/examples_frames/dog-agility_sr/00016.jpg differ diff --git a/examples_frames/dog-agility_sr/00017.jpg b/examples_frames/dog-agility_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3871030a432e5f48491992f45debca3aa299c6b6 Binary files /dev/null and b/examples_frames/dog-agility_sr/00017.jpg differ diff --git a/examples_frames/dog-agility_sr/00018.jpg b/examples_frames/dog-agility_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..85cf7dee5c5cbe97bbbbf794b60eeb29f37412e9 Binary files /dev/null and b/examples_frames/dog-agility_sr/00018.jpg differ diff --git a/examples_frames/dog-agility_sr/00019.jpg b/examples_frames/dog-agility_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..17ab30c3c7f6d5f23f3ec9053ec68024c6d5426c Binary files /dev/null and b/examples_frames/dog-agility_sr/00019.jpg differ diff --git a/examples_frames/dog-gooses/00000.jpg b/examples_frames/dog-gooses/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76f55848e5d4ecaf51a9d15c83f7d2f9ab6fc115 Binary files /dev/null and b/examples_frames/dog-gooses/00000.jpg differ diff --git a/examples_frames/dog-gooses/00001.jpg b/examples_frames/dog-gooses/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6276318e9c6ef8782640e023c47754c5f3358049 Binary files /dev/null and b/examples_frames/dog-gooses/00001.jpg differ diff --git a/examples_frames/dog-gooses/00002.jpg b/examples_frames/dog-gooses/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f8d858bf64c5538e86f38bdf4a9d31b23794077d Binary files /dev/null and b/examples_frames/dog-gooses/00002.jpg differ diff --git a/examples_frames/dog-gooses/00003.jpg b/examples_frames/dog-gooses/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b47ab3b0d61ecd68dcdd5d724328eebbc9c82391 Binary files /dev/null and b/examples_frames/dog-gooses/00003.jpg differ diff --git a/examples_frames/dog-gooses/00004.jpg b/examples_frames/dog-gooses/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d25d4fa46d9d7770f4ecd2794679a93eb08a8cfa Binary files /dev/null and b/examples_frames/dog-gooses/00004.jpg differ diff --git a/examples_frames/dog-gooses/00005.jpg b/examples_frames/dog-gooses/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2fed2e2c7a2843a8fd47000b7a0af1f8701f4be Binary files /dev/null and b/examples_frames/dog-gooses/00005.jpg differ diff --git a/examples_frames/dog-gooses/00006.jpg b/examples_frames/dog-gooses/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..277f8314615d87a8d67d2b4f1612f126740d1a27 Binary files /dev/null and b/examples_frames/dog-gooses/00006.jpg differ diff --git a/examples_frames/dog-gooses/00007.jpg b/examples_frames/dog-gooses/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9aae600be2427671feb442d930966f9d55cead7c Binary files /dev/null and b/examples_frames/dog-gooses/00007.jpg differ diff --git a/examples_frames/dog-gooses/00008.jpg b/examples_frames/dog-gooses/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d25a4d8787cf060b8a488e80acb85385554525f2 Binary files /dev/null and b/examples_frames/dog-gooses/00008.jpg differ diff --git a/examples_frames/dog-gooses/00009.jpg b/examples_frames/dog-gooses/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5ab187086453389fa68a454579ed3c4a5264eef4 Binary files /dev/null and b/examples_frames/dog-gooses/00009.jpg differ diff --git a/examples_frames/dog-gooses/00010.jpg b/examples_frames/dog-gooses/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8a59a06ce57e3ce1660d8d0ed14ceff6ac1689ed Binary files /dev/null and b/examples_frames/dog-gooses/00010.jpg differ diff --git a/examples_frames/dog-gooses/00011.jpg b/examples_frames/dog-gooses/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e27e9e54810cadd1ac62a41ec235398a05cfe34d Binary files /dev/null and b/examples_frames/dog-gooses/00011.jpg differ diff --git a/examples_frames/dog-gooses/00012.jpg b/examples_frames/dog-gooses/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f414df659d7b9645bcd1730da0ce52410f034d0e Binary files /dev/null and b/examples_frames/dog-gooses/00012.jpg differ diff --git a/examples_frames/dog-gooses/00013.jpg b/examples_frames/dog-gooses/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..325f6f8341ab69e272ebb9c7f56b194fd5200ebd Binary files /dev/null and b/examples_frames/dog-gooses/00013.jpg differ diff --git a/examples_frames/dog-gooses/00014.jpg b/examples_frames/dog-gooses/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3cf7e43d49af1c15fbfe7a12cf67817cf67ed101 Binary files /dev/null and b/examples_frames/dog-gooses/00014.jpg differ diff --git a/examples_frames/dog-gooses/00015.jpg b/examples_frames/dog-gooses/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..43f345cfbf2d395d98710b2608e48bf1307c99e3 Binary files /dev/null and b/examples_frames/dog-gooses/00015.jpg differ diff --git a/examples_frames/dog-gooses/00016.jpg b/examples_frames/dog-gooses/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03f2295b3326c5c025f7e1be2138eff05189f7e5 Binary files /dev/null and b/examples_frames/dog-gooses/00016.jpg differ diff --git a/examples_frames/dog-gooses/00017.jpg b/examples_frames/dog-gooses/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dae774a38ab8c25e9da388886f26194758162e1e Binary files /dev/null and b/examples_frames/dog-gooses/00017.jpg differ diff --git a/examples_frames/dog-gooses/00018.jpg b/examples_frames/dog-gooses/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fc81fc9e08ac2cfc641cde04d620702441aa231f Binary files /dev/null and b/examples_frames/dog-gooses/00018.jpg differ diff --git a/examples_frames/dog-gooses/00019.jpg b/examples_frames/dog-gooses/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5690d605b9ef9b4f410a0c454d24e1150c926034 Binary files /dev/null and b/examples_frames/dog-gooses/00019.jpg differ diff --git a/examples_frames/dog-gooses_sr/00000.jpg b/examples_frames/dog-gooses_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a966c99af929d244e6d10f581485e38fa091c306 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00000.jpg differ diff --git a/examples_frames/dog-gooses_sr/00001.jpg b/examples_frames/dog-gooses_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0fc2c2e667827323a47a3fe9f1c79d725117ff8 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00001.jpg differ diff --git a/examples_frames/dog-gooses_sr/00002.jpg b/examples_frames/dog-gooses_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ca4511519381f1c005760e580d062eba2208fbf7 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00002.jpg differ diff --git a/examples_frames/dog-gooses_sr/00003.jpg b/examples_frames/dog-gooses_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c7a3f07f065ab76f7cce640e5802398676833fd Binary files /dev/null and b/examples_frames/dog-gooses_sr/00003.jpg differ diff --git a/examples_frames/dog-gooses_sr/00004.jpg b/examples_frames/dog-gooses_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bbe57548797b85a45d9aed3a40a7df4ab12d91ee Binary files /dev/null and b/examples_frames/dog-gooses_sr/00004.jpg differ diff --git a/examples_frames/dog-gooses_sr/00005.jpg b/examples_frames/dog-gooses_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c30db749f3d4f1dfca1228d90d90a0b8c00c11df Binary files /dev/null and b/examples_frames/dog-gooses_sr/00005.jpg differ diff --git a/examples_frames/dog-gooses_sr/00006.jpg b/examples_frames/dog-gooses_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37652ae4289fd9c874dd8b778d50b6e76cec2d0d Binary files /dev/null and b/examples_frames/dog-gooses_sr/00006.jpg differ diff --git a/examples_frames/dog-gooses_sr/00007.jpg b/examples_frames/dog-gooses_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc930629297f2418360a7c2872f64221a9b18954 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00007.jpg differ diff --git a/examples_frames/dog-gooses_sr/00008.jpg b/examples_frames/dog-gooses_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df0394c7c495db74a45efbd2c954f4d73d4dd238 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00008.jpg differ diff --git a/examples_frames/dog-gooses_sr/00009.jpg b/examples_frames/dog-gooses_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5b66984c42b82b84420e7dea9e96c44c53b1cc12 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00009.jpg differ diff --git a/examples_frames/dog-gooses_sr/00010.jpg b/examples_frames/dog-gooses_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5081792bc9bbfd921cf38de37cffb8e4fe34cd92 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00010.jpg differ diff --git a/examples_frames/dog-gooses_sr/00011.jpg b/examples_frames/dog-gooses_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db0844ba2750fb7361255aee387d1f187bb09571 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00011.jpg differ diff --git a/examples_frames/dog-gooses_sr/00012.jpg b/examples_frames/dog-gooses_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..978cb8a8b12395fb3c689f2e845f7e9733f87487 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00012.jpg differ diff --git a/examples_frames/dog-gooses_sr/00013.jpg b/examples_frames/dog-gooses_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..41be38aa850b985b9a3134b35c07f840aa4f1ea8 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00013.jpg differ diff --git a/examples_frames/dog-gooses_sr/00014.jpg b/examples_frames/dog-gooses_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9dcc1f743613fd54abc0d539df27320139be8da9 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00014.jpg differ diff --git a/examples_frames/dog-gooses_sr/00015.jpg b/examples_frames/dog-gooses_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c36a3a20299ff395652791800b10ad6033900a95 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00015.jpg differ diff --git a/examples_frames/dog-gooses_sr/00016.jpg b/examples_frames/dog-gooses_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5ac67a75035841413b794db61c3aab495c88179 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00016.jpg differ diff --git a/examples_frames/dog-gooses_sr/00017.jpg b/examples_frames/dog-gooses_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6e5870d3cb1e990d7d6567e7fcff17ca18fd7583 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00017.jpg differ diff --git a/examples_frames/dog-gooses_sr/00018.jpg b/examples_frames/dog-gooses_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..770b6d21afa463e4c0fbd61fa1562e941397a53c Binary files /dev/null and b/examples_frames/dog-gooses_sr/00018.jpg differ diff --git a/examples_frames/dog-gooses_sr/00019.jpg b/examples_frames/dog-gooses_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e38e437ad6ec2837cf2b7aa4e1473a51f52c3bd0 Binary files /dev/null and b/examples_frames/dog-gooses_sr/00019.jpg differ diff --git a/examples_frames/elephant/00000.jpg b/examples_frames/elephant/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0a74df589848617d2726cd9f8f5305f997dd3fdf Binary files /dev/null and b/examples_frames/elephant/00000.jpg differ diff --git a/examples_frames/elephant/00001.jpg b/examples_frames/elephant/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fdd4d98ef4b39d930b633b771baef2cfcea0bab1 Binary files /dev/null and b/examples_frames/elephant/00001.jpg differ diff --git a/examples_frames/elephant/00002.jpg b/examples_frames/elephant/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8cbbedaa8bcbb2bf8e15e8a0ac8b600c5b7038b0 Binary files /dev/null and b/examples_frames/elephant/00002.jpg differ diff --git a/examples_frames/elephant/00003.jpg b/examples_frames/elephant/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16ecab6542094179e88c9b3332c05e73087b014d Binary files /dev/null and b/examples_frames/elephant/00003.jpg differ diff --git a/examples_frames/elephant/00004.jpg b/examples_frames/elephant/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9f954daa78456846c6aa920654087b61c36b921 Binary files /dev/null and b/examples_frames/elephant/00004.jpg differ diff --git a/examples_frames/elephant/00005.jpg b/examples_frames/elephant/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ac98c4299be657fe905f36a1de8fe5bac622dc7 Binary files /dev/null and b/examples_frames/elephant/00005.jpg differ diff --git a/examples_frames/elephant/00006.jpg b/examples_frames/elephant/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0b62bd04e0c4230bdb7a44b1a12ad2a953ce936b Binary files /dev/null and b/examples_frames/elephant/00006.jpg differ diff --git a/examples_frames/elephant/00007.jpg b/examples_frames/elephant/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2681858c00bacc08428bdf0bb86934b742175a9e Binary files /dev/null and b/examples_frames/elephant/00007.jpg differ diff --git a/examples_frames/elephant/00008.jpg b/examples_frames/elephant/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ca604d8f5ee3dbc4d9d023b6b0224a9ed186f78a Binary files /dev/null and b/examples_frames/elephant/00008.jpg differ diff --git a/examples_frames/elephant/00009.jpg b/examples_frames/elephant/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f6617e2f8da574765ca093a28713d095cc89708 Binary files /dev/null and b/examples_frames/elephant/00009.jpg differ diff --git a/examples_frames/elephant/00010.jpg b/examples_frames/elephant/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5ee2f9aa8f0ea654a11765269843a109553343e0 Binary files /dev/null and b/examples_frames/elephant/00010.jpg differ diff --git a/examples_frames/elephant/00011.jpg b/examples_frames/elephant/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b9453f89cc676028ad4295690f53812695891a49 Binary files /dev/null and b/examples_frames/elephant/00011.jpg differ diff --git a/examples_frames/elephant/00012.jpg b/examples_frames/elephant/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4f56687c1612392617831774d0645e3fe32280e0 Binary files /dev/null and b/examples_frames/elephant/00012.jpg differ diff --git a/examples_frames/elephant/00013.jpg b/examples_frames/elephant/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3a3b9680a61b9ed8471e66d5cd952f4f6d9ff96e Binary files /dev/null and b/examples_frames/elephant/00013.jpg differ diff --git a/examples_frames/elephant/00014.jpg b/examples_frames/elephant/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..058d55de339b699f5fd49a010199378b1cb5a2b7 Binary files /dev/null and b/examples_frames/elephant/00014.jpg differ diff --git a/examples_frames/elephant/00015.jpg b/examples_frames/elephant/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e44f9588951f1b91c86fb49c0038b4c87ee1d687 Binary files /dev/null and b/examples_frames/elephant/00015.jpg differ diff --git a/examples_frames/elephant/00016.jpg b/examples_frames/elephant/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a73817f8bfac19080896f5a8716f60a462f0bac2 Binary files /dev/null and b/examples_frames/elephant/00016.jpg differ diff --git a/examples_frames/elephant/00017.jpg b/examples_frames/elephant/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5458cebab192cb3f4ea395ee82b939771c8e6f38 Binary files /dev/null and b/examples_frames/elephant/00017.jpg differ diff --git a/examples_frames/elephant/00018.jpg b/examples_frames/elephant/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8cc19e243732052ee3be31dddcb81656f1e466d0 Binary files /dev/null and b/examples_frames/elephant/00018.jpg differ diff --git a/examples_frames/elephant/00019.jpg b/examples_frames/elephant/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dc6e3490228e340ee1e26120cb708f7eaf6e94bd Binary files /dev/null and b/examples_frames/elephant/00019.jpg differ diff --git a/examples_frames/elephant_sr/00000.jpg b/examples_frames/elephant_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6dec2713dc9f338daf40c5ee64b7c75b1b00e5a Binary files /dev/null and b/examples_frames/elephant_sr/00000.jpg differ diff --git a/examples_frames/elephant_sr/00001.jpg b/examples_frames/elephant_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..48febcd59804b8a908b3e1443e23a25418d89607 Binary files /dev/null and b/examples_frames/elephant_sr/00001.jpg differ diff --git a/examples_frames/elephant_sr/00002.jpg b/examples_frames/elephant_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f5d89fe9ebcbdd10038d9ffffd44929b297542a9 Binary files /dev/null and b/examples_frames/elephant_sr/00002.jpg differ diff --git a/examples_frames/elephant_sr/00003.jpg b/examples_frames/elephant_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1bedd3004f8b5c7379fb3956133913ad373b73cf Binary files /dev/null and b/examples_frames/elephant_sr/00003.jpg differ diff --git a/examples_frames/elephant_sr/00004.jpg b/examples_frames/elephant_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db24083db764a6456fb00a7fee2940b90b94d56e Binary files /dev/null and b/examples_frames/elephant_sr/00004.jpg differ diff --git a/examples_frames/elephant_sr/00005.jpg b/examples_frames/elephant_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ae1121383141324cc5fcdc310da5da69ff0858c Binary files /dev/null and b/examples_frames/elephant_sr/00005.jpg differ diff --git a/examples_frames/elephant_sr/00006.jpg b/examples_frames/elephant_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..25f2864c6ab0857ceb5cc69fb8b15bfdb228feb6 Binary files /dev/null and b/examples_frames/elephant_sr/00006.jpg differ diff --git a/examples_frames/elephant_sr/00007.jpg b/examples_frames/elephant_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..241ef651afb29d28598984f3ac7add2fe89761d0 Binary files /dev/null and b/examples_frames/elephant_sr/00007.jpg differ diff --git a/examples_frames/elephant_sr/00008.jpg b/examples_frames/elephant_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..02c756f04b7f7d960ac0c5d03de7e0e44ce0eb29 Binary files /dev/null and b/examples_frames/elephant_sr/00008.jpg differ diff --git a/examples_frames/elephant_sr/00009.jpg b/examples_frames/elephant_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7ab5f2a39ad368dd0354b3ec21633a6f8afc9709 Binary files /dev/null and b/examples_frames/elephant_sr/00009.jpg differ diff --git a/examples_frames/elephant_sr/00010.jpg b/examples_frames/elephant_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..40beef2eb3a33f77811c2e62355b933282fb9476 Binary files /dev/null and b/examples_frames/elephant_sr/00010.jpg differ diff --git a/examples_frames/elephant_sr/00011.jpg b/examples_frames/elephant_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8365f99b022bf001854166f6c9dca29c53806eab Binary files /dev/null and b/examples_frames/elephant_sr/00011.jpg differ diff --git a/examples_frames/elephant_sr/00012.jpg b/examples_frames/elephant_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1a50228f4f838107239cc84919d6cb38d1e299f9 Binary files /dev/null and b/examples_frames/elephant_sr/00012.jpg differ diff --git a/examples_frames/elephant_sr/00013.jpg b/examples_frames/elephant_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a79dc8bad07edbfa78e0296ed4884efd516b9709 Binary files /dev/null and b/examples_frames/elephant_sr/00013.jpg differ diff --git a/examples_frames/elephant_sr/00014.jpg b/examples_frames/elephant_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12f29c33d61ac82b671ef6f3e702ec010cea89b4 Binary files /dev/null and b/examples_frames/elephant_sr/00014.jpg differ diff --git a/examples_frames/elephant_sr/00015.jpg b/examples_frames/elephant_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d5fffd097d4f7656383681530588e97bf75fcda Binary files /dev/null and b/examples_frames/elephant_sr/00015.jpg differ diff --git a/examples_frames/elephant_sr/00016.jpg b/examples_frames/elephant_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6262f0dc47dd64a43ea1c7820c611d7f1fbd9a0f Binary files /dev/null and b/examples_frames/elephant_sr/00016.jpg differ diff --git a/examples_frames/elephant_sr/00017.jpg b/examples_frames/elephant_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a9f759114c7286d02c9235a446bf298ebed808a5 Binary files /dev/null and b/examples_frames/elephant_sr/00017.jpg differ diff --git a/examples_frames/elephant_sr/00018.jpg b/examples_frames/elephant_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad2342c0ff122dadb22a427282a0f899bcf3e856 Binary files /dev/null and b/examples_frames/elephant_sr/00018.jpg differ diff --git a/examples_frames/elephant_sr/00019.jpg b/examples_frames/elephant_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2120732d4560e6ebcd675f628137323a74cee0b3 Binary files /dev/null and b/examples_frames/elephant_sr/00019.jpg differ diff --git a/examples_frames/flamingo/00000.jpg b/examples_frames/flamingo/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..385fdb5326b757f1914d3559ef218b8872b02b04 Binary files /dev/null and b/examples_frames/flamingo/00000.jpg differ diff --git a/examples_frames/flamingo/00001.jpg b/examples_frames/flamingo/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ff7c8cee9d0e76199e6057eacdeb956fd8b128fe Binary files /dev/null and b/examples_frames/flamingo/00001.jpg differ diff --git a/examples_frames/flamingo/00002.jpg b/examples_frames/flamingo/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ec0e716661f7821bdd037635efe7c2cc3148a1e9 Binary files /dev/null and b/examples_frames/flamingo/00002.jpg differ diff --git a/examples_frames/flamingo/00003.jpg b/examples_frames/flamingo/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a7b49367e379aa9a6532387a84c99dd084ea874f Binary files /dev/null and b/examples_frames/flamingo/00003.jpg differ diff --git a/examples_frames/flamingo/00004.jpg b/examples_frames/flamingo/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b5cc5b2c250a102a2ee7d7a935f6df00c41c0e23 Binary files /dev/null and b/examples_frames/flamingo/00004.jpg differ diff --git a/examples_frames/flamingo/00005.jpg b/examples_frames/flamingo/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d87d81226356a538419cc013e07f8cf584497fdd Binary files /dev/null and b/examples_frames/flamingo/00005.jpg differ diff --git a/examples_frames/flamingo/00006.jpg b/examples_frames/flamingo/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7591c4365d6e5cd7441be2e82902177ffcf7299d Binary files /dev/null and b/examples_frames/flamingo/00006.jpg differ diff --git a/examples_frames/flamingo/00007.jpg b/examples_frames/flamingo/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8e10b35597ff25413519b82b3e54877b39af546 Binary files /dev/null and b/examples_frames/flamingo/00007.jpg differ diff --git a/examples_frames/flamingo/00008.jpg b/examples_frames/flamingo/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fcacfee1c19bbf8291d706848b5df96a5a79e258 Binary files /dev/null and b/examples_frames/flamingo/00008.jpg differ diff --git a/examples_frames/flamingo/00009.jpg b/examples_frames/flamingo/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6b54182d55a432b906537ea30413db9168654ca0 Binary files /dev/null and b/examples_frames/flamingo/00009.jpg differ diff --git a/examples_frames/flamingo/00010.jpg b/examples_frames/flamingo/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99ef08b2b79644c99887b1f017ed7976b4a133f9 Binary files /dev/null and b/examples_frames/flamingo/00010.jpg differ diff --git a/examples_frames/flamingo/00011.jpg b/examples_frames/flamingo/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d62ad41902b54a03b4fceab5e94df7c4262dbf99 Binary files /dev/null and b/examples_frames/flamingo/00011.jpg differ diff --git a/examples_frames/flamingo/00012.jpg b/examples_frames/flamingo/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..21f2b6abb88318fc062e504a437e5435a5748add Binary files /dev/null and b/examples_frames/flamingo/00012.jpg differ diff --git a/examples_frames/flamingo/00013.jpg b/examples_frames/flamingo/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..464c38cbf06bab65f7d2734e8a5f53a696a9122e Binary files /dev/null and b/examples_frames/flamingo/00013.jpg differ diff --git a/examples_frames/flamingo/00014.jpg b/examples_frames/flamingo/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d1f38011d98723b9c9aeb7e23be29c1d25936deb Binary files /dev/null and b/examples_frames/flamingo/00014.jpg differ diff --git a/examples_frames/flamingo/00015.jpg b/examples_frames/flamingo/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df68ab79390ab8687c611bde047da680e0c19484 Binary files /dev/null and b/examples_frames/flamingo/00015.jpg differ diff --git a/examples_frames/flamingo/00016.jpg b/examples_frames/flamingo/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..effacbf007671ea03c42759c94aa8c6646f93bb3 Binary files /dev/null and b/examples_frames/flamingo/00016.jpg differ diff --git a/examples_frames/flamingo/00017.jpg b/examples_frames/flamingo/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eaeeb77afeab35b6aab2d66ac4fa43bb724ee454 Binary files /dev/null and b/examples_frames/flamingo/00017.jpg differ diff --git a/examples_frames/flamingo/00018.jpg b/examples_frames/flamingo/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c2c639a38997e1a55c0d1e768505fdaf97e92a2 Binary files /dev/null and b/examples_frames/flamingo/00018.jpg differ diff --git a/examples_frames/flamingo/00019.jpg b/examples_frames/flamingo/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e432029ff1e3125bc6d928dc13a5d53126249b8b Binary files /dev/null and b/examples_frames/flamingo/00019.jpg differ diff --git a/examples_frames/flamingo_sr/00000.jpg b/examples_frames/flamingo_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5a1b20b7c609a06fe0ae295540dfbe1821da12c3 Binary files /dev/null and b/examples_frames/flamingo_sr/00000.jpg differ diff --git a/examples_frames/flamingo_sr/00001.jpg b/examples_frames/flamingo_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..29ab5bf40a3197e1eae8e4c99605de711bbb6692 Binary files /dev/null and b/examples_frames/flamingo_sr/00001.jpg differ diff --git a/examples_frames/flamingo_sr/00002.jpg b/examples_frames/flamingo_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0b0469a29441ea30edb62d8353e7794e3839a346 Binary files /dev/null and b/examples_frames/flamingo_sr/00002.jpg differ diff --git a/examples_frames/flamingo_sr/00003.jpg b/examples_frames/flamingo_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4dc612dcd6c5bff61a5780ff270b75932a2c412f Binary files /dev/null and b/examples_frames/flamingo_sr/00003.jpg differ diff --git a/examples_frames/flamingo_sr/00004.jpg b/examples_frames/flamingo_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..617505e46be57960f2c21c8601a882512ca86f9d Binary files /dev/null and b/examples_frames/flamingo_sr/00004.jpg differ diff --git a/examples_frames/flamingo_sr/00005.jpg b/examples_frames/flamingo_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..becf365c272186c6d4b44b362caac181adcc534e Binary files /dev/null and b/examples_frames/flamingo_sr/00005.jpg differ diff --git a/examples_frames/flamingo_sr/00006.jpg b/examples_frames/flamingo_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4e082234c85c84c766204694cc2f8ad7d0a06b6 Binary files /dev/null and b/examples_frames/flamingo_sr/00006.jpg differ diff --git a/examples_frames/flamingo_sr/00007.jpg b/examples_frames/flamingo_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eaf4836dec41051ca6dd07cca46c76faece92dc4 Binary files /dev/null and b/examples_frames/flamingo_sr/00007.jpg differ diff --git a/examples_frames/flamingo_sr/00008.jpg b/examples_frames/flamingo_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8ddd907ba0bb309964e30e74677804b2d873c805 Binary files /dev/null and b/examples_frames/flamingo_sr/00008.jpg differ diff --git a/examples_frames/flamingo_sr/00009.jpg b/examples_frames/flamingo_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..552f306d6eb9b036fd3035acd3f21afa3d51ab5d Binary files /dev/null and b/examples_frames/flamingo_sr/00009.jpg differ diff --git a/examples_frames/flamingo_sr/00010.jpg b/examples_frames/flamingo_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..edc31d78ca99b05ee0f4e646ff7d46222a4e1b7a Binary files /dev/null and b/examples_frames/flamingo_sr/00010.jpg differ diff --git a/examples_frames/flamingo_sr/00011.jpg b/examples_frames/flamingo_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..86b9c0fc8ddffaeda39c5d3aac74acb38795049d Binary files /dev/null and b/examples_frames/flamingo_sr/00011.jpg differ diff --git a/examples_frames/flamingo_sr/00012.jpg b/examples_frames/flamingo_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6694c817373539c426fd28bb53543377db1b3f3e Binary files /dev/null and b/examples_frames/flamingo_sr/00012.jpg differ diff --git a/examples_frames/flamingo_sr/00013.jpg b/examples_frames/flamingo_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..171ba4895d761efbde92cab4053c067459cef323 Binary files /dev/null and b/examples_frames/flamingo_sr/00013.jpg differ diff --git a/examples_frames/flamingo_sr/00014.jpg b/examples_frames/flamingo_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7cc7f9fcd8a885e143f5ce5c2d1a73341918c562 Binary files /dev/null and b/examples_frames/flamingo_sr/00014.jpg differ diff --git a/examples_frames/flamingo_sr/00015.jpg b/examples_frames/flamingo_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20e76702ea475e3069504a647b301f5bf775651d Binary files /dev/null and b/examples_frames/flamingo_sr/00015.jpg differ diff --git a/examples_frames/flamingo_sr/00016.jpg b/examples_frames/flamingo_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..18306e7695abad6e6e46d9a3ed9d914ce98b220e Binary files /dev/null and b/examples_frames/flamingo_sr/00016.jpg differ diff --git a/examples_frames/flamingo_sr/00017.jpg b/examples_frames/flamingo_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f808cd87feb10a69edf9ea679c59fe3ff9cd545 Binary files /dev/null and b/examples_frames/flamingo_sr/00017.jpg differ diff --git a/examples_frames/flamingo_sr/00018.jpg b/examples_frames/flamingo_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8dd9d811d1a47d416ad2808054fe40ae0fe12c54 Binary files /dev/null and b/examples_frames/flamingo_sr/00018.jpg differ diff --git a/examples_frames/flamingo_sr/00019.jpg b/examples_frames/flamingo_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c6d5ae45be297212ba4236ff6f1dd3a34fc30b1b Binary files /dev/null and b/examples_frames/flamingo_sr/00019.jpg differ diff --git a/examples_frames/koala/00000.jpg b/examples_frames/koala/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68e1485a7e91293e9930bf153a32fd000fb03864 Binary files /dev/null and b/examples_frames/koala/00000.jpg differ diff --git a/examples_frames/koala/00001.jpg b/examples_frames/koala/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3975f89816a0e3f534cfd7b216f87f076342017f Binary files /dev/null and b/examples_frames/koala/00001.jpg differ diff --git a/examples_frames/koala/00002.jpg b/examples_frames/koala/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7a3896d0f19bf7da35cd3af05bf627358862a51f Binary files /dev/null and b/examples_frames/koala/00002.jpg differ diff --git a/examples_frames/koala/00003.jpg b/examples_frames/koala/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..71e6ebab233b6b9d51db3bd715b6d1c63a0a2840 Binary files /dev/null and b/examples_frames/koala/00003.jpg differ diff --git a/examples_frames/koala/00004.jpg b/examples_frames/koala/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d9b404456408b9fe90c4a429eeeb41399163f3a Binary files /dev/null and b/examples_frames/koala/00004.jpg differ diff --git a/examples_frames/koala/00005.jpg b/examples_frames/koala/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b7c9c11a27724d201a29a01987eae5bb9fd92e77 Binary files /dev/null and b/examples_frames/koala/00005.jpg differ diff --git a/examples_frames/koala/00006.jpg b/examples_frames/koala/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fcd477636799fed295ef0977de148f9fbeca4b5f Binary files /dev/null and b/examples_frames/koala/00006.jpg differ diff --git a/examples_frames/koala/00007.jpg b/examples_frames/koala/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ecc90fd22951300251ed29fd06b994c1dee5353 Binary files /dev/null and b/examples_frames/koala/00007.jpg differ diff --git a/examples_frames/koala/00008.jpg b/examples_frames/koala/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2382d0354da6e6908851732bf32ab1300cdb5592 Binary files /dev/null and b/examples_frames/koala/00008.jpg differ diff --git a/examples_frames/koala/00009.jpg b/examples_frames/koala/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..24c4b160367b176205ecd4018d7c16dc7194c1ca Binary files /dev/null and b/examples_frames/koala/00009.jpg differ diff --git a/examples_frames/koala/00010.jpg b/examples_frames/koala/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..70710c156f9cb0ba01130b6e7fffcf23b444e6c6 Binary files /dev/null and b/examples_frames/koala/00010.jpg differ diff --git a/examples_frames/koala/00011.jpg b/examples_frames/koala/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd73016f621bb08200edd94197ce8f0f4e441d77 Binary files /dev/null and b/examples_frames/koala/00011.jpg differ diff --git a/examples_frames/koala/00012.jpg b/examples_frames/koala/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..65c2d30222e9640504ac0648f2510d6271091243 Binary files /dev/null and b/examples_frames/koala/00012.jpg differ diff --git a/examples_frames/koala/00013.jpg b/examples_frames/koala/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..202396bf8fc979e5499b7fc89586bdc317340d65 Binary files /dev/null and b/examples_frames/koala/00013.jpg differ diff --git a/examples_frames/koala/00014.jpg b/examples_frames/koala/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1ffaa0b8b84001e63959a9826d958e5db8ca1a6 Binary files /dev/null and b/examples_frames/koala/00014.jpg differ diff --git a/examples_frames/koala/00015.jpg b/examples_frames/koala/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7a13af2da3eefb768cd139975c5dcdb84442ce35 Binary files /dev/null and b/examples_frames/koala/00015.jpg differ diff --git a/examples_frames/koala/00016.jpg b/examples_frames/koala/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b6a7f6dbbaeada9c73f7887da94d5af38d5407f8 Binary files /dev/null and b/examples_frames/koala/00016.jpg differ diff --git a/examples_frames/koala/00017.jpg b/examples_frames/koala/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ffe0f88c6b1534e5486fe1b288b030cd3755faee Binary files /dev/null and b/examples_frames/koala/00017.jpg differ diff --git a/examples_frames/koala/00018.jpg b/examples_frames/koala/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d9b62d90d9a45cea1c01b860f5c02153c713b14 Binary files /dev/null and b/examples_frames/koala/00018.jpg differ diff --git a/examples_frames/koala/00019.jpg b/examples_frames/koala/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3b4f0f242bf0d0bbacaa931455b2b5cabcb8a645 Binary files /dev/null and b/examples_frames/koala/00019.jpg differ diff --git a/examples_frames/koala_sr/00000.jpg b/examples_frames/koala_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a037e218e1727f0bbc8c849b2801d7f710d0e8b5 Binary files /dev/null and b/examples_frames/koala_sr/00000.jpg differ diff --git a/examples_frames/koala_sr/00001.jpg b/examples_frames/koala_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..15ee1e0849464c30bc08522c7803791625475f25 Binary files /dev/null and b/examples_frames/koala_sr/00001.jpg differ diff --git a/examples_frames/koala_sr/00002.jpg b/examples_frames/koala_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f58f382152dcca240d3172333a1acc27edd8eac Binary files /dev/null and b/examples_frames/koala_sr/00002.jpg differ diff --git a/examples_frames/koala_sr/00003.jpg b/examples_frames/koala_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c4481de7bc80fc7bf09d56b9858ceb57780fd9a4 Binary files /dev/null and b/examples_frames/koala_sr/00003.jpg differ diff --git a/examples_frames/koala_sr/00004.jpg b/examples_frames/koala_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..727f552089391e724120cba1bf2dc0ce32a04527 Binary files /dev/null and b/examples_frames/koala_sr/00004.jpg differ diff --git a/examples_frames/koala_sr/00005.jpg b/examples_frames/koala_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..169425f92c9197b7300892174d03be6e42927a28 Binary files /dev/null and b/examples_frames/koala_sr/00005.jpg differ diff --git a/examples_frames/koala_sr/00006.jpg b/examples_frames/koala_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..395d371b40ba941c0723a0030a0ae92d429048c9 Binary files /dev/null and b/examples_frames/koala_sr/00006.jpg differ diff --git a/examples_frames/koala_sr/00007.jpg b/examples_frames/koala_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8c441b6f5736b3a05b6c18faec02417edeae409 Binary files /dev/null and b/examples_frames/koala_sr/00007.jpg differ diff --git a/examples_frames/koala_sr/00008.jpg b/examples_frames/koala_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8bc7dce25fe849948baccc7ab47f4bf1c23d1b3f Binary files /dev/null and b/examples_frames/koala_sr/00008.jpg differ diff --git a/examples_frames/koala_sr/00009.jpg b/examples_frames/koala_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9cc415c514c11d9a3260a7fbe059c5799d3f7002 Binary files /dev/null and b/examples_frames/koala_sr/00009.jpg differ diff --git a/examples_frames/koala_sr/00010.jpg b/examples_frames/koala_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad595afe38f2835416bf8ff40978eaa65005b741 Binary files /dev/null and b/examples_frames/koala_sr/00010.jpg differ diff --git a/examples_frames/koala_sr/00011.jpg b/examples_frames/koala_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76d8e816d14fd2f2150872f2641ec97ca17b071a Binary files /dev/null and b/examples_frames/koala_sr/00011.jpg differ diff --git a/examples_frames/koala_sr/00012.jpg b/examples_frames/koala_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a58072064d14860074ee8db483179e5ac0f4e2b5 Binary files /dev/null and b/examples_frames/koala_sr/00012.jpg differ diff --git a/examples_frames/koala_sr/00013.jpg b/examples_frames/koala_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a63e5c984b8136783364226ecd60f983ac162bc1 Binary files /dev/null and b/examples_frames/koala_sr/00013.jpg differ diff --git a/examples_frames/koala_sr/00014.jpg b/examples_frames/koala_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e61874b0ec0a66844229daf3c3f9a10a651b0217 Binary files /dev/null and b/examples_frames/koala_sr/00014.jpg differ diff --git a/examples_frames/koala_sr/00015.jpg b/examples_frames/koala_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..58a9a90197d4d02ed0220646439ec88e55ea5232 Binary files /dev/null and b/examples_frames/koala_sr/00015.jpg differ diff --git a/examples_frames/koala_sr/00016.jpg b/examples_frames/koala_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8e73c2227240eec14c3f36ef29805ed2151edc2c Binary files /dev/null and b/examples_frames/koala_sr/00016.jpg differ diff --git a/examples_frames/koala_sr/00017.jpg b/examples_frames/koala_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ee83a4070189962f49e37758c874922c4c0d5f9 Binary files /dev/null and b/examples_frames/koala_sr/00017.jpg differ diff --git a/examples_frames/koala_sr/00018.jpg b/examples_frames/koala_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a7a2fc8127018082a2a07a645db58d884c624b93 Binary files /dev/null and b/examples_frames/koala_sr/00018.jpg differ diff --git a/examples_frames/koala_sr/00019.jpg b/examples_frames/koala_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a432b09b11c5c31bb6f522bc4a31bd632bb84bde Binary files /dev/null and b/examples_frames/koala_sr/00019.jpg differ diff --git a/examples_frames/rhino/00000.jpg b/examples_frames/rhino/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c3e9083d12272605633b1d9e1e1de3f16dbaf87 Binary files /dev/null and b/examples_frames/rhino/00000.jpg differ diff --git a/examples_frames/rhino/00001.jpg b/examples_frames/rhino/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1cb44473a42c9989d3279bdf6c19a2f2dc2445e Binary files /dev/null and b/examples_frames/rhino/00001.jpg differ diff --git a/examples_frames/rhino/00002.jpg b/examples_frames/rhino/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bd661afa34684d47f67b2ea67be73dd4a04bf096 Binary files /dev/null and b/examples_frames/rhino/00002.jpg differ diff --git a/examples_frames/rhino/00003.jpg b/examples_frames/rhino/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1aab3f13146a71c1a1ba664168384b90176b1bf Binary files /dev/null and b/examples_frames/rhino/00003.jpg differ diff --git a/examples_frames/rhino/00004.jpg b/examples_frames/rhino/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..07cc0c2f8df4b695f32eaf94a64ef7bde20ed878 Binary files /dev/null and b/examples_frames/rhino/00004.jpg differ diff --git a/examples_frames/rhino/00005.jpg b/examples_frames/rhino/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b51b7f73dbd0db3255434fcf9ebbbcac3f2a0a6f Binary files /dev/null and b/examples_frames/rhino/00005.jpg differ diff --git a/examples_frames/rhino/00006.jpg b/examples_frames/rhino/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..201407e996f344c0f39a6b1054aafc0d31a6bf2e Binary files /dev/null and b/examples_frames/rhino/00006.jpg differ diff --git a/examples_frames/rhino/00007.jpg b/examples_frames/rhino/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d3a908df0f6ec51e36f6fc46e7fc9653123b1094 Binary files /dev/null and b/examples_frames/rhino/00007.jpg differ diff --git a/examples_frames/rhino/00008.jpg b/examples_frames/rhino/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05c5855ea93c3b941c32ba0c9d2003b01550a854 Binary files /dev/null and b/examples_frames/rhino/00008.jpg differ diff --git a/examples_frames/rhino/00009.jpg b/examples_frames/rhino/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1bce91a5f3dd9d4b7bef6ee8802e8d163a961c1e Binary files /dev/null and b/examples_frames/rhino/00009.jpg differ diff --git a/examples_frames/rhino/00010.jpg b/examples_frames/rhino/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cc7f77ffbba62ca57eb8df4d8fa879a105722a7c Binary files /dev/null and b/examples_frames/rhino/00010.jpg differ diff --git a/examples_frames/rhino/00011.jpg b/examples_frames/rhino/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5fca7fc05f7cff22e30aceefa6c23681ef17de8 Binary files /dev/null and b/examples_frames/rhino/00011.jpg differ diff --git a/examples_frames/rhino/00012.jpg b/examples_frames/rhino/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2bcf9ccd9b21170b22e96fab44e976038cdf290 Binary files /dev/null and b/examples_frames/rhino/00012.jpg differ diff --git a/examples_frames/rhino/00013.jpg b/examples_frames/rhino/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b04c4dedf14a53a96212c6a9901030dfcafd8ee Binary files /dev/null and b/examples_frames/rhino/00013.jpg differ diff --git a/examples_frames/rhino/00014.jpg b/examples_frames/rhino/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..492ca78c00936730c877ccd2491a5363c5e24cfc Binary files /dev/null and b/examples_frames/rhino/00014.jpg differ diff --git a/examples_frames/rhino/00015.jpg b/examples_frames/rhino/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6c19d7b161d0d980ba959b0f373616ee44a13ad3 Binary files /dev/null and b/examples_frames/rhino/00015.jpg differ diff --git a/examples_frames/rhino/00016.jpg b/examples_frames/rhino/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6833b39be0c76b0e2f02d73a828519b9e2cbe800 Binary files /dev/null and b/examples_frames/rhino/00016.jpg differ diff --git a/examples_frames/rhino/00017.jpg b/examples_frames/rhino/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cee8d6f222f338c3ac82da21632e2189d71cad8c Binary files /dev/null and b/examples_frames/rhino/00017.jpg differ diff --git a/examples_frames/rhino/00018.jpg b/examples_frames/rhino/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..39aa745a9abcead92398961f81cbd43ca24085c9 Binary files /dev/null and b/examples_frames/rhino/00018.jpg differ diff --git a/examples_frames/rhino/00019.jpg b/examples_frames/rhino/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dd0e5974201396c686b08ff9f215213a683a3736 Binary files /dev/null and b/examples_frames/rhino/00019.jpg differ diff --git a/examples_frames/rhino_sr/00000.jpg b/examples_frames/rhino_sr/00000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..238a8e0284a7c7c9ca0300d39390f6140f8c6c6f Binary files /dev/null and b/examples_frames/rhino_sr/00000.jpg differ diff --git a/examples_frames/rhino_sr/00001.jpg b/examples_frames/rhino_sr/00001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7c54457c50fc5f15af60adf6bee9c88c4647b3c9 Binary files /dev/null and b/examples_frames/rhino_sr/00001.jpg differ diff --git a/examples_frames/rhino_sr/00002.jpg b/examples_frames/rhino_sr/00002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..da8e715758e7640d0128030772f33e171d088d90 Binary files /dev/null and b/examples_frames/rhino_sr/00002.jpg differ diff --git a/examples_frames/rhino_sr/00003.jpg b/examples_frames/rhino_sr/00003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7105cb65d5a11d00af5a366c8953b30747f16123 Binary files /dev/null and b/examples_frames/rhino_sr/00003.jpg differ diff --git a/examples_frames/rhino_sr/00004.jpg b/examples_frames/rhino_sr/00004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d766def544f477b7d329a7fbcb2ffb1adcffe0be Binary files /dev/null and b/examples_frames/rhino_sr/00004.jpg differ diff --git a/examples_frames/rhino_sr/00005.jpg b/examples_frames/rhino_sr/00005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..112d9c1de31729d85691f8e31deceeb5ffe73578 Binary files /dev/null and b/examples_frames/rhino_sr/00005.jpg differ diff --git a/examples_frames/rhino_sr/00006.jpg b/examples_frames/rhino_sr/00006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3bdba7a6c56c95fbec705d1144ce657f85dbbb20 Binary files /dev/null and b/examples_frames/rhino_sr/00006.jpg differ diff --git a/examples_frames/rhino_sr/00007.jpg b/examples_frames/rhino_sr/00007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce07cc8bd43d59f3e6c1bf3833243a9958b7c99c Binary files /dev/null and b/examples_frames/rhino_sr/00007.jpg differ diff --git a/examples_frames/rhino_sr/00008.jpg b/examples_frames/rhino_sr/00008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9292b16fa73fd9de3c5ae82460db34887b72464a Binary files /dev/null and b/examples_frames/rhino_sr/00008.jpg differ diff --git a/examples_frames/rhino_sr/00009.jpg b/examples_frames/rhino_sr/00009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb2717b5fdf18c342ab7e30e5d58eba5d602f2dc Binary files /dev/null and b/examples_frames/rhino_sr/00009.jpg differ diff --git a/examples_frames/rhino_sr/00010.jpg b/examples_frames/rhino_sr/00010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f0c909d3820b16f31cef3d63f723b4206fd0e89 Binary files /dev/null and b/examples_frames/rhino_sr/00010.jpg differ diff --git a/examples_frames/rhino_sr/00011.jpg b/examples_frames/rhino_sr/00011.jpg new file mode 100644 index 0000000000000000000000000000000000000000..074a3117071fabf559f86a5da751c3640ffb0e95 Binary files /dev/null and b/examples_frames/rhino_sr/00011.jpg differ diff --git a/examples_frames/rhino_sr/00012.jpg b/examples_frames/rhino_sr/00012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cbb8ea812ae3b7d71e95514624a8394a3c890c6a Binary files /dev/null and b/examples_frames/rhino_sr/00012.jpg differ diff --git a/examples_frames/rhino_sr/00013.jpg b/examples_frames/rhino_sr/00013.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ea35d70034471eb12ec31d133e86e0d57e28ba7e Binary files /dev/null and b/examples_frames/rhino_sr/00013.jpg differ diff --git a/examples_frames/rhino_sr/00014.jpg b/examples_frames/rhino_sr/00014.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d6842357d0607f779ccd9a376799e18e7fda585 Binary files /dev/null and b/examples_frames/rhino_sr/00014.jpg differ diff --git a/examples_frames/rhino_sr/00015.jpg b/examples_frames/rhino_sr/00015.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbbfc445cabaedfe1e7dca028c1c2aa11962888b Binary files /dev/null and b/examples_frames/rhino_sr/00015.jpg differ diff --git a/examples_frames/rhino_sr/00016.jpg b/examples_frames/rhino_sr/00016.jpg new file mode 100644 index 0000000000000000000000000000000000000000..89f71430efc347f44f94527a343bc2eb90f9b0c6 Binary files /dev/null and b/examples_frames/rhino_sr/00016.jpg differ diff --git a/examples_frames/rhino_sr/00017.jpg b/examples_frames/rhino_sr/00017.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2d7f45a658e0cdcb8ed2ee4b73ea6db1e855691d Binary files /dev/null and b/examples_frames/rhino_sr/00017.jpg differ diff --git a/examples_frames/rhino_sr/00018.jpg b/examples_frames/rhino_sr/00018.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b44d64e72039c0ed6456239ba7e748eb07c3ccf Binary files /dev/null and b/examples_frames/rhino_sr/00018.jpg differ diff --git a/examples_frames/rhino_sr/00019.jpg b/examples_frames/rhino_sr/00019.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03b8350897e03fbc302c2f1422c7f925e0fdd383 Binary files /dev/null and b/examples_frames/rhino_sr/00019.jpg differ diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0071fff6d776ad1e04fe4c817f91ce501cfaef5b --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,12 @@ +from . import config + +from .controlnet import ControlledUnetModel, ControlNet +from .vae import AutoencoderKL +from .clip import FrozenOpenCLIPEmbedder + +from .cldm import ControlLDM +from .gaussian_diffusion import Diffusion + +from .swinir import SwinIR +from .bsrnet import RRDBNet +from .scunet import SCUNet diff --git a/model/__pycache__/__init__.cpython-310.pyc b/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9a6492e937e527ac2142428e3f5e4e86212484c Binary files /dev/null and b/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e9b193e8f02039e4e265edeb42d4ad9cf429c34 Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/model/__pycache__/attention.cpython-310.pyc b/model/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e25f7ad373d7db487a4e6ad05b7bafe47e6a5bc7 Binary files /dev/null and b/model/__pycache__/attention.cpython-310.pyc differ diff --git a/model/__pycache__/attention.cpython-39.pyc b/model/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..365698d3788dfcd442b7a87f6236602a24f49bc4 Binary files /dev/null and b/model/__pycache__/attention.cpython-39.pyc differ diff --git a/model/__pycache__/bsrnet.cpython-310.pyc b/model/__pycache__/bsrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdc0bafa5103024ed4341aea613db4cb15405fb3 Binary files /dev/null and b/model/__pycache__/bsrnet.cpython-310.pyc differ diff --git a/model/__pycache__/bsrnet.cpython-39.pyc b/model/__pycache__/bsrnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad6bf067b217f3aea1565bfdda15039e9973517b Binary files /dev/null and b/model/__pycache__/bsrnet.cpython-39.pyc differ diff --git a/model/__pycache__/cldm.cpython-310.pyc b/model/__pycache__/cldm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f866791af56e188af0a2f077837c02de79cfe7f Binary files /dev/null and b/model/__pycache__/cldm.cpython-310.pyc differ diff --git a/model/__pycache__/cldm.cpython-39.pyc b/model/__pycache__/cldm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97fc27b2df4f6374a5dfecc140d91e410222681a Binary files /dev/null and b/model/__pycache__/cldm.cpython-39.pyc differ diff --git a/model/__pycache__/clip.cpython-310.pyc b/model/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d0dd10953902ce7c2a473ff9e8ab18c167ab0df Binary files /dev/null and b/model/__pycache__/clip.cpython-310.pyc differ diff --git a/model/__pycache__/clip.cpython-39.pyc b/model/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e60b13b46cde33ee1abd5c7ade2ffe084b49e20d Binary files /dev/null and b/model/__pycache__/clip.cpython-39.pyc differ diff --git a/model/__pycache__/config.cpython-310.pyc b/model/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b763e2f0e9d6206e1f372faacf671583a49aa23 Binary files /dev/null and b/model/__pycache__/config.cpython-310.pyc differ diff --git a/model/__pycache__/config.cpython-39.pyc b/model/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516b624b645d14a804b34efd6ea52789d47b1ff9 Binary files /dev/null and b/model/__pycache__/config.cpython-39.pyc differ diff --git a/model/__pycache__/controlnet.cpython-310.pyc b/model/__pycache__/controlnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3218a4dcfe2092eee050c1b32c4ac2c68f1afe9e Binary files /dev/null and b/model/__pycache__/controlnet.cpython-310.pyc differ diff --git a/model/__pycache__/controlnet.cpython-39.pyc b/model/__pycache__/controlnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bde0130dddd3ce42ef3bf382d6e47f443891f10e Binary files /dev/null and b/model/__pycache__/controlnet.cpython-39.pyc differ diff --git a/model/__pycache__/distributions.cpython-310.pyc b/model/__pycache__/distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c20011d4026a9bbdd5ea754aff0d8653d6f09f5a Binary files /dev/null and b/model/__pycache__/distributions.cpython-310.pyc differ diff --git a/model/__pycache__/distributions.cpython-39.pyc b/model/__pycache__/distributions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1912e166798afbdfd3b2147542fab6c5c85e7fa2 Binary files /dev/null and b/model/__pycache__/distributions.cpython-39.pyc differ diff --git a/model/__pycache__/gaussian_diffusion.cpython-310.pyc b/model/__pycache__/gaussian_diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bafaa8c6839ced8fd6721ff56c4ab70fcb3cffd6 Binary files /dev/null and b/model/__pycache__/gaussian_diffusion.cpython-310.pyc differ diff --git a/model/__pycache__/gaussian_diffusion.cpython-39.pyc b/model/__pycache__/gaussian_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f58f6a73b5b5caceb773f6af8c347748c95a5d8 Binary files /dev/null and b/model/__pycache__/gaussian_diffusion.cpython-39.pyc differ diff --git a/model/__pycache__/scunet.cpython-310.pyc b/model/__pycache__/scunet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8df2532a519c0776df822de4ac5871fab900591a Binary files /dev/null and b/model/__pycache__/scunet.cpython-310.pyc differ diff --git a/model/__pycache__/scunet.cpython-39.pyc b/model/__pycache__/scunet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c16d865667d34761e68d47ce238a0310958806d2 Binary files /dev/null and b/model/__pycache__/scunet.cpython-39.pyc differ diff --git a/model/__pycache__/swinir.cpython-310.pyc b/model/__pycache__/swinir.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f23acd9d78afd1411df5da8742ab91152099f28 Binary files /dev/null and b/model/__pycache__/swinir.cpython-310.pyc differ diff --git a/model/__pycache__/swinir.cpython-39.pyc b/model/__pycache__/swinir.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5185b75c4d5d4397a79c2fd7abb036776d9800ab Binary files /dev/null and b/model/__pycache__/swinir.cpython-39.pyc differ diff --git a/model/__pycache__/unet.cpython-310.pyc b/model/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfd2b117732914d7da301bc5894379e7995fcd1e Binary files /dev/null and b/model/__pycache__/unet.cpython-310.pyc differ diff --git a/model/__pycache__/unet.cpython-39.pyc b/model/__pycache__/unet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d59d154e2547c56b7dd092e80ffb6dde541d9336 Binary files /dev/null and b/model/__pycache__/unet.cpython-39.pyc differ diff --git a/model/__pycache__/util.cpython-310.pyc b/model/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a6f1803934941d0d4bad0c6585bb38d9f7c199 Binary files /dev/null and b/model/__pycache__/util.cpython-310.pyc differ diff --git a/model/__pycache__/util.cpython-39.pyc b/model/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34cc4a24424f92d609ff6230371f7c72cd598e9b Binary files /dev/null and b/model/__pycache__/util.cpython-39.pyc differ diff --git a/model/__pycache__/vae.cpython-310.pyc b/model/__pycache__/vae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8a2ec8ce2275fdd16e2f904ab38bbbc5c6fb185 Binary files /dev/null and b/model/__pycache__/vae.cpython-310.pyc differ diff --git a/model/__pycache__/vae.cpython-39.pyc b/model/__pycache__/vae.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bf3eb9c1fe087d66a3bc61e75019183662a7c6a Binary files /dev/null and b/model/__pycache__/vae.cpython-39.pyc differ diff --git a/model/attention.py b/model/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..e1cd96cc457b4bade9565036396e00083d0d162f --- /dev/null +++ b/model/attention.py @@ -0,0 +1,299 @@ +from packaging import version +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from model.util import ( + checkpoint, zero_module, exists, default +) +from model.config import Config, AttnMode + + +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + print(f"Setting up {self.__class__.__name__} (vanilla). Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + # with torch.autocast(enabled=False, device_type = 'cuda'): + with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__} (xformers). Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + # print(context_dim, query_dim) + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + # import ipdb; ipdb.set_trace() + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class SDPCrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__} (sdp). Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = F.scaled_dot_product_attention(q, k, v) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + AttnMode.VANILLA: CrossAttention, # vanilla attention + AttnMode.XFORMERS: MemoryEfficientCrossAttention, + AttnMode.SDP: SDPCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = self.ATTENTION_MODES[Config.attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None, label=None): + return checkpoint(self._forward, (x, context, label), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None, label=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None, label=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], label=label) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/model/bsrnet.py b/model/bsrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..376066cad431acff14f76e5477d89fe2712e667f --- /dev/null +++ b/model/bsrnet.py @@ -0,0 +1,104 @@ +# From BSRGAN +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.sf = sf + print([in_nc, out_nc, nf, nb, gc, sf]) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.sf==4: + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + if self.sf==4: + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out diff --git a/model/cldm.py b/model/cldm.py new file mode 100644 index 0000000000000000000000000000000000000000..2956c657d962ea03286c165acef4375aef81bd0d --- /dev/null +++ b/model/cldm.py @@ -0,0 +1,192 @@ +from typing import Tuple, Set, List, Dict + +import torch +from torch import nn + +from model import ( + ControlledUnetModel, ControlNet, + AutoencoderKL, FrozenOpenCLIPEmbedder +) +from utils.common import sliding_windows, count_vram_usage, gaussian_weights + + +def disabled_train(self: nn.Module) -> nn.Module: + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class ControlLDM(nn.Module): + + def __init__( + self, + unet_cfg, + vae_cfg, + clip_cfg, + controlnet_cfg, + latent_scale_factor, + VidToMe_cfg, + latent_warp_cfg, + ): + super().__init__() + self.unet = ControlledUnetModel(**unet_cfg) + self.vae = AutoencoderKL(**vae_cfg) + self.clip = FrozenOpenCLIPEmbedder(**clip_cfg) + self.controlnet = ControlNet(**controlnet_cfg) + self.scale_factor = latent_scale_factor + self.control_scales = [1.0] * 13 + + self.latent_control = latent_warp_cfg.latent_control + self.latent_warp_period = latent_warp_cfg.warp_period + self.latent_merge_period = latent_warp_cfg.merge_period + self.controller = None + + self.ToMe_period = VidToMe_cfg.ToMe_period + + self.merge_ratio = VidToMe_cfg.merge_ratio + self.merge_global = VidToMe_cfg.merge_global + self.global_merge_ratio = VidToMe_cfg.global_merge_ratio + self.seed = VidToMe_cfg.seed + self.batch_size = VidToMe_cfg.batch_size + self.align_batch = VidToMe_cfg.align_batch + self.global_rand = VidToMe_cfg.global_rand + + if self.latent_control: + from controller.controller import AttentionControl + self.controller = AttentionControl(warp_period=self.latent_warp_period, \ + merge_period=self.latent_merge_period, \ + ToMe_period=self.ToMe_period, \ + merge_ratio=self.merge_ratio, ) + + # if self.ToMe: + if self.ToMe_period[0] == 0: + print("[INFO] activate token merging ") + self.activate_vidtome() + + + def activate_vidtome(self): + import vidtome + # import ipdb; ipdb.set_trace() + vidtome.apply_patch(self, self.merge_ratio[0], self.merge_global, self.global_merge_ratio, + seed = self.seed, batch_size = self.batch_size, align_batch = self.align_batch, global_rand = self.global_rand) + + + @torch.no_grad() + def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]: + module_map = { + "unet": "model.diffusion_model", + "vae": "first_stage_model", + "clip": "cond_stage_model", + } + modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)] + used = set() + for name, module in modules: + init_sd = {} + scratch_sd = module.state_dict() + for key in scratch_sd: + target_key = ".".join([module_map[name], key]) + init_sd[key] = sd[target_key].clone() + used.add(target_key) + module.load_state_dict(init_sd, strict=True) + unused = set(sd.keys()) - used + # NOTE: this is slightly different from previous version, which haven't switched + # the UNet to eval mode and disabled the requires_grad flag. + for module in [self.vae, self.clip, self.unet]: + module.eval() + module.train = disabled_train + for p in module.parameters(): + p.requires_grad = False + return unused + + @torch.no_grad() + def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None: + self.controlnet.load_state_dict(sd, strict=True) + + @torch.no_grad() + def load_controlnet_from_unet(self) -> Tuple[Set[str]]: + unet_sd = self.unet.state_dict() + scratch_sd = self.controlnet.state_dict() + init_sd = {} + init_with_new_zero = set() + init_with_scratch = set() + for key in scratch_sd: + if key in unet_sd: + this, target = scratch_sd[key], unet_sd[key] + if this.size() == target.size(): + init_sd[key] = target.clone() + else: + d_ic = this.size(1) - target.size(1) + oc, _, h, w = this.size() + zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype) + init_sd[key] = torch.cat((target, zeros), dim=1) + init_with_new_zero.add(key) + else: + init_sd[key] = scratch_sd[key].clone() + init_with_scratch.add(key) + self.controlnet.load_state_dict(init_sd, strict=True) + return init_with_new_zero, init_with_scratch + + def vae_encode(self, image: torch.Tensor, sample: bool=True, batch_size: int=0) -> torch.Tensor: + if sample: + return self.vae.encode(image, batch_size=batch_size).sample() * self.scale_factor + else: + return self.vae.encode(image, batch_size=batch_size).mode() * self.scale_factor + + def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, sample: bool=True) -> torch.Tensor: + bs, _, h, w = image.shape + z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device) + count = torch.zeros_like(z, dtype=torch.float32) + weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] + weights = torch.tensor(weights, dtype=torch.float32, device=image.device) + tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8) + for hi, hi_end, wi, wi_end in tiles: + tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] + z[:, :, hi:hi_end, wi:wi_end] += self.vae_encode(tile_image, sample=sample) * weights + count[:, :, hi:hi_end, wi:wi_end] += weights + z.div_(count) + return z + + def vae_decode(self, z: torch.Tensor, batch_size: int=0) -> torch.Tensor: + return self.vae.decode(z / self.scale_factor, batch_size=batch_size) + + @count_vram_usage + def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor: + bs, _, h, w = z.shape + image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device) + count = torch.zeros_like(image, dtype=torch.float32) + weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None] + weights = torch.tensor(weights, dtype=torch.float32, device=z.device) + tiles = sliding_windows(h, w, tile_size, tile_stride) + for hi, hi_end, wi, wi_end in tiles: + tile_z = z[:, :, hi:hi_end, wi:wi_end] + image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights + count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights + image.div_(count) + return image + + def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]: + return dict( + c_txt=self.clip.encode(txt), + c_img=self.vae_encode(clean * 2 - 1, sample=False, batch_size=5) + ) + + @count_vram_usage + def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[str, torch.Tensor]: + return dict( + c_txt=self.clip.encode(txt), + c_img=self.vae_encode_tiled(clean * 2 - 1, tile_size, tile_stride, sample=False) + ) + + def forward(self, x_noisy, t, cond): + c_txt = cond["c_txt"] + c_img = cond["c_img"] + control = self.controlnet( + x_noisy, hint=c_img, + timesteps=t, context=c_txt + ) + control = [c * scale for c, scale in zip(control, self.control_scales)] + eps = self.unet( + x_noisy, timesteps=t, + context=c_txt, control=control, only_mid_control=False + ) + return eps diff --git a/model/clip.py b/model/clip.py new file mode 100755 index 0000000000000000000000000000000000000000..01535ad364a44bed72f416dabe0762fe2dcef78f --- /dev/null +++ b/model/clip.py @@ -0,0 +1,65 @@ +from typing import List +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from model.open_clip import CLIP, tokenize + +### pretrained model path +# _VITH14 = dict( +# laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), +# ) + +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"): + super().__init__() + assert layer in self.LAYERS + # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg)) + del model.visual + self.model = model + + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def forward(self, tokens): + z = self.encode_with_transformer(tokens) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text: List[str]) -> torch.Tensor: + # convert a batch of text to tensor + tokens = tokenize(text) + # move tensor to model device + tokens = tokens.to(next(self.model.parameters()).device) + return self(tokens) diff --git a/model/config.py b/model/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8d698f1abad4eaba1bfe12f8851fe5c05e45bab1 --- /dev/null +++ b/model/config.py @@ -0,0 +1,62 @@ +import os +from typing import Optional, Literal +from types import ModuleType +import enum +from packaging import version + +import torch + +# collect system information +if version.parse(torch.__version__) >= version.parse("2.0.0"): + SDP_IS_AVAILABLE = True +else: + SDP_IS_AVAILABLE = False + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + + +class AttnMode(enum.Enum): + SDP = 0 + XFORMERS = 1 + VANILLA = 2 + + +class Config: + xformers: Optional[ModuleType] = None + attn_mode: AttnMode = AttnMode.VANILLA + + +# initialize attention mode +if XFORMERS_IS_AVAILBLE: + Config.attn_mode = AttnMode.XFORMERS + print(f"use xformers attention as default") +elif SDP_IS_AVAILABLE: + Config.attn_mode = AttnMode.SDP + print(f"use sdp attention as default") +else: + print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") + +if XFORMERS_IS_AVAILBLE: + Config.xformers = xformers + + +# user-specified attention mode +ATTN_MODE = os.environ.get("ATTN_MODE", None) +if ATTN_MODE is not None: + assert ATTN_MODE in ["vanilla", "sdp", "xformers"] + if ATTN_MODE == "sdp": + assert SDP_IS_AVAILABLE + Config.attn_mode = AttnMode.SDP + elif ATTN_MODE == "xformers": + assert XFORMERS_IS_AVAILBLE + Config.attn_mode = AttnMode.XFORMERS + else: + Config.attn_mode = AttnMode.VANILLA + print(f"set attention mode to {ATTN_MODE}") +else: + print("keep default attention mode") diff --git a/model/controlnet.py b/model/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9f54690bf0f9d9b4796267d00f40eb230822d2 --- /dev/null +++ b/model/controlnet.py @@ -0,0 +1,281 @@ +import torch +import torch as th +import torch.nn as nn + +from model.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, + exists +) +from model.attention import SpatialTransformer +from model.unet import ( + TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel +) + + +class ControlledUnetModel(UNetModel): + + def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + h = x.type(self.dtype) + # print(f"[INFO] attention head: {h.shape}") + for module in self.input_blocks: + h = module(h, emb, context, label="unet_down") + hs.append(h) + h = self.middle_block(h, emb, context, label="unet_mid") + + if control is not None: + h += control.pop() + for i, module in enumerate(self.output_blocks): + if only_mid_control or control is None: + h = torch.cat([h, hs.pop()], dim=1) + else: + h = torch.cat([h, hs.pop() + control.pop()], dim=1) + h = module(h, emb, context, label="unet_up") + + h = h.type(x.dtype) + # print(f"[INFO] attention out: {self.out(h).shape}") + return self.out(h) + + +class ControlNet(nn.Module): + + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + x = torch.cat((x, hint), dim=1) + outs = [] + + h = x.type(self.dtype) + # print(f"[INFO] controlnet attention head: {h.shape}") + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + # print(f"[INFO] {module.__class__.__name__}") + h = module(h, emb, context, label="controlnet_down") + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context, label="controlnet_mid") + outs.append(self.middle_block_out(h, emb, context, label="controlnet_up")) + # import ipdb; ipdb.set_trace() + # print(f"[INFO] controlnet attention out: {outs[0].shape}") + return outs diff --git a/model/distributions.py b/model/distributions.py new file mode 100755 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/model/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/model/gaussian_diffusion.py b/model/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..54f28049cca8fe5346d2e5a9f659a93252c7747e --- /dev/null +++ b/model/gaussian_diffusion.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Tuple + +import torch +from torch import nn +import numpy as np + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = np.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) + elif schedule == "sqrt": + betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas + + +def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor: + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class Diffusion(nn.Module): + + def __init__( + self, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + parameterization="eps" + ): + super().__init__() + self.num_timesteps = timesteps + self.beta_schedule = beta_schedule + self.linear_start = linear_start + self.linear_end = linear_end + self.cosine_s = cosine_s + assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'" + self.parameterization = parameterization + self.loss_type = loss_type + + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod) + + self.betas = betas + self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod) + self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod) + + def register(self, name: str, value: np.ndarray) -> None: + self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) + + def q_sample(self, x_start, t, noise): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, model, x_start, t, cond): + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = model(x_noisy, t, cond) + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean() + return loss_simple diff --git a/model/open_clip/__init__.py b/model/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cacfcb7518888ce81abe7dc5cd2067522b60097 --- /dev/null +++ b/model/open_clip/__init__.py @@ -0,0 +1,4 @@ +from .model import CLIP +from .tokenizer import tokenize + +__all__ = ["CLIP", "tokenize"] diff --git a/model/open_clip/__pycache__/__init__.cpython-310.pyc b/model/open_clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a5a31509294743d7e0ba2c35e9189e628fba45b Binary files /dev/null and b/model/open_clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/open_clip/__pycache__/__init__.cpython-39.pyc b/model/open_clip/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebceb86ae47b53d3da0387fb86be2fd7bb257df Binary files /dev/null and b/model/open_clip/__pycache__/__init__.cpython-39.pyc differ diff --git a/model/open_clip/__pycache__/model.cpython-310.pyc b/model/open_clip/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db095516b51bd799a38907657376cfb447fd1eb0 Binary files /dev/null and b/model/open_clip/__pycache__/model.cpython-310.pyc differ diff --git a/model/open_clip/__pycache__/model.cpython-39.pyc b/model/open_clip/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f0537a6cd99cd3b8eda7acc7d3060b3abfdc32 Binary files /dev/null and b/model/open_clip/__pycache__/model.cpython-39.pyc differ diff --git a/model/open_clip/__pycache__/tokenizer.cpython-310.pyc b/model/open_clip/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d097529772b9ced5261457c644408cd131c43cc6 Binary files /dev/null and b/model/open_clip/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/model/open_clip/__pycache__/tokenizer.cpython-39.pyc b/model/open_clip/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a19d57a1ecb64c027171438cccea97f35d00ec26 Binary files /dev/null and b/model/open_clip/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/model/open_clip/__pycache__/transformer.cpython-310.pyc b/model/open_clip/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fed1d996e51fa8df580a6fbd1dc7029826017f70 Binary files /dev/null and b/model/open_clip/__pycache__/transformer.cpython-310.pyc differ diff --git a/model/open_clip/__pycache__/transformer.cpython-39.pyc b/model/open_clip/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a6f82fd50af228aed82194d5bd32b9342cdb6d3 Binary files /dev/null and b/model/open_clip/__pycache__/transformer.cpython-39.pyc differ diff --git a/model/open_clip/bpe_simple_vocab_16e6.txt.gz b/model/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/model/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/model/open_clip/model.py b/model/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d940c9c592b6d7dccd54e233bca33144c1020787 --- /dev/null +++ b/model/open_clip/model.py @@ -0,0 +1,206 @@ +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .transformer import LayerNormFp32, LayerNorm, QuickGELU, VisionTransformer, TextTransformer + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design + global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer + n_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + output_tokens: bool = False + + timm_model_name: str = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + width: int = 512 + heads: int = 8 + layers: int = 12 + ls_init_value: Optional[float] = None # layer scale initial value + hf_model_name: str = None + hf_tokenizer_name: str = None + hf_model_pretrained: bool = True + proj: str = 'mlp' + pooler_type: str = 'mean_pooler' + embed_cls: bool = False + pad_id: int = 0 + output_tokens: bool = False + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + input_patchnorm=vision_cfg.input_patchnorm, + global_average_pool=vision_cfg.global_average_pool, + attentional_pool=vision_cfg.attentional_pool, + n_queries=vision_cfg.n_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + output_tokens=text_cfg.output_tokens, + pad_id=text_cfg.pad_id, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + return F.normalize(x, dim=-1) if normalize else x + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + if self.output_dict: + return { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + return image_features, text_features, self.logit_scale.exp() diff --git a/model/open_clip/tokenizer.py b/model/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..23fcfcbcb4ca051ba5bba7520918693001999282 --- /dev/null +++ b/model/open_clip/tokenizer.py @@ -0,0 +1,214 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + +# https://stackoverflow.com/q/62691279 +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + +def decode(output_ids: torch.Tensor): + output_ids = output_ids.cpu().numpy() + return _tokenizer.decode(output_ids) + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + tokens[-1] = eot_token + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + +class HFTokenizer: + """HuggingFace tokenizer wrapper""" + + def __init__(self, tokenizer_name: str): + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + def save_pretrained(self, dest): + self.tokenizer.save_pretrained(dest) + + def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: + # same cleaning as for default tokenizer, except lowercasing + # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance + if isinstance(texts, str): + texts = [texts] + texts = [whitespace_clean(basic_clean(text)) for text in texts] + input_ids = self.tokenizer( + texts, + return_tensors='pt', + max_length=context_length, + padding='max_length', + truncation=True, + ).input_ids + return input_ids diff --git a/model/open_clip/transformer.py b/model/open_clip/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b356ad7a995a17f26c4e45ed2a370ef0b1f9f8 --- /dev/null +++ b/model/open_clip/transformer.py @@ -0,0 +1,736 @@ +import collections +from collections import OrderedDict +import math +from typing import Callable, Optional, Sequence, Tuple +from itertools import repeat + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +to_2tuple = _ntuple(2) + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__(self, prob, exclude_first_token=True): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + logit_scale_max=math.log(1. / 0.01), + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + attn += attn_mask + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = torch.bmm(attn, v) + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + x = x.transpose(0, 1).reshape(L, N, C) + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] + return out.permute(1, 0, 2) # LND -> NLD + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + global_average_pool: bool = False, + attentional_pool: bool = False, + n_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + input_patchnorm: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False + ): + super().__init__() + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 + self.input_patchnorm = input_patchnorm + + if input_patchnorm: + patch_input_dim = patch_height * patch_width * 3 + self.patchnorm_pre_ln = LayerNorm(patch_input_dim) + self.conv1 = nn.Linear(patch_input_dim, width) + else: + self.patchnorm_pre_ln = nn.Identity() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.global_average_pool = global_average_pool + if attentional_pool: + self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) + else: + self.attn_pool = None + self.ln_post = norm_layer(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.global_average_pool: + return x.mean(dim=1), x + else: + return x[:, 0], x[:, 1:] + + def forward(self, x: torch.Tensor): + + # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 + if self.input_patchnorm: + # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') + x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) + x = x.permute(0, 2, 4, 1, 3, 5) + x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) + x = self.patchnorm_pre_ln(x) + x = self.conv1(x) + else: + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + x = self.patch_dropout(x) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + if self.attn_pool is not None: + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + else: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + ls_init_value: float = None, + output_dim: int = 512, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + embed_cls: bool = False, + pad_id: int = 0, + output_tokens: bool = False, + ): + super().__init__() + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _repeat(self, t, N: int): + return t.reshape(1, 1, -1).repeat(N, 1, 1) + + def forward(self, text): + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + if self.cls_emb is not None: + pooled, tokens = x[:, -1], x[:, :-1] + pooled = self.ln_final(pooled) + else: + x = self.ln_final(x) + pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x + + if self.text_projection is not None: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + ): + + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward(self, image_embs, text_embs): + text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + seq_len = text_embs.shape[0] + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) + text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + x = text_embs.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + if self.text_projection is not None: + x = x @ self.text_projection + + return x + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable diff --git a/model/scunet.py b/model/scunet.py new file mode 100644 index 0000000000000000000000000000000000000000..081128220579f0ab00f15890012ca2176fff2ff4 --- /dev/null +++ b/model/scunet.py @@ -0,0 +1,272 @@ +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_, DropPath + + +class WMSA(nn.Module): + """ Self-attention module in Swin Transformer + """ + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim ** -0.5 + self.n_heads = input_dim//head_dim + self.window_size = window_size + self.type=type + self.embedding_layer = nn.Linear(self.input_dim, 3*self.input_dim, bias=True) + + # TODO recover + # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1)) + self.relative_position_params = nn.Parameter(torch.zeros((2 * window_size - 1)*(2 * window_size -1), self.n_heads)) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=.02) + self.relative_position_params = torch.nn.Parameter(self.relative_position_params.view(2*window_size-1, 2*window_size-1, self.n_heads).transpose(1,2).transpose(0,1)) + + def generate_mask(self, h, w, p, shift): + """ generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting sqaure. + attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) + if self.type == 'W': + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') + return attn_mask + + def forward(self, x): + """ Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; + attn_mask: attention mask, fill -inf where the value is True; + Returns: + output: tensor shape [b h w c] + """ + if self.type!='W': x = torch.roll(x, shifts=(-(self.window_size//2), -(self.window_size//2)), dims=(1,2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) + h_windows = x.size(1) + w_windows = x.size(2) + # sqaure validation + # assert h_windows == w_windows + + x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) + qkv = self.embedding_layer(x) + q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) + sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale + # Adding learnable relative embedding + sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') + # Using Attn Mask to distinguish different subwindows. + if self.type != 'W': + attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size//2) + sim = sim.masked_fill_(attn_mask, float("-inf")) + + probs = nn.functional.softmax(sim, dim=-1) + output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) + output = rearrange(output, 'h b w p c -> b w p (h c)') + output = self.linear(output) + output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) + + if self.type!='W': output = torch.roll(output, shifts=(self.window_size//2, self.window_size//2), dims=(1,2)) + return output + + def relative_embedding(self): + cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) + relation = cord[:, None, :] - cord[None, :, :] + self.window_size -1 + # negative is allowed + return self.relative_position_params[:, relation[:,:,0].long(), relation[:,:,1].long()] + + +class Block(nn.Module): + def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer Block + """ + super(Block, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + assert type in ['W', 'SW'] + self.type = type + if input_resolution <= window_size: + self.type = 'W' + + print("Block Initial Type: {}, drop_path_rate:{:.6f}".format(self.type, drop_path)) + self.ln1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = nn.Sequential( + nn.Linear(input_dim, 4 * input_dim), + nn.GELU(), + nn.Linear(4 * input_dim, output_dim), + ) + + def forward(self, x): + x = x + self.drop_path(self.msa(self.ln1(x))) + x = x + self.drop_path(self.mlp(self.ln2(x))) + return x + + +class ConvTransBlock(nn.Module): + def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): + """ SwinTransformer and Conv Block + """ + super(ConvTransBlock, self).__init__() + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.head_dim = head_dim + self.window_size = window_size + self.drop_path = drop_path + self.type = type + self.input_resolution = input_resolution + + assert self.type in ['W', 'SW'] + if self.input_resolution <= self.window_size: + self.type = 'W' + + self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, self.type, self.input_resolution) + self.conv1_1 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True) + self.conv1_2 = nn.Conv2d(self.conv_dim+self.trans_dim, self.conv_dim+self.trans_dim, 1, 1, 0, bias=True) + + self.conv_block = nn.Sequential( + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) + ) + + def forward(self, x): + conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) + conv_x = self.conv_block(conv_x) + conv_x + trans_x = Rearrange('b c h w -> b h w c')(trans_x) + trans_x = self.trans_block(trans_x) + trans_x = Rearrange('b h w c -> b c h w')(trans_x) + res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) + x = x + res + + return x + + +class SCUNet(nn.Module): + + def __init__(self, in_nc=3, config=[2,2,2,2,2,2,2], dim=64, drop_path_rate=0.0, input_resolution=256): + super(SCUNet, self).__init__() + self.config = config + self.dim = dim + self.head_dim = 32 + self.window_size = 8 + + # drop path rate for each layer + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] + + self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] + + begin = 0 + self.m_down1 = [ConvTransBlock(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution) + for i in range(config[0])] + \ + [nn.Conv2d(dim, 2*dim, 2, 2, 0, bias=False)] + + begin += config[0] + self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2) + for i in range(config[1])] + \ + [nn.Conv2d(2*dim, 4*dim, 2, 2, 0, bias=False)] + + begin += config[1] + self.m_down3 = [ConvTransBlock(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4) + for i in range(config[2])] + \ + [nn.Conv2d(4*dim, 8*dim, 2, 2, 0, bias=False)] + + begin += config[2] + self.m_body = [ConvTransBlock(4*dim, 4*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//8) + for i in range(config[3])] + + begin += config[3] + self.m_up3 = [nn.ConvTranspose2d(8*dim, 4*dim, 2, 2, 0, bias=False),] + \ + [ConvTransBlock(2*dim, 2*dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW',input_resolution//4) + for i in range(config[4])] + + begin += config[4] + self.m_up2 = [nn.ConvTranspose2d(4*dim, 2*dim, 2, 2, 0, bias=False),] + \ + [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution//2) + for i in range(config[5])] + + begin += config[5] + self.m_up1 = [nn.ConvTranspose2d(2*dim, dim, 2, 2, 0, bias=False),] + \ + [ConvTransBlock(dim//2, dim//2, self.head_dim, self.window_size, dpr[i+begin], 'W' if not i%2 else 'SW', input_resolution) + for i in range(config[6])] + + self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] + + self.m_head = nn.Sequential(*self.m_head) + self.m_down1 = nn.Sequential(*self.m_down1) + self.m_down2 = nn.Sequential(*self.m_down2) + self.m_down3 = nn.Sequential(*self.m_down3) + self.m_body = nn.Sequential(*self.m_body) + self.m_up3 = nn.Sequential(*self.m_up3) + self.m_up2 = nn.Sequential(*self.m_up2) + self.m_up1 = nn.Sequential(*self.m_up1) + self.m_tail = nn.Sequential(*self.m_tail) + #self.apply(self._init_weights) + + def forward(self, x0): + + h, w = x0.size()[-2:] + paddingBottom = int(np.ceil(h/64)*64-h) + paddingRight = int(np.ceil(w/64)*64-w) + if paddingBottom: + pad_bottom = x0[:, :, -paddingBottom:, :] + x0 = torch.cat((x0, pad_bottom), dim=2) + if paddingRight: + pad_right = x0[:, :, :, -paddingRight:] + x0 = torch.cat((x0, pad_right), dim=3) + + # import ipdb; ipdb.set_trace() + # x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) + # cv2.imwrite('x0.png', (x0[0]*255).permute(1,2,0).cpu().numpy()) + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x+x4) + x = self.m_up2(x+x3) + x = self.m_up1(x+x2) + x = self.m_tail(x+x1) + + x = x[..., :h, :w] + + return x + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + +if __name__ == '__main__': + + # torch.cuda.empty_cache() + net = SCUNet() + + x = torch.randn((2, 3, 64, 128)) + x = net(x) + print(x.shape) diff --git a/model/swinir.py b/model/swinir.py new file mode 100644 index 0000000000000000000000000000000000000000..bc37fb6221c9e645edb3cf162e5246b295fbfc3d --- /dev/null +++ b/model/swinir.py @@ -0,0 +1,905 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +# Originally borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py) + +import math +from typing import Set + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + # coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + # Fix: Pass indexing="ij" to avoid warning + coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__( + self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=[6, 6, 6, 6], + num_heads=[6, 6, 6, 6], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + sf=4, + img_range=1., + upsampler='', + resi_connection='1conv', + unshuffle=False, + unshuffle_scale=None, + hq_key: str="jpg", + lq_key: str="hint", + learning_rate: float=None, + weight_decay: float=None + ) -> "SwinIR": + super(SwinIR, self).__init__() + num_in_ch = in_chans * (unshuffle_scale**2) if unshuffle else in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = sf + self.upsampler = upsampler + self.window_size = window_size + self.unshuffle_scale = unshuffle_scale + self.unshuffle = unshuffle + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + if unshuffle: + assert unshuffle_scale is not None + self.conv_first = nn.Sequential( + nn.PixelUnshuffle(sf), + nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1), + ) + else: + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None + ) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None + ) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1) + ) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True) + ) + self.upsample = Upsample(sf, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep( + sf, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1]) + ) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True) + ) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + elif self.upscale == 8: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # TODO: What's this ? + @torch.jit.ignore + def no_weight_decay(self) -> Set[str]: + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self) -> Set[str]: + return {'relative_position_bias_table'} + + def check_image_size(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + elif self.upscale == 8: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self) -> int: + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops diff --git a/model/unet.py b/model/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..a0783214df8989a069696203a8ce608b285c13be --- /dev/null +++ b/model/unet.py @@ -0,0 +1,722 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from model.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, + exists +) +from model.attention import SpatialTransformer + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, label=None): + for layer in self: + if isinstance(layer, TimestepBlock): + # print(f"[INFO] timestepblock") + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + # print(f"[INFO] spatialtransformer") + x = layer(x, context, label=label) + else: + # print(f"[INFO] else") + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def forward(self, x, timesteps, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/model/util.py b/model/util.py new file mode 100755 index 0000000000000000000000000000000000000000..298467da1bc426c3be344599e45f7a104473d387 --- /dev/null +++ b/model/util.py @@ -0,0 +1,225 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +from inspect import isfunction +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +# class CheckpointFunction(torch.autograd.Function): +# @staticmethod +# def forward(ctx, run_function, length, *args): +# ctx.run_function = run_function +# ctx.input_tensors = list(args[:length]) +# ctx.input_params = list(args[length:]) +# ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), +# "dtype": torch.get_autocast_gpu_dtype(), +# "cache_enabled": torch.is_autocast_cache_enabled()} +# with torch.no_grad(): +# output_tensors = ctx.run_function(*ctx.input_tensors) +# return output_tensors + +# @staticmethod +# def backward(ctx, *output_grads): +# ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] +# with torch.enable_grad(), \ +# torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): +# # Fixes a bug where the first op in run_function modifies the +# # Tensor storage in place, which is not allowed for detach()'d +# # Tensors. +# shallow_copies = [x.view_as(x) for x in ctx.input_tensors] +# output_tensors = ctx.run_function(*shallow_copies) +# input_grads = torch.autograd.grad( +# output_tensors, +# ctx.input_tensors + ctx.input_params, +# output_grads, +# allow_unused=True, +# ) +# del ctx.input_tensors +# del ctx.input_params +# del output_tensors +# return (None, None) + input_grads + + +# Fixes: When we set unet parameters with requires_grad=False, the original CheckpointFunction +# still tries to compute gradient for unet parameters. +# https://discuss.pytorch.org/t/get-runtimeerror-one-of-the-differentiated-tensors-does-not-require-grad-in-pytorch-lightning/179738/6 +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad], + output_grads, + allow_unused=True, + ) + grads = list(grads) + # Assign gradients to the correct positions, matching None for those that do not require gradients + input_grads = [] + for tensor in ctx.input_tensors + ctx.input_params: + if tensor.requires_grad: + input_grads.append(grads.pop(0)) # Get the next computed gradient + else: + input_grads.append(None) # No gradient required for this tensor + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + tuple(input_grads) + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/model/vae.py b/model/vae.py new file mode 100755 index 0000000000000000000000000000000000000000..590f70c179883e54d67797c3b4b0e1e62bd8ba7f --- /dev/null +++ b/model/vae.py @@ -0,0 +1,674 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from model.distributions import DiagonalGaussianDistribution +from model.config import Config, AttnMode + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + print(f"building AttnBlock (vanilla) with {in_channels} in_channels") + + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + print(f"building MemoryEfficientAttnBlock (xformers) with {in_channels} in_channels") + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class SDPAttnBlock(nn.Module): + + def __init__(self, in_channels): + super().__init__() + print(f"building SDPAttnBlock (sdp) with {in_channels} in_channels") + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = F.scaled_dot_product_attention(q, k, v) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "sdp", "xformers", "linear", "none"], f'attn_type {attn_type} unknown' + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "sdp": + return SDPAttnBlock(in_channels) + elif attn_type == "xformers": + return MemoryEfficientAttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, + **ignore_kwargs): + super().__init__() + ### setup attention type + if Config.attn_mode == AttnMode.SDP: + attn_type = "sdp" + elif Config.attn_mode == AttnMode.XFORMERS: + attn_type = "xformers" + else: + attn_type = "vanilla" + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + **ignorekwargs): + super().__init__() + ### setup attention type + if Config.attn_mode == AttnMode.SDP: + attn_type = "sdp" + elif Config.attn_mode == AttnMode.XFORMERS: + attn_type = "xformers" + else: + attn_type = "vanilla" + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.controller = None + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + print(f"attn") + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + + + ''' ToMe ''' + # tome_info = { + # "size": None, + # "hooks": [], + # "args": { + # "generator": None, + # "max_downsample": 2, + # "min_downsample": 1, + # "generator": None, + # "seed": 123, + # "batch_size": 1, + # "align_batch": False, + # "merge_global": False, + # "global_merge_ratio": 0, + # "local_merge_ratio": 0.9, + # "global_rand": 0.1, + # "target_stride": 4, + # "current_step": 0, + # "frame_ids": [0], + # "label": "Decoder_up", + # "downsample": 1, + # "controller": self.controller, + # } + # } + # B, C, H, W = h.shape + # h = rearrange(h, 'b c h w -> b (h w) c') + + # if tome_info["args"]["controller"] is None: + # non_pad_ratio_h, non_pad_ratio_w = 1, 1 + # print(f"[INFO] no padding removal") + # else: + # non_pad_ratio_h, non_pad_ratio_w = self.controller.non_pad_ratio + + # padding_size_w = W - int(W * non_pad_ratio_w) + # padding_size_h = H - int(H * non_pad_ratio_h) + # padding_mask = torch.zeros((H, W), device=h.device, dtype=torch.bool) + # if padding_size_w: + # padding_mask[:, -padding_size_w:] = 1 + # if padding_size_h: + # padding_mask[-padding_size_h:, :] = 1 + # padding_mask = rearrange(padding_mask, 'h w -> (h w)') + + # idx_buffer = torch.arange(H * W, device=h.device, dtype=torch.int64) + # non_pad_idx = idx_buffer[None, ~padding_mask, None] + # del idx_buffer, padding_mask + + # x_non_pad = torch.gather(h, dim=1, index=non_pad_idx.expand(B, -1, C)) + # tome_info["size"] = (int(H * non_pad_ratio_h), int(W * non_pad_ratio_w)) + # from vidtome.patch import compute_merge + # m_a, u_a, merged_tokens = compute_merge( + # self, x_non_pad, tome_info) + # x_non_pad = u_a(merged_tokens) + # h.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) + # h = rearrange(h, 'b (h w) c -> b c h w', h=H, w=W) + ''' ToMe ended''' + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + ''' ToMe ''' + # print(f"[INFO] before merging h mean: {torch.mean(h)} h std: {torch.std(h)}") + # B, C, H, W = h.shape + # h = rearrange(h, 'b c h w -> b (h w) c') + + # padding_size_w = W - int(W * non_pad_ratio_w) + # padding_size_h = H - int(H * non_pad_ratio_h) + # padding_mask = torch.zeros((H, W), device=h.device, dtype=torch.bool) + # if padding_size_w: + # padding_mask[:, -padding_size_w:] = 1 + # if padding_size_h: + # padding_mask[-padding_size_h:, :] = 1 + # padding_mask = rearrange(padding_mask, 'h w -> (h w)') + + # idx_buffer = torch.arange(H * W, device=h.device, dtype=torch.int64) + # non_pad_idx = idx_buffer[None, ~padding_mask, None] + # del idx_buffer, padding_mask + + # x_non_pad = torch.gather(h, dim=1, index=non_pad_idx.expand(B, -1, C)) + # tome_info["size"] = (int(H * non_pad_ratio_h), int(W * non_pad_ratio_w)) + # m_a, u_a, merged_tokens = compute_merge( + # self, x_non_pad, tome_info) + # x_non_pad = u_a(merged_tokens) + # h.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) + # h = rearrange(h, 'b (h w) c -> b c h w', h=H, w=W) + # print(f"[INFO] after merging h mean: {torch.mean(h)} h std: {torch.std(h)}") + ''' ToMe ended ''' + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + # print(f"i_level {i_level} i_block {i_block} with shape {h.shape}") + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + # import ipdb; ipdb.set_trace() + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class AutoencoderKL(nn.Module): + + def __init__(self, ddconfig, embed_dim): + super().__init__() + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + def encode(self, x, batch_size=0): + if batch_size: + h = [] + batch_x = x.split(batch_size, dim=0) + for x_ in batch_x: + h_ = self.encoder(x_) + h += [h_] + torch.cuda.empty_cache() + h = torch.cat(h) + else: + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, batch_size=0): + z = self.post_quant_conv(z) + if batch_size: + dec = [] + batch_z = z.split(batch_size, dim=0) + for z_ in batch_z: + # decode + z_ = self.decoder(z_) + dec += [z_] + torch.cuda.empty_cache() + dec = torch.cat(dec) + else: + dec = self.decoder(z) + # import ipdb; ipdb.set_trace() + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..249628612f834e144518fed7316fc3b720a07adf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,188 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +# absl-py==2.1.0 +accelerate +# aiofiles==23.2.1 +# aiohttp==3.9.5 +# aiosignal==1.3.1 +# altair==5.3.0 +# annotated-types==0.6.0 +# antlr4-python3-runtime==4.9.3 +# anyio==4.3.0 +# asttokens==2.4.1 +# async-timeout==4.0.3 +# attrs==23.2.0 +# ca-certificates==2024.3.11==h06a4308_0 +# certifi==2024.2.2 +# charset-normalizer==3.3.2 +# click==8.1.7 +# contourpy==1.2.1 +# cycler==0.12.1 +# decorator==5.1.1 +diffusers==0.27.2 +einops==0.7.0 +# exceptiongroup==1.2.1 +# executing==2.0.1 +facexlib==0.3.0 +# fastapi==0.110.2 +# ffmpeg==4.2.2==h20bf706_0 +# ffmpy==0.3.2 +# filelock==3.13.4 +# filterpy==1.4.5 +# fonttools==4.51.0 +# freetype==2.12.1==h4a9f257_0 +# frozenlist==1.4.1 +# fsspec==2024.3.1 +ftfy==6.2.0 +# future==1.0.0 +# gdown==5.1.0 +# gmp==6.2.1==h295c915_3 +# gnutls==3.6.15==he1e5248_0 +gradio +# gradio-client==1.0.1 +# grpcio==1.62.2 +# h11==0.14.0 +# httpcore==1.0.5 +# httpx==0.27.0 +# huggingface-hub==0.22.2 +# idna==3.7 +imageio==2.34.1 +imageio-ffmpeg==0.5.1 +# imhist==0.0.4 +# importlib-metadata==7.1.0 +# importlib-resources==6.4.0 +# imwatermark==0.0.2 +# invisible-watermark==0.2.0 +# ipdb==0.13.13 +# ipython==8.18.1 +# jedi==0.19.1 +# jinja2==3.1.3 +# joblib==1.4.0 +# jsonschema==4.21.1 +# jsonschema-specifications==2023.12.1 +# kiwisolver==1.4.5 +# kornia==0.7.2 +# kornia-rs==0.1.3 +# lame==3.100==h7b6447c_0 +# lazy-loader==0.4 +# ld_impl_linux-64==2.38==h1181459_1 +# libffi==3.4.4==h6a678d5_0 +# libgcc-ng==11.2.0==h1234567_1 +# libgomp==11.2.0==h1234567_1 +# libidn2==2.3.4==h5eee18b_0 +# libopus==1.3.1==h7b6447c_0 +# libpng==1.6.39==h5eee18b_0 +# libstdcxx-ng==11.2.0==h1234567_1 +# libtasn1==4.19.0==h5eee18b_0 +# libunistring==0.9.10==h27cfd23_0 +# libvpx==1.7.0==h439df22_0 +# lightning-utilities==0.11.2 +# linkify-it-py==2.0.3 +# llvmlite==0.42.0 +# lpips==0.1.4 +mediapipe +# markdown==3.6 +# markdown-it-py==3.0.0 +# markupsafe==2.1.5 +# matplotlib==3.8.4 +# matplotlib-inline==0.1.7 +# mdit-py-plugins==0.4.0 +# mdurl==0.1.2 +# multidict==6.0.5 +# mypy-extensions==1.0.0 +# ncurses==6.4==h6a678d5_0 +# nettle==3.7.3==hbbd107a_1 +# networkx==3.2.1 +# numba==0.59.1 +numpy==1.26.4 +omegaconf==2.3.0 +# open-clip-torch==2.24.0 +opencv-python==4.9.0.80 +# openh264==2.1.1==h4ff587b_0 +# openssl==3.0.13==h7f8727e_0 +# orjson==3.10.1 +# packaging==24.0 +# pandas==2.2.2 +# parso==0.8.4 +# pexpect==4.9.0 +pillow==10.3.0 +# pip==23.3.1==py39h06a4308_0 +# prompt-toolkit==3.0.43 +# protobuf==4.25.3 +# psutil==5.9.8 +# ptyprocess==0.7.0 +# pure-eval==0.2.2 +# pycryptodome==3.20.0 +# pydantic==2.7.1 +# pydantic-core==2.18.2 +# pydeprecate==0.3.1 +# pydub==0.25.1 +# pygments==2.17.2 +# pyparsing==3.1.2 +# pyre-extensions==0.0.23 +# pysocks==1.7.1 +# python==3.9.19==h955ad1f_0 +# python-dateutil==2.9.0.post0 +# python-multipart==0.0.9 +pytorch-lightning==2.2.5 +# pytz==2024.1 +# pywavelets==1.6.0 +# pyyaml==6.0.1 +# readline==8.2==h5eee18b_0 +# referencing==0.35.0 +# regex==2023.12.25 +# requests==2.31.0 +# rich==13.7.1 +# rpds-py==0.18.0 +# ruff==0.4.2 +# safetensors==0.4.3 +# scikit-image==0.22.0 +# scikit-learn==1.4.2 +scikit-video==1.1.11 +# scipy==1.12.0 +# semantic-version==2.10.0 +# sentencepiece==0.2.0 +# setuptools==68.2.2==py39h06a4308_0 +# shellingham==1.5.4 +# six==1.16.0 +# sniffio==1.3.1 +# soupsieve==2.5 +spaces +# sqlite==3.41.2==h5eee18b_0 +# stack-data==0.6.3 +# starlette==0.37.2 +# tensorboard==2.16.2 +# tensorboard-data-server==0.7.2 +# threadpoolctl==3.5.0 +# tifffile==2024.4.24 +timm==0.9.16 +# tk==8.6.12==h1ccaba5_0 +# tokenizers==0.12.1 +# tomli==2.0.1 +# tomlkit==0.12.0 +# toolz==0.12.1 +torch==2.0.0 +torchmetrics +# torchvision==0.14.1+cu117 +tqdm==4.66.2 +# traitlets==5.14.3 +transformers +# triton==2.3.0 +# typer==0.12.3 +# typing-extensions==4.11.0 +# typing-inspect==0.9.0 +# tzdata==2024.1 +# uc-micro-py==1.0.3 +# urllib3==2.2.1 +# uvicorn==0.29.0 +# wcwidth==0.2.13 +# websockets==11.0.3 +# werkzeug==3.0.2 +# wheel==0.41.2==py39h06a4308_0 +# x264==1!157.20191217==h7b6447c_0 +xformers==0.0.19 +# xz==5.4.6==h5eee18b_0 +# yarl==1.9.4 +# zipp==3.18.1 +# zlib==1.2.13==h5eee18b_0 diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..cf501a42a5b082777ea31c7a681e69c6067ae4e8 --- /dev/null +++ b/style.css @@ -0,0 +1,92 @@ +/* +This CSS file is modified from: +https://huggingface.co/spaces/DeepFloyd/IF/blob/main/style.css +*/ + +h1 { + text-align: center; +} + +.gradio-container { + font-family: 'IBM Plex Sans', sans-serif; +} + +.gr-button { + color: white; + border-color: black; + background: black; +} + +input[type='range'] { + accent-color: black; +} + +.dark input[type='range'] { + accent-color: #dfdfdf; +} + +.container { + max-width: 1500px; + margin: auto; +} + +.gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); + outline: none; + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); + --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); + --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); + --tw-ring-opacity: .5; +} + +.gr-form { + flex: 1 1 50%; + border-top-right-radius: 0; + border-bottom-right-radius: 0; +} + +#prompt-container { + gap: 0; +} + +#prompt-text-input, +#negative-prompt-text-input { + padding: .45rem 0.625rem +} + +/* #component-16 { + border-top-width: 1px !important; + margin-top: 1em +} */ + + +.image_duplication { + position: absolute; + width: 100px; + left: 50px +} + +#component-0 { + max-width: 1500px; + margin: auto; + padding-top: 1.5rem; +} + +#share-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; margin-left: auto; +} +#share-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important; +} +#share-btn * { + all: unset; +} +#share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; +} +#share-btn-container .wrap { + display: none !important; +} \ No newline at end of file diff --git a/utils/__pycache__/batch_inference.cpython-310.pyc b/utils/__pycache__/batch_inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f25c569cce6f4f1abd57ce1ef08fb2ab85aefd4e Binary files /dev/null and b/utils/__pycache__/batch_inference.cpython-310.pyc differ diff --git a/utils/__pycache__/batch_inference.cpython-39.pyc b/utils/__pycache__/batch_inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2d63529661600bc3b2e8e9ac659c67fca01ae54 Binary files /dev/null and b/utils/__pycache__/batch_inference.cpython-39.pyc differ diff --git a/utils/__pycache__/common.cpython-310.pyc b/utils/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..721c249d5514cd77ed50614f27250c353f95c662 Binary files /dev/null and b/utils/__pycache__/common.cpython-310.pyc differ diff --git a/utils/__pycache__/common.cpython-39.pyc b/utils/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..135a2e38154eea6191e42d7005ae9f3c937ea680 Binary files /dev/null and b/utils/__pycache__/common.cpython-39.pyc differ diff --git a/utils/__pycache__/cond_fn.cpython-310.pyc b/utils/__pycache__/cond_fn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76c4598a85bedf9729117e11713877a6465713ef Binary files /dev/null and b/utils/__pycache__/cond_fn.cpython-310.pyc differ diff --git a/utils/__pycache__/cond_fn.cpython-39.pyc b/utils/__pycache__/cond_fn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4c0184c05863bd83bf8d7b054508c42cbff2d5 Binary files /dev/null and b/utils/__pycache__/cond_fn.cpython-39.pyc differ diff --git a/utils/__pycache__/degradation.cpython-39.pyc b/utils/__pycache__/degradation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..806cf28ea4630a2632196fdb34d595c86a44a2da Binary files /dev/null and b/utils/__pycache__/degradation.cpython-39.pyc differ diff --git a/utils/__pycache__/face_restoration_helper.cpython-310.pyc b/utils/__pycache__/face_restoration_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f6ea3712146e6ac4d206347fc2de67950c50008 Binary files /dev/null and b/utils/__pycache__/face_restoration_helper.cpython-310.pyc differ diff --git a/utils/__pycache__/face_restoration_helper.cpython-39.pyc b/utils/__pycache__/face_restoration_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee39c2deadd2ab194784ba898fd103632e637ddc Binary files /dev/null and b/utils/__pycache__/face_restoration_helper.cpython-39.pyc differ diff --git a/utils/__pycache__/file.cpython-39.pyc b/utils/__pycache__/file.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..734e55826c918e6b660cddb23ec2ca576c325da6 Binary files /dev/null and b/utils/__pycache__/file.cpython-39.pyc differ diff --git a/utils/__pycache__/flow_utils.cpython-310.pyc b/utils/__pycache__/flow_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daa9e89f7e52f1fbfacc4c2e9321f5fb4e44b601 Binary files /dev/null and b/utils/__pycache__/flow_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/flow_utils.cpython-39.pyc b/utils/__pycache__/flow_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa6e7a6faa7ad9127d05f55d25f47dc736458e0 Binary files /dev/null and b/utils/__pycache__/flow_utils.cpython-39.pyc differ diff --git a/utils/__pycache__/helpers.cpython-310.pyc b/utils/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8156a144313830d3210bb8e7b866f00b805773fc Binary files /dev/null and b/utils/__pycache__/helpers.cpython-310.pyc differ diff --git a/utils/__pycache__/helpers.cpython-39.pyc b/utils/__pycache__/helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e4a18febe956cf2cfb0f6d63357199202263a24 Binary files /dev/null and b/utils/__pycache__/helpers.cpython-39.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-310.pyc b/utils/__pycache__/image_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a28c4aafc68fc89595e19e64c14214feb46b5d58 Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-310.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-38.pyc b/utils/__pycache__/image_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b8e8c9edd2fd94c6f05cab845a5aadfb3c3fcf Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/image_utils.cpython-39.pyc b/utils/__pycache__/image_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..101e2956887483e03f5c54d9eed47d6a556fbd18 Binary files /dev/null and b/utils/__pycache__/image_utils.cpython-39.pyc differ diff --git a/utils/__pycache__/inference.cpython-310.pyc b/utils/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e833c61e1b252096001de85082bb30cec4d1e3f Binary files /dev/null and b/utils/__pycache__/inference.cpython-310.pyc differ diff --git a/utils/__pycache__/inference.cpython-39.pyc b/utils/__pycache__/inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbbfaf8040b075c8ba0a1246c2e7501c13b4721a Binary files /dev/null and b/utils/__pycache__/inference.cpython-39.pyc differ diff --git a/utils/__pycache__/metrics.cpython-39.pyc b/utils/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cfcfa534167e6a3dfcf070df879a32fe6b34c38 Binary files /dev/null and b/utils/__pycache__/metrics.cpython-39.pyc differ diff --git a/utils/__pycache__/sampler.cpython-310.pyc b/utils/__pycache__/sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4b8bd28aa44a9b055172c3cf28e6ece1ee35bf9 Binary files /dev/null and b/utils/__pycache__/sampler.cpython-310.pyc differ diff --git a/utils/__pycache__/sampler.cpython-39.pyc b/utils/__pycache__/sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a69d0e1b80dc7d08dba3601e96880114e5e068b1 Binary files /dev/null and b/utils/__pycache__/sampler.cpython-39.pyc differ diff --git a/utils/__pycache__/video_visualizer.cpython-310.pyc b/utils/__pycache__/video_visualizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295b0fabe29b023bebad191b9503b34690b033c4 Binary files /dev/null and b/utils/__pycache__/video_visualizer.cpython-310.pyc differ diff --git a/utils/__pycache__/video_visualizer.cpython-38.pyc b/utils/__pycache__/video_visualizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..facd6803c9622e31a48b1fb64f62f07a29e59ec6 Binary files /dev/null and b/utils/__pycache__/video_visualizer.cpython-38.pyc differ diff --git a/utils/__pycache__/video_visualizer.cpython-39.pyc b/utils/__pycache__/video_visualizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d365d4e15453b4b7d3cf03296fa2687d2fffa45 Binary files /dev/null and b/utils/__pycache__/video_visualizer.cpython-39.pyc differ diff --git a/utils/batch_inference.py b/utils/batch_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c6141c63111a0e8937ca1c539edccba47ae61149 --- /dev/null +++ b/utils/batch_inference.py @@ -0,0 +1,436 @@ +import os +import cv2 +from typing import overload, Generator, Dict +from argparse import Namespace + +import numpy as np +import torch +import imageio +from PIL import Image +from omegaconf import OmegaConf + +from accelerate.utils import set_seed +from model.cldm import ControlLDM +from model.gaussian_diffusion import Diffusion +from model.bsrnet import RRDBNet +from model.scunet import SCUNet +from model.swinir import SwinIR +from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage +from utils.face_restoration_helper import FaceRestoreHelper +from utils.helpers import ( + Pipeline, + BSRNetPipeline, SwinIRPipeline, SCUNetPipeline, + batch_bicubic_resize, + bicubic_resize, + save_video +) +from utils.cond_fn import MSEGuidance, WeightedMSEGuidance +from GMFlow.gmflow.gmflow import GMFlow +from controller.controller import AttentionControl + +MODELS = { + ### stage_1 model weights + "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth", + # the following checkpoint is up-to-date, but we use the old version in our paper + # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth", + "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt", + "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth", + "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt", + ### stage_2 model weights + "sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt", + "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth", + "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth", + "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth" +} + + +def load_model_from_url(url: str) -> Dict[str, torch.Tensor]: + sd_path = load_file_from_url(url, model_dir="weights") + sd = torch.load(sd_path, map_location="cpu") + if "state_dict" in sd: + sd = sd["state_dict"] + if list(sd.keys())[0].startswith("module"): + sd = {k[len("module."):]: v for k, v in sd.items()} + return sd + + +class InferenceLoop: + + def __init__(self, args: Namespace) -> "InferenceLoop": + self.args = args + self.loop_ctx = {} + self.pipeline: Pipeline = None + self.init_stage1_model() + self.init_stage2_model() + self.init_cond_fn() + self.init_pipeline() + + @overload + def init_stage1_model(self) -> None: + ... + + @count_vram_usage + def init_stage2_model(self) -> None: + ### load uent, vae, clip + # self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/my_cldm.yaml")) + config = OmegaConf.load(self.args.config) + if self.args.warp_period is not None: + config.params.latent_warp_cfg.warp_period = self.args.warp_period + if self.args.merge_period is not None: + config.params.latent_warp_cfg.merge_period = self.args.merge_period + if self.args.ToMe_period is not None: + config.params.VidToMe_cfg.ToMe_period = self.args.ToMe_period + if self.args.merge_ratio is not None: + config.params.VidToMe_cfg.merge_ratio = self.args.merge_ratio + + # import ipdb; ipdb.set_trace() + self.cldm: ControlLDM = instantiate_from_config(config) + sd = load_model_from_url(MODELS["sd_v21"]) + unused = self.cldm.load_pretrained_sd(sd) + print(f"strictly load pretrained sd_v2.1, unused weights: {unused}") + ### load controlnet + control_sd = load_model_from_url(MODELS["v2"]) + + self.cldm.load_controlnet_from_ckpt(control_sd) + print(f"strictly load controlnet weight") + self.cldm.eval().to(self.args.device) + ### load diffusion + self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml")) + self.diffusion.to(self.args.device) + + def init_cond_fn(self) -> None: + if not self.args.guidance: + self.cond_fn = None + return + if self.args.g_loss == "mse": + cond_fn_cls = MSEGuidance + elif self.args.g_loss == "w_mse": + cond_fn_cls = WeightedMSEGuidance + else: + raise ValueError(self.args.g_loss) + self.cond_fn = cond_fn_cls( + scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop, + space=self.args.g_space, repeat=self.args.g_repeat + ) + + @overload + def init_pipeline(self) -> None: + ... + + def setup(self) -> None: + pass + # self.output_dir = self.args.output + # os.makedirs(self.output_dir, exist_ok=True) + + def lq_loader(self) -> Generator[np.ndarray, None, None]: + img_exts = [".png", ".jpg", ".jpeg"] + if os.path.isdir(self.args.input): + file_names = sorted([ + file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts + ]) + file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names] + else: + assert os.path.splitext(self.args.input)[-1] in img_exts + file_paths = [self.args.input] + + def _loader() -> Generator[np.ndarray, None, None]: + for file_path in file_paths: + ### load lq + lq = np.array(Image.open(file_path).convert("RGB")) + print(f"load lq: {file_path}") + ### set context for saving results + self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0] + for i in range(self.args.n_samples): + self.loop_ctx["repeat_idx"] = i + yield lq + + return _loader + + def batch_lq_loader(self) -> Generator[np.ndarray, None, None]: + img_exts = [".png", ".jpg", ".jpeg"] + print(f"[INFO] input: {self.args.input}") + if os.path.isdir(self.args.input): + file_names = sorted([ + file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts + ], key=lambda x: int(x.split('.')[0])) + # file_names=file_names[30:] + # sorted([filename for filename in os.listdir(img_folder) if filename.endswith(('.png', '.jpg'))], key=lambda x: int(x.split('.')[0])) + file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names] + file_paths = file_paths[:self.args.n_frames] + else: + assert os.path.splitext(self.args.input)[-1] in img_exts + file_paths = [self.args.input] + def _loader() -> Generator[np.ndarray, None, None]: + for j in range(0, len(file_paths), self.args.batch_size): + lqs, self.loop_ctx["file_stem"] = [], [] + batch = self.args.batch_size if len(file_paths) - (j + self.args.batch_size) > 2 else len(file_paths) - j + if batch != self.args.batch_size: + self.args.batch_size = batch + # sampler.model.controller.distances.clear() + if self.pipeline.cldm.controller is not None and self.pipeline.cldm.controller.distances is not None: + self.pipeline.cldm.controller.distances.clear() + + for file_path in file_paths[j:min(len(file_paths), j+batch)]: + ### load lq + print(f"[INFO] load lq: {file_path}") + lq = np.array(Image.open(file_path).convert("RGB")) + lqs.append(lq) + ### set context for saving results + self.loop_ctx["file_stem"].append(os.path.splitext(os.path.basename(file_path))[0]) + # import ipdb; ipdb.set_trace() + self.args.final_size = (lqs[0].shape[0] * self.args.upscale, lqs[0].shape[1] * self.args.upscale) + for i in range(self.args.n_samples): + self.loop_ctx["repeat_idx"] = i + yield np.array(lqs) + if j + batch == len(file_paths): + break + + return _loader + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + return lq + + @torch.no_grad() + def run(self) -> None: + self.setup() + # We don't support batch processing since input images may have different size + loader = self.batch_lq_loader() + + ''' flow model ''' + + flow_model = GMFlow( + feature_channels=128, + num_scales=1, + upsample_factor=8, + num_head=1, + attention_type='swin', + ffn_dim_expansion=4, + num_transformer_layers=6, + ).to(self.args.device) + + checkpoint = torch.load('weights/gmflow_sintel-0c07dcb3.pth', + map_location=lambda storage, loc: storage) + weights = checkpoint['model'] if 'model' in checkpoint else checkpoint + flow_model.load_state_dict(weights, strict=False) + flow_model.eval() + + ''' flow model ended ''' + results = [] + if self.cldm.latent_control: + self.cldm.controller.set_total_step(self.args.steps) + for i, img in enumerate(loader()): + torch.cuda.empty_cache() + # import ipdb; ipdb.set_trace() + lq = img + lq = self.after_load_lq(lq) + if self.cldm.latent_control: + print(f"[INFO] set seed @ {self.args.seed}") + set_seed(self.args.seed) + samples, stage1s = self.pipeline.run( + lq, self.args.steps, 1.0, self.args.tiled, + self.args.tile_size, self.args.tile_stride, + self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale, + self.args.better_start, + index=i, input=self.args.input, final_size=self.args.final_size, + flow_model=flow_model, + ) + + if self.cldm.controller is not None: + self.cldm.controller.set_pre_keyframe_lq(lq[self.args.batch_size // 2][None]) + results.append(samples) + + results = np.concatenate(results, axis=0) + video_path = f'DiffIR2VR_fps_10.mp4' + results = [np.array(img).astype(np.uint8) for img in results] + writer = imageio.get_writer(video_path, fps=10, codec='libx264', + macro_block_size=1) + + for img in results: + writer.append_data(img) + + writer.close() + + return video_path + + + def save(self, sample: np.ndarray) -> None: + file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"] + file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png" + save_path = os.path.join(self.args.output, file_name) + Image.fromarray(sample).save(save_path) + print(f"save result to {save_path}") + + def batch_save(self, samples: np.ndarray, dir: str=None) -> None: + file_stems, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"] + + if dir is not None: + out_dir = os.path.join(self.args.output, dir) + else: + out_dir = os.path.join(self.args.output) + os.makedirs(out_dir, exist_ok=True) + + for file_stem, sample in zip(file_stems, samples): + file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png" + save_path = os.path.join(out_dir, file_name) + Image.fromarray(sample).save(save_path) + print(f"save result to {save_path}") + + +class BSRInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml")) + sd = load_model_from_url(MODELS["bsrnet"]) + self.bsrnet.load_state_dict(sd, strict=True) + self.bsrnet.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale) + + +class BFRInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml")) + sd = load_model_from_url(MODELS["swinir_face"]) + self.swinir_face.load_state_dict(sd, strict=True) + self.swinir_face.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device) + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + # For BFR task, super resolution is achieved by directly upscaling lq + return bicubic_resize(lq, self.args.upscale) + + +class BIDInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.scunet_psnr: SCUNet = instantiate_from_config(OmegaConf.load("configs/inference/scunet.yaml")) + sd = load_model_from_url(MODELS["scunet_psnr"]) + self.scunet_psnr.load_state_dict(sd, strict=True) + self.scunet_psnr.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = SCUNetPipeline(self.scunet_psnr, self.cldm, self.diffusion, self.cond_fn, self.args.device) + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + # For BID task, super resolution is achieved by directly upscaling lq + return batch_bicubic_resize(lq, self.args.upscale) + + +class V1InferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.swinir: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml")) + if self.args.task == "fr": + sd = load_model_from_url(MODELS["swinir_face"]) + elif self.args.task == "sr": + sd = load_model_from_url(MODELS["swinir_general"]) + else: + raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'") + self.swinir.load_state_dict(sd, strict=True) + self.swinir.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = SwinIRPipeline(self.swinir, self.cldm, self.diffusion, self.cond_fn, self.args.device) + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + # For BFR task, super resolution is achieved by directly upscaling lq + return bicubic_resize(lq, self.args.upscale) + + +class UnAlignedBFRInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml")) + sd = load_model_from_url(MODELS["bsrnet"]) + self.bsrnet.load_state_dict(sd, strict=True) + self.bsrnet.eval().to(self.args.device) + + self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml")) + sd = load_model_from_url(MODELS["swinir_face"]) + self.swinir_face.load_state_dict(sd, strict=True) + self.swinir_face.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipes = { + "bg": BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale), + "face": SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device) + } + self.pipeline = self.pipes["face"] + + def setup(self) -> None: + super().setup() + self.cropped_face_dir = os.path.join(self.args.output, "cropped_faces") + os.makedirs(self.cropped_face_dir, exist_ok=True) + self.restored_face_dir = os.path.join(self.args.output, "restored_faces") + os.makedirs(self.restored_face_dir, exist_ok=True) + self.restored_bg_dir = os.path.join(self.args.output, "restored_backgrounds") + os.makedirs(self.restored_bg_dir, exist_ok=True) + + def lq_loader(self) -> Generator[np.ndarray, None, None]: + base_loader = super().lq_loader() + self.face_helper = FaceRestoreHelper( + device=self.args.device, + upscale_factor=1, + face_size=512, + use_parse=True, + det_model="retinaface_resnet50" + ) + + def _loader() -> Generator[np.ndarray, None, None]: + for lq in base_loader(): + ### set input image + self.face_helper.clean_all() + upscaled_bg = bicubic_resize(lq, self.args.upscale) + self.face_helper.read_image(upscaled_bg) + ### get face landmarks for each face + self.face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() + print(f"detect {len(self.face_helper.cropped_faces)} faces") + ### restore each face (has been upscaeled) + for i, lq_face in enumerate(self.face_helper.cropped_faces): + self.loop_ctx["is_face"] = True + self.loop_ctx["face_idx"] = i + self.loop_ctx["cropped_face"] = lq_face + yield lq_face + ### restore background (hasn't been upscaled) + self.loop_ctx["is_face"] = False + yield lq + + return _loader + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + if self.loop_ctx["is_face"]: + self.pipeline = self.pipes["face"] + else: + self.pipeline = self.pipes["bg"] + return lq + + def save(self, sample: np.ndarray) -> None: + file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"] + if self.loop_ctx["is_face"]: + face_idx = self.loop_ctx["face_idx"] + file_name = f"{file_stem}_{repeat_idx}_face_{face_idx}.png" + Image.fromarray(sample).save(os.path.join(self.restored_face_dir, file_name)) + + cropped_face = self.loop_ctx["cropped_face"] + Image.fromarray(cropped_face).save(os.path.join(self.cropped_face_dir, file_name)) + + self.face_helper.add_restored_face(sample) + else: + self.face_helper.get_inverse_affine() + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image( + upsample_img=sample + ) + file_name = f"{file_stem}_{repeat_idx}.png" + Image.fromarray(sample).save(os.path.join(self.restored_bg_dir, file_name)) + Image.fromarray(restored_img).save(os.path.join(self.output_dir, file_name)) diff --git a/utils/common.py b/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f857cef0e5212ca331d7e97f5aec27c62ae9e873 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,160 @@ +from typing import Mapping, Any, Tuple, Callable +import importlib +import os +from urllib.parse import urlparse + +import torch +from torch import Tensor +from torch.nn import functional as F +import numpy as np + +from torch.hub import download_url_to_file, get_dir + + +def get_obj_from_str(string: str, reload: bool=False) -> Any: + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config: Mapping[str, Any]) -> Any: + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + # import ipdb; ipdb.set_trace() + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq + + +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/ +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + + +def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: + hi_list = list(range(0, h - tile_size + 1, tile_stride)) + if (h - tile_size) % tile_stride != 0: + hi_list.append(h - tile_size) + + wi_list = list(range(0, w - tile_size + 1, tile_stride)) + if (w - tile_size) % tile_stride != 0: + wi_list.append(w - tile_size) + + coords = [] + for hi in hi_list: + for wi in wi_list: + coords.append((hi, hi + tile_size, wi, wi + tile_size)) + return coords + + +# https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503 +def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray: + """Generates a gaussian mask of weights for tile contributions""" + latent_width = tile_width + latent_height = tile_height + var = 0.01 + midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 + x_probs = [ + np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var) + for x in range(latent_width)] + midpoint = latent_height / 2 + y_probs = [ + np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var) + for y in range(latent_height)] + weights = np.outer(y_probs, x_probs) + return weights + + +COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False)) + +def count_vram_usage(func: Callable) -> Callable: + if not COUNT_VRAM: + return func + + def wrapper(*args, **kwargs): + peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3) + ret = func(*args, **kwargs) + torch.cuda.synchronize() + peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3) + print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB") + return ret + return wrapper \ No newline at end of file diff --git a/utils/cond_fn.py b/utils/cond_fn.py new file mode 100755 index 0000000000000000000000000000000000000000..799f0ac93655c98554f3b44359faec7d2a0b12cd --- /dev/null +++ b/utils/cond_fn.py @@ -0,0 +1,98 @@ +from typing import overload, Tuple +import torch +from torch.nn import functional as F + + +class Guidance: + + def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance": + """ + Initialize restoration guidance. + + Args: + scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, + the closer the final result will be to the output of the first stage model. + t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling + process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`. + space (str): The data space for computing loss function (rgb or latent). + + Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior). + Thanks for their work! + """ + self.scale = scale * 3000 + self.t_start = t_start + self.t_stop = t_stop + self.target = None + self.space = space + self.repeat = repeat + + def load_target(self, target: torch.Tensor) -> None: + self.target = target + + def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: + # avoid propagating gradient out of this scope + pred_x0 = pred_x0.detach().clone() + target_x0 = target_x0.detach().clone() + return self._forward(target_x0, pred_x0, t) + + @overload + def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: + ... + + +class MSEGuidance(Guidance): + + def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: + # inputs: [-1, 1], nchw, rgb + with torch.enable_grad(): + pred_x0.requires_grad_(True) + loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum() + scale = self.scale + g = -torch.autograd.grad(loss, pred_x0)[0] * scale + return g, loss.item() + + +class WeightedMSEGuidance(Guidance): + + def _get_weight(self, target: torch.Tensor) -> torch.Tensor: + # convert RGB to G + rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1) + target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True) + # initialize sobel kernel in x and y axis + G_x = [ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1] + ] + G_y = [ + [1, 2, 1], + [0, 0, 0], + [-1, -2, -1] + ] + G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None] + G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None] + G = torch.stack((G_x, G_y)) + + target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1 + grad = F.conv2d(target, G, stride=1) + mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt() + + n, c, h, w = mag.size() + block_size = 2 + blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous() + block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous() + block_mean = block_mean.view(n, c, h, w) + weight_map = 1 - block_mean + + return weight_map + + def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: + # inputs: [-1, 1], nchw, rgb + with torch.no_grad(): + w = self._get_weight((target_x0 + 1) / 2) + with torch.enable_grad(): + pred_x0.requires_grad_(True) + loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum() + scale = self.scale + g = -torch.autograd.grad(loss, pred_x0)[0] * scale + return g, loss.item() diff --git a/utils/degradation.py b/utils/degradation.py new file mode 100644 index 0000000000000000000000000000000000000000..aa75976523a36d0020ddbc0db9f0f0c3d753fde9 --- /dev/null +++ b/utils/degradation.py @@ -0,0 +1,765 @@ +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/data/degradations.py +import cv2 +import math +import numpy as np +import random +import torch +from scipy import special +from scipy.stats import multivariate_normal +from torchvision.transforms.functional_tensor import rgb_to_grayscale + +# -------------------------------------------------------------------- # +# --------------------------- blur kernels --------------------------- # +# -------------------------------------------------------------------- # + + +# --------------------------- util functions --------------------------- # +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(d_matrix, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + + Args: + d_matrix (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, d_matrix) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a bivariate generalized Gaussian kernel. + + ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions`` + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a plateau-like anisotropic kernel. + + 1 / (1+x^(beta)) + + Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate generalized Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # assume beta_range[0] < 1 < beta_range[1] + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_plateau(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate plateau kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # TODO: this may be not proper + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False) + elif kernel_type == 'generalized_iso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=True) + elif kernel_type == 'generalized_aniso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=False) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False) + return kernel + + +np.seterr(divide='ignore', invalid='ignore') + + +def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): + """2D sinc filter + + Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + + Args: + cutoff (float): cutoff frequency in radians (pi is max) + kernel_size (int): horizontal and vertical size, must be odd. + pad_to (int): pad kernel size to desired size, must be odd or zero. + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + kernel = np.fromfunction( + lambda x, y: cutoff * special.j1(cutoff * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) + kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) + kernel = kernel / np.sum(kernel) + if pad_to > kernel_size: + pad_size = (pad_to - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + return kernel + + +# ------------------------------------------------------------- # +# --------------------------- noise --------------------------- # +# ------------------------------------------------------------- # + +# ----------------------- Gaussian Noise ----------------------- # + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. + return noise + + +def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): + """Add Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_gaussian_noise(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if not isinstance(sigma, (float, int)): + sigma = sigma.view(img.size(0), 1, 1, 1) + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + + if cal_gray_noise: + noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. + noise_gray = noise_gray.view(b, 1, h, w) + + # always calculate color noise + noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. + + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + return noise + + +def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_gaussian_noise_pt(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Gaussian Noise ----------------------- # +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): + sigma = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_gaussian_noise_pt(img, sigma, gray_noise) + + +def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Poisson (Shot) Noise ----------------------- # + + +def generate_poisson_noise(img, scale=1.0, gray_noise=False): + """Generate poisson noise. + + Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # round and clip image for counting vals correctly + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = len(np.unique(img)) + vals = 2**np.ceil(np.log2(vals)) + out = np.float32(np.random.poisson(img * vals) / float(vals)) + noise = out - img + if gray_noise: + noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) + return noise * scale + + +def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): + """Add poisson noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_poisson_noise(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): + """Generate a batch of poisson noise (PyTorch version) + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + if cal_gray_noise: + img_gray = rgb_to_grayscale(img, num_output_channels=1) + # round and clip image for counting vals correctly + img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img_gray * vals) / vals + noise_gray = out - img_gray + noise_gray = noise_gray.expand(b, 3, h, w) + + # always calculate color noise + # round and clip image for counting vals correctly + img = torch.clamp((img * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img * vals) / vals + noise = out - img + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + if not isinstance(scale, (float, int)): + scale = scale.view(b, 1, 1, 1) + return noise * scale + + +def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): + """Add poisson noise to a batch of images (PyTorch version). + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_poisson_noise_pt(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Poisson (Shot) Noise ----------------------- # + + +def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): + scale = np.random.uniform(scale_range[0], scale_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_poisson_noise(img, scale, gray_noise) + + +def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): + scale = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_poisson_noise_pt(img, scale, gray_noise) + + +def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ------------------------------------------------------------------------ # +# --------------------------- JPEG compression --------------------------- # +# ------------------------------------------------------------------------ # + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + _, encimg = cv2.imencode('.jpg', img * 255., encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255. + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100)): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = np.random.uniform(quality_range[0], quality_range[1]) + return add_jpg_compression(img, int(quality)) diff --git a/utils/face_restoration_helper.py b/utils/face_restoration_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc806348f2e7c46ff0ab4c0288fb4141d9eb8fe --- /dev/null +++ b/utils/face_restoration_helper.py @@ -0,0 +1,517 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from facexlib.detection import init_detection_model +from facexlib.parsing import init_parsing_model +from facexlib.utils.misc import img2tensor, imwrite + +from utils.common import load_file_from_url + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = int(upscale_factor) + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + self.det_model = det_model + + if self.det_model == 'dlib': + # standard 5 landmarks for FFHQ faces with 1024 x 1024 + self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941], + [337.91089109, 488.38613861], [437.95049505, 493.51485149], + [513.58415842, 678.5049505]]) + self.face_template = self.face_template / (1024 // face_size) + elif self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # self.device = get_device() + else: + self.device = device + + # init face detection model + self.face_detector = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + # self.is_gray = is_gray(img, threshold=10) + # if self.is_gray: + # print('Grayscale input: True') + + if min(self.input_img.shape[:2])<512: + f = 512.0/min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def init_dlib(self, detection_path, landmark5_path): + """Initialize the dlib detectors and predictors.""" + try: + import dlib + except ImportError: + print('Please install dlib by running:' 'conda install -c conda-forge dlib') + detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None) + landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None) + face_detector = dlib.cnn_face_detection_model_v1(detection_path) + shape_predictor_5 = dlib.shape_predictor(landmark5_path) + return face_detector, shape_predictor_5 + + def get_face_landmarks_5_dlib(self, + only_keep_largest=False, + scale=1): + det_faces = self.face_detector(self.input_img, scale) + + if len(det_faces) == 0: + print('No face detected. Try to increase upsample_num_times.') + return 0 + else: + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for i in range(len(det_faces)): + face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * ( + det_faces[i].rect.bottom() - det_faces[i].rect.top()) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + self.det_faces = [det_faces[largest_idx]] + else: + self.det_faces = det_faces + + if len(self.det_faces) == 0: + return 0 + + for face in self.det_faces: + shape = self.shape_predictor_5(self.input_img, face.rect) + landmark = np.array([[part.x, part.y] for part in shape.parts()]) + self.all_landmarks_5.append(landmark) + + return len(self.all_landmarks_5) + + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if self.det_model == 'dlib': + return self.get_face_landmarks_5_dlib(only_keep_largest) + + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_detector.detect_faces(input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + + def add_restored_face(self, restored_face, input_face=None): + # if self.is_gray: + # restored_face = bgr2gray(restored_face) # convert img into grayscale + # if input_face is not None: + # restored_face = adain_npy(restored_face, input_face) # transfer the color + self.restored_faces.append(restored_face) + + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + + inv_mask_borders = [] + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # if draw_box or not self.use_parse: # use square parse maps + # mask = np.ones(face_size, dtype=np.float32) + # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # # remove the black borders + # inv_mask_erosion = cv2.erode( + # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + # pasted_face = inv_mask_erosion[:, :, None] * inv_restored + # total_face_area = np.sum(inv_mask_erosion) # // 3 + # # add border + # if draw_box: + # h, w = face_size + # mask_border = np.ones((h, w, 3), dtype=np.float32) + # border = int(1400/np.sqrt(total_face_area)) + # mask_border[border:h-border, border:w-border,:] = 0 + # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + # inv_mask_borders.append(inv_mask_border) + # if not self.use_parse: + # # compute the fusion edge based on the area of face + # w_edge = int(total_face_area**0.5) // 20 + # erosion_radius = w_edge * 2 + # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + # blur_size = w_edge * 2 + # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + # if len(upsample_img.shape) == 2: # upsample_img is gray image + # upsample_img = upsample_img[:, :, None] + # inv_soft_mask = inv_soft_mask[:, :, None] + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400/np.sqrt(total_face_area)) + mask_border[border:h-border, border:w-border,:] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255. + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:,:,0] = 0 + img_color[:,:,1] = 255 + img_color[:,:,2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] \ No newline at end of file diff --git a/utils/file.py b/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..9582986ed26766bfe23288d2d21dbc98a3d8bd1b --- /dev/null +++ b/utils/file.py @@ -0,0 +1,80 @@ +import os +from typing import List, Tuple + +from urllib.parse import urlparse +from torch.hub import download_url_to_file, get_dir + + +def load_file_list(file_list_path: str) -> List[str]: + files = [] + # each line in file list contains a path of an image + with open(file_list_path, "r") as fin: + for line in fin: + path = line.strip() + if path: + files.append(path) + return files + + +def list_image_files( + img_dir: str, + exts: Tuple[str]=(".jpg", ".png", ".jpeg"), + follow_links: bool=False, + log_progress: bool=False, + log_every_n_files: int=10000, + max_size: int=-1 +) -> List[str]: + files = [] + for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links): + early_stop = False + file_names = sorted(file_names, key=lambda x: int(x.split('.')[0])) + for file_name in file_names: + if os.path.splitext(file_name)[1].lower() in exts: + if max_size >= 0 and len(files) >= max_size: + early_stop = True + break + files.append(os.path.join(dir_path, file_name)) + if log_progress and len(files) % log_every_n_files == 0: + print(f"find {len(files)} images in {img_dir}") + if early_stop: + break + return files + + +def get_file_name_parts(file_path: str) -> Tuple[str, str, str]: + parent_path, file_name = os.path.split(file_path) + stem, ext = os.path.splitext(file_name) + return parent_path, stem, ext + + +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/ +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file diff --git a/utils/flow_utils.py b/utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28a7169abcc9f57755378ebab13159ad83e94606 --- /dev/null +++ b/utils/flow_utils.py @@ -0,0 +1,291 @@ +import os +import sys + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import einops +from PIL import Image + +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +gmflow_dir = os.path.join(parent_dir, 'deps/gmflow') +sys.path.insert(0, gmflow_dir) + +from GMFlow.gmflow.gmflow import GMFlow # noqa: E702 E402 F401 +from GMFlow.utils.utils import InputPadder # noqa: E702 E402 + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def bilinear_sample(img, + sample_coords, + mode='bilinear', + padding_mode='zeros', + return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, + grid, + mode=mode, + padding_mode=padding_mode, + align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & ( + y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, + flow, + mask=False, + mode='bilinear', + padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, + grid, + mode=mode, + padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, + bwd_flow, + alpha=0.01, + beta=0.5, + return_confidence=False): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow + # (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, + dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + if return_confidence: + # fwd_occ = diff_fwd + # bwd_occ = diff_bwd + fwd_occ = torch.exp(-diff_fwd) + bwd_occ = torch.exp(-diff_bwd) + # import ipdb; ipdb.set_trace() + # Image.fromarray((bwd_occ * 255)[0,:,:].cpu().numpy().clip(0, 255).astype(np.uint8)).save("mask.png") + else: + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + + return fwd_occ, bwd_occ + + +@torch.no_grad() +def get_warped_and_mask(flow_model, + image1, + image2, + image3=None, + pixel_consistency=False, + return_confidence=False): + if image3 is None: + image3 = image1[None] + padder = InputPadder(image1.shape, padding_factor=16) + # image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + image1, image2 = padder.pad(image1[None], image2[None]) + results_dict = flow_model(image1, + image2, + attn_splits_list=[2], + corr_radius_list=[-1], + prop_radius_list=[-1], + pred_bidir_flow=True) + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] + bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] + + # results_dict_ = flow_model(image2, + # image1, + # attn_splits_list=[2], + # corr_radius_list=[-1], + # prop_radius_list=[-1], + # pred_bidir_flow=True) + # flow_pr = results_dict_['flow_preds'][-1] # [B, 2, H, W] + # fwd_flow_ = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] + # bwd_flow_ = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] + # fwd_occ_, bwd_occ_ = forward_backward_consistency_check( + # fwd_flow_, bwd_flow_, return_error=True) # [1, H, W] float + + fwd_occ, bwd_occ = forward_backward_consistency_check( + fwd_flow, bwd_flow) # [1, H, W] float + + if pixel_consistency: + warped_image1 = flow_warp(image1, padder.pad(bwd_flow)[0]) + bwd_occ = torch.clamp( + padder.pad(bwd_occ)[0] + + (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0, + 1) + warped_results = flow_warp(image3, bwd_flow) + if return_confidence: + fwd_err, bwd_err = forward_backward_consistency_check( + fwd_flow, bwd_flow, return_confidence=return_confidence) # [1, H, W] float + return warped_results, bwd_occ, bwd_flow, bwd_err + + return warped_results, bwd_occ, bwd_flow + + +class FlowCalc(): + + def __init__(self, model_path='./weights/gmflow_sintel-0c07dcb3.pth'): + flow_model = GMFlow( + feature_channels=128, + num_scales=1, + upsample_factor=8, + num_head=1, + attention_type='swin', + ffn_dim_expansion=4, + num_transformer_layers=6, + ).to('cuda') + + checkpoint = torch.load(model_path, + map_location=lambda storage, loc: storage) + weights = checkpoint['model'] if 'model' in checkpoint else checkpoint + flow_model.load_state_dict(weights, strict=False) + flow_model.eval() + self.model = flow_model + + @torch.no_grad() + def get_flow(self, image1, image2, save_path=None): + + if save_path is not None and os.path.exists(save_path): + bwd_flow = read_flow(save_path) + return bwd_flow + + image1 = torch.from_numpy(image1).permute(2, 0, 1).float() + image2 = torch.from_numpy(image2).permute(2, 0, 1).float() + padder = InputPadder(image1.shape, padding_factor=8) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + results_dict = self.model(image1, + image2, + attn_splits_list=[2], + corr_radius_list=[-1], + prop_radius_list=[-1], + pred_bidir_flow=True) + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] + bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] + fwd_occ, bwd_occ = forward_backward_consistency_check( + fwd_flow, bwd_flow) # [1, H, W] float + if save_path is not None: + flow_np = bwd_flow.cpu().numpy() + np.save(save_path, flow_np) + mask_path = os.path.splitext(save_path)[0] + '.png' + bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( + torch.long).numpy() * 255 + cv2.imwrite(mask_path, bwd_occ) + + return bwd_flow + + @torch.no_grad() + def get_mask(self, image1, image2, save_path=None): + + if save_path is not None: + mask_path = os.path.splitext(save_path)[0] + '.png' + if os.path.exists(mask_path): + return read_mask(mask_path) + + image1 = torch.from_numpy(image1).permute(2, 0, 1).float() + image2 = torch.from_numpy(image2).permute(2, 0, 1).float() + padder = InputPadder(image1.shape, padding_factor=8) + image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda()) + results_dict = self.model(image1, + image2, + attn_splits_list=[2], + corr_radius_list=[-1], + prop_radius_list=[-1], + pred_bidir_flow=True) + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0) # [1, 2, H, W] + bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0) # [1, 2, H, W] + fwd_occ, bwd_occ = forward_backward_consistency_check( + fwd_flow, bwd_flow) # [1, H, W] float + if save_path is not None: + flow_np = bwd_flow.cpu().numpy() + np.save(save_path, flow_np) + mask_path = os.path.splitext(save_path)[0] + '.png' + bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to( + torch.long).numpy() * 255 + cv2.imwrite(mask_path, bwd_occ) + + return bwd_occ + + def warp(self, img, flow, mode='bilinear'): + expand = False + if len(img.shape) == 2: + expand = True + img = np.expand_dims(img, 2) + + img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) + dtype = img.dtype + img = img.to(torch.float) + res = flow_warp(img, flow, mode=mode) + res = res.to(dtype) + res = res[0].cpu().permute(1, 2, 0).numpy() + if expand: + res = res[:, :, 0] + return res + + +def read_flow(save_path): + flow_np = np.load(save_path) + bwd_flow = torch.from_numpy(flow_np) + return bwd_flow + + +def read_mask(save_path): + mask_path = os.path.splitext(save_path)[0] + '.png' + mask = cv2.imread(mask_path) + mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) + return mask + + +flow_calc = FlowCalc() diff --git a/utils/helpers.py b/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9988e6a1e6803e8d1fbbbe39115d61bf1722eff3 --- /dev/null +++ b/utils/helpers.py @@ -0,0 +1,442 @@ +from typing import overload, Tuple, Optional + +import os +import cv2 +import torch +from torch import nn +import torch.nn.functional as F +import torchvision.transforms as T + +import numpy as np +from glob import glob +from PIL import Image +from einops import rearrange + +from model.cldm import ControlLDM +from model.gaussian_diffusion import Diffusion +from model.bsrnet import RRDBNet +from model.swinir import SwinIR +from model.scunet import SCUNet +from utils.sampler import SpacedSampler +from utils.cond_fn import Guidance +from utils.video_visualizer import VideoVisualizer +from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage + +import vidtome +from GMFlow.gmflow.gmflow import GMFlow +from utils.flow_utils import get_warped_and_mask + +def save_video(input_folder, out_path, output_name, fps=25): + video_visualizer = VideoVisualizer(path=os.path.join(out_path, output_name), + frame_size=None, + fps=fps) + input_folder = os.path.join(out_path, input_folder) + imgs = sorted([filename for filename in os.listdir(input_folder) if filename.endswith(('.png', '.jpg'))], key=lambda x: int(x.split('.')[0])) + for img in imgs: + img_pth = os.path.join(input_folder, img) + image = cv2.imread(img_pth) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + video_visualizer.add(image) + video_visualizer.save() + +def batch_bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: + + if scale != 1: + for i in range(img.shape[0]): + img[i] = bicubic_resize(img[i], scale) + # pil = Image.fromarray(img) + # res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) + return img + +def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: + + if scale != 1: + pil = Image.fromarray(img) + res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) + return np.array(res) + + +def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor: + _, _, h, w = imgs.size() + if h == w: + new_h, new_w = size, size + elif h < w: + new_h, new_w = size, int(w * (size / h)) + else: + new_h, new_w = int(h * (size / w)), size + return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True) + + +def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor: + _, _, h, w = imgs.size() + if h % multiple == 0 and w % multiple == 0: + return imgs.clone() + # get_pad = lambda x: (x // multiple + 1) * multiple - x + get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x + ph, pw = get_pad(h), get_pad(w) + return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0) + + +class Pipeline: + + def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: + self.stage1_model = stage1_model + self.cldm = cldm + self.diffusion = diffusion + self.cond_fn = cond_fn + self.device = device + self.final_size: Tuple[int] = None + + def set_final_size(self, lq: torch.Tensor) -> None: + h, w = lq.shape[2:] + self.final_size = (h, w) + + @overload + def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: + ... + + @count_vram_usage + def run_stage2( + self, + clean: torch.Tensor, + steps: int, + strength: float, + tiled: bool, + tile_size: int, + tile_stride: int, + pos_prompt: str, + neg_prompt: str, + cfg_scale: float, + better_start: float, + index: int = 0, + input: str = None + ) -> torch.Tensor: + ### preprocess + bs, _, ori_h, ori_w = clean.shape + # pad: ensure that height & width are multiples of 64 + pad_clean = pad_to_multiples_of(clean, multiple=64) + h, w = pad_clean.shape[2:] + if self.cldm.controller is not None: + self.cldm.controller.cldm = self.cldm + self.cldm.controller.non_pad_ratio = (ori_h / h, ori_w / w) + self.cldm.vae.decoder.controller = self.cldm.controller + # prepare conditon + if not tiled: + cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs) + uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs) + else: + cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride) + uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride) + if self.cond_fn: + self.cond_fn.load_target(pad_clean * 2 - 1) + old_control_scales = self.cldm.control_scales + self.cldm.control_scales = [strength] * 13 + if better_start: + # using noised low frequency part of condition as a better start point of + # reverse sampling, which can prevent our model from generating noise in + # image background. + _, low_freq = wavelet_decomposition(pad_clean) + # low_freq = pad_clean + if not tiled: + x_0 = self.cldm.vae_encode(low_freq, batch_size=5) + else: + x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride) + x_T = self.diffusion.q_sample( + x_0, + torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device), + torch.randn(x_0.shape, dtype=torch.float32, device=self.device) + ) + # print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}") + else: + if self.cldm.latent_control: + print(f"[INFO] random initialize {bs} same latents") + x_T = 1 * torch.randn((1, 4, h // 8, w // 8), dtype=torch.float32, device=self.device) + x_T = x_T.repeat(bs, 1, 1, 1) + else: + print(f"[INFO] random initialize {bs} latents") + x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device) + ''' loaded latents ''' + # t = 981 + # latent_fname = f'noisy_latents_{t}.pt' + # # model_key = config.model_key.split('/')[-1] + # model_key = "stable-diffusion-2-1-base" + # inversion_path = os.path.join("latents", os.path.basename(input), "latents") + # # outputs/bear_4_BD/latents/stable-diffusion-v1-5/noisy_latents_981.pt + # lp = os.path.join(inversion_path, model_key, latent_fname) + # latents = torch.load(lp) + + # # init_noise = latents.to(dtype).to(args.device) + # x_T = latents[index][None].to(torch.float32).to(self.device) + # print(f"[INFO] loaded latents[{index}]") + ''' loaded latent ended ''' + ### run sampler + sampler = SpacedSampler(self.diffusion.betas) + z = sampler.sample( + model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8), + cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True, + progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + non_pad_ratio=(ori_h / h, ori_w / w) + ) + if not tiled: + if ori_w > 1500: + x = self.cldm.vae_decode(z, batch_size=2) + else: + x = self.cldm.vae_decode(z, batch_size=5) + else: + x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8) + ### postprocess + self.cldm.control_scales = old_control_scales + sample = x[:, :, :ori_h, :ori_w] + return sample + + @torch.no_grad() + def run( + self, + lq: np.ndarray, + steps: int, + strength: float, + tiled: bool, + tile_size: int, + tile_stride: int, + pos_prompt: str, + neg_prompt: str, + cfg_scale: float, + better_start: bool, + index: int = 0, + input: str = None, + final_size: Tuple[int] = None, + flow_model: GMFlow = None, + hq: np.ndarray = None + ) -> np.ndarray: + # image to tensor + lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) + lq = rearrange(lq, "n h w c -> n c h w").contiguous() + # set pipeline output size + if final_size is None: + self.set_final_size(lq) + else: + self.final_size = final_size + + clean = self.run_stage1(lq) + print(f"[INFO] {clean.shape}") + # import ipdb; ipdb.set_trace() + # clean = F.interpolate(lq, size=clean.shape[-2:], mode='bicubic', align_corners=False) + ''' hq flow & occlusion mask ''' + # hq = torch.tensor((hq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) + # hq = rearrange(hq, "n h w c -> n c h w").contiguous() + # hq = resize_short_edge_to(hq, size=512) + # pre_keyframe_lq = None + + # if self.cldm.controller is not None and \ + # self.cldm.controller.step_store["pre_keyframe_lq"] is not None: + + # pre_keyframe_lq = self.cldm.controller.step_store["pre_keyframe_lq"] + # pre_keyframe_lq = torch.tensor((pre_keyframe_lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) + # pre_keyframe_lq = rearrange(pre_keyframe_lq, "n h w c -> n c h w").contiguous() + # pre_keyframe_lq = resize_short_edge_to(pre_keyframe_lq, size=512) + # pre_keyframe_clean = pre_keyframe_lq[0] + # # pre_keyframe_clean = self.run_stage1(pre_keyframe_lq)[0] + + # flows, masks, confids = [], [], [] + # mid = lq.shape[0] // 2 + # for k in range(lq.shape[0]): + # if k == mid: + # if pre_keyframe_lq is not None: + # tar_img = (torch.clamp(hq[mid], 0 ,1) * 255).float().to(self.device) + # src_img = (torch.clamp(pre_keyframe_clean, 0 ,1) * 255).float().to(self.device) + # else: + # flows.append(None) + # masks.append(None) + # confids.append(None) + # continue + # else: + # tar_img = (torch.clamp(hq[k], 0 ,1) * 255).float().to(self.device) + # src_img = (torch.clamp(hq[mid], 0 ,1) * 255).float().to(self.device) + # # tar_img = stage1_x[0].float().to(args.device) + # _, bwd_occ, bwd_flow, bwd_confid = get_warped_and_mask( + # flow_model, src_img, tar_img, image3=None, pixel_consistency=False, return_confidence=True) + # blend_mask = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))( + # F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) + + # blend_mask = torch.clamp(blend_mask + bwd_occ, 0, 1) + # blend_mask = 1 - F.max_pool2d(blend_mask, kernel_size=8) + + # bwd_confid = F.max_pool2d(bwd_confid, kernel_size=8) + + # bwd_flow = F.interpolate(bwd_flow / 8.0, scale_factor=1. / 8, mode='bilinear') + + # # _, _, h, w = bwd_flow.shape + # # bwd_flow = pad_to_multiples_of(bwd_flow, 8) + # # padding_ratio = w / bwd_flow.shape[3] + # blend_mask = pad_to_multiples_of(blend_mask[None], 8)[0] + # # bwd_confid = pad_to_multiples_of(bwd_confid[None], 8)[0] + # flows.append(bwd_flow) + # masks.append(blend_mask) + # confids.append(bwd_confid) + + # if self.cldm.controller is not None: + # self.cldm.controller.set_warp(flows, masks, flow_confids=confids) + + ''' flow & occlusion mask ''' + pre_keyframe_lq = None + + if self.cldm.controller is not None and \ + self.cldm.controller.step_store["pre_keyframe_lq"] is not None: + + pre_keyframe_lq = self.cldm.controller.step_store["pre_keyframe_lq"] + pre_keyframe_lq = torch.tensor((pre_keyframe_lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) + pre_keyframe_lq = rearrange(pre_keyframe_lq, "n h w c -> n c h w").contiguous() + pre_keyframe_clean = self.run_stage1(pre_keyframe_lq)[0] + + flows, masks, confids = [], [], [] + flows2, confids2 = [], [] + mid = lq.shape[0] // 2 + for k in range(lq.shape[0]): + if k == mid: + if pre_keyframe_lq is not None: + tar_img = (torch.clamp(clean[mid], 0 ,1) * 255).float().to(self.device) + src_img = (torch.clamp(pre_keyframe_clean, 0 ,1) * 255).float().to(self.device) + else: + flows.append(None) + masks.append(None) + confids.append(None) + continue + else: + tar_img = (torch.clamp(clean[k], 0 ,1) * 255).float().to(self.device) + src_img = (torch.clamp(clean[mid], 0 ,1) * 255).float().to(self.device) + # tar_img = stage1_x[0].float().to(args.device) + _, bwd_occ, bwd_flow, bwd_confid = get_warped_and_mask( + flow_model, src_img, tar_img, image3=None, pixel_consistency=False, return_confidence=True) + blend_mask = T.GaussianBlur(kernel_size=(9, 9), sigma=(18, 18))( + F.max_pool2d(bwd_occ, kernel_size=9, stride=1, padding=4)) + blend_mask = torch.clamp(blend_mask + bwd_occ, 0, 1) + blend_mask = 1 - F.max_pool2d(blend_mask, kernel_size=8) + + blend_mask = 1 - F.max_pool2d(bwd_occ, kernel_size=8) + + bwd_confid2 = F.max_pool2d(bwd_confid, kernel_size=16) + bwd_flow2 = F.interpolate(bwd_flow / 16.0, scale_factor=1. / 16, mode='bilinear') + + + bwd_confid = F.max_pool2d(bwd_confid, kernel_size=8) + bwd_flow = F.interpolate(bwd_flow / 8.0, scale_factor=1. / 8, mode='bilinear') + + # _, _, h, w = bwd_flow.shape + # bwd_flow = pad_to_multiples_of(bwd_flow, 8) + # padding_ratio = w / bwd_flow.shape[3] + blend_mask = pad_to_multiples_of(blend_mask[None], 8)[0] + # bwd_confid = pad_to_multiples_of(bwd_confid[None], 8)[0] + flows.append(bwd_flow) + masks.append(blend_mask) + confids.append(bwd_confid) + + flows2.append(bwd_flow2) + confids2.append(bwd_confid2) + + if self.cldm.controller is not None: + self.cldm.controller.set_warp(flows, masks, flow_confids=confids) + # import ipdb; ipdb.set_trace() + _, H, W = confids[0].shape + self.cldm.controller.set_flow_correspondence(lq.shape[0], H, W, lq.shape[0] // 2, confids, flows) + _, H, W = confids2[0].shape + self.cldm.controller.set_flow_correspondence(lq.shape[0], H, W, lq.shape[0] // 2, confids2, flows2) + for j, flow in enumerate(self.cldm.controller.step_store["flows"]): + if flow is not None: + self.cldm.controller.step_store["flows"][j] = pad_to_multiples_of(self.cldm.controller.step_store["flows"][j], 8) + # self.cldm.controller.set_warp2(flows2, confids2) + ''' flow & occlusion mask ended ''' + + + sample = self.run_stage2( + clean, steps, strength, tiled, tile_size, tile_stride, + pos_prompt, neg_prompt, cfg_scale, better_start, + index=index, input=input + ) + + if self.cldm.controller is not None: + print(f"[INFO] clearing controller correspondence scores ... ") + self.cldm.controller.step_store["corres_scores"] = None + # colorfix (borrowed from StableSR, thanks for their work) + sample = (sample + 1) / 2 + sample = wavelet_reconstruction(sample, clean) + # resize to desired output size + sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True) + clean = F.interpolate(clean, size=self.final_size, mode="bilinear", antialias=True) + # tensor to image + sample = rearrange(sample * 255., "n c h w -> n h w c") + sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() + clean = rearrange(clean * 255., "n c h w -> n h w c") + clean = clean.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() + return sample, clean + + +class BSRNetPipeline(Pipeline): + + def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None: + super().__init__(bsrnet, cldm, diffusion, cond_fn, device) + self.upscale = upscale + + def set_final_size(self, lq: torch.Tensor) -> None: + h, w = lq.shape[2:] + self.final_size = (int(h * self.upscale), int(w * self.upscale)) + + @count_vram_usage + def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: + # NOTE: upscale is always set to 4 in our experiments + if lq.shape[-2] > 1000: + clean = [] + for i in range(lq.shape[0]): + torch.cuda.empty_cache() + clean.append(self.stage1_model(lq[i:i+1])) + clean = torch.cat(clean, dim=0) + else: + clean = self.stage1_model(lq) + # if self.final_size[0] < 512 and self.final_size[1] < 512: + if min(self.final_size) < 512: + clean = resize_short_edge_to(clean, size=512) + else: + clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True) + return clean + + +class SwinIRPipeline(Pipeline): + + def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: + super().__init__(swinir, cldm, diffusion, cond_fn, device) + + @count_vram_usage + def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: + # NOTE: lq size is always equal to 512 in our experiments + # resize: ensure the input lq size is as least 512, since SwinIR is trained on 512 resolution + if min(lq.shape[2:]) < 512: + lq = resize_short_edge_to(lq, size=512) + ori_h, ori_w = lq.shape[2:] + # pad: ensure that height & width are multiples of 64 + pad_lq = pad_to_multiples_of(lq, multiple=64) + # run + clean = self.stage1_model(pad_lq) + # remove padding + clean = clean[:, :, :ori_h, :ori_w] + return clean + + +class SCUNetPipeline(Pipeline): + + def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: + super().__init__(scunet, cldm, diffusion, cond_fn, device) + + @count_vram_usage + def run_stage1(self, lq: torch.Tensor) -> torch.Tensor: + if lq.shape[-1] > 1500: + clean = [] + batch_lq = lq.split(2, dim=0) + for lq_ in batch_lq: + clean.append(self.stage1_model(lq_)) + torch.cuda.empty_cache() + clean = torch.cat(clean) + else: + clean = self.stage1_model(lq) + if min(clean.shape[2:]) < 512: + clean = resize_short_edge_to(clean, size=512) + # import ipdb; ipdb.set_trace() + return clean \ No newline at end of file diff --git a/utils/image/__init__.py b/utils/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbe064d019e3f4ddc9f041d80ed4e23cbf9d643 --- /dev/null +++ b/utils/image/__init__.py @@ -0,0 +1,26 @@ +from .diffjpeg import DiffJPEG +from .usm_sharp import USMSharp +from .common import ( + random_crop_arr, center_crop_arr, augment, + filter2D, rgb2ycbcr_pt, auto_resize, pad +) +from .align_color import ( + wavelet_reconstruction, adaptive_instance_normalization +) + +__all__ = [ + "DiffJPEG", + + "USMSharp", + + "random_crop_arr", + "center_crop_arr", + "augment", + "filter2D", + "rgb2ycbcr_pt", + "auto_resize", + "pad", + + "wavelet_reconstruction", + "adaptive_instance_normalization" +] diff --git a/utils/image/__pycache__/__init__.cpython-39.pyc b/utils/image/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2673ae633ba44b3b3a0f642eaf699f528fe50032 Binary files /dev/null and b/utils/image/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/image/__pycache__/align_color.cpython-39.pyc b/utils/image/__pycache__/align_color.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0c292b39511061eb4c0dea3e1352ada20c646d Binary files /dev/null and b/utils/image/__pycache__/align_color.cpython-39.pyc differ diff --git a/utils/image/__pycache__/common.cpython-39.pyc b/utils/image/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..becf2bd103747cbf2d664b17847788243757dc9b Binary files /dev/null and b/utils/image/__pycache__/common.cpython-39.pyc differ diff --git a/utils/image/__pycache__/diffjpeg.cpython-39.pyc b/utils/image/__pycache__/diffjpeg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11c5c06a800a1d51b7af19c5e7269bbfb5aa9865 Binary files /dev/null and b/utils/image/__pycache__/diffjpeg.cpython-39.pyc differ diff --git a/utils/image/__pycache__/usm_sharp.cpython-39.pyc b/utils/image/__pycache__/usm_sharp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..308914a36800ae3040a99e73b4fca597189283b1 Binary files /dev/null and b/utils/image/__pycache__/usm_sharp.cpython-39.pyc differ diff --git a/utils/image/align_color.py b/utils/image/align_color.py new file mode 100644 index 0000000000000000000000000000000000000000..36da591b240b1daefbcbc3ca81c202e31c74b650 --- /dev/null +++ b/utils/image/align_color.py @@ -0,0 +1,119 @@ +''' +# -------------------------------------------------------------------------------- +# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) +# -------------------------------------------------------------------------------- +''' + +import torch +from PIL import Image +from torch import Tensor +from torch.nn import functional as F +from torchvision.transforms import ToTensor, ToPILImage + + +def adain_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply adaptive instance normalization + result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def calc_mean_std(feat: Tensor, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.reshape(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().reshape(b, c, 1, 1) + feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq \ No newline at end of file diff --git a/utils/image/common.py b/utils/image/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8ca3ce954778f0554968595446dcbe2a3455bb --- /dev/null +++ b/utils/image/common.py @@ -0,0 +1,235 @@ +import random +import math + +from PIL import Image +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py +def center_crop_arr(pil_image, image_size): + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] + + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py +def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): + min_smaller_dim_size = math.ceil(image_size / max_crop_frac) + max_smaller_dim_size = math.ceil(image_size / min_crop_frac) + smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) + + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * smaller_dim_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = smaller_dim_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = random.randrange(arr.shape[0] - image_size + 1) + crop_x = random.randrange(arr.shape[1] - image_size + 1) + return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] + + +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/data/transforms.py +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +# https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/utils/color_util.py#L186 +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img + + +def to_pil_image(inputs, mem_order, val_range, channel_order): + # convert inputs to numpy array + if isinstance(inputs, torch.Tensor): + inputs = inputs.cpu().numpy() + assert isinstance(inputs, np.ndarray) + + # make sure that inputs is a 4-dimension array + if mem_order in ["hwc", "chw"]: + inputs = inputs[None, ...] + mem_order = f"n{mem_order}" + # to NHWC + if mem_order == "nchw": + inputs = inputs.transpose(0, 2, 3, 1) + # to RGB + if channel_order == "bgr": + inputs = inputs[..., ::-1].copy() + else: + assert channel_order == "rgb" + + if val_range == "0,1": + inputs = inputs * 255 + elif val_range == "-1,1": + inputs = (inputs + 1) * 127.5 + else: + assert val_range == "0,255" + + inputs = inputs.clip(0, 255).astype(np.uint8) + return [inputs[i] for i in range(len(inputs))] + + +def put_text(pil_img_arr, text): + cv_img = pil_img_arr[..., ::-1].copy() + cv2.putText(cv_img, text, (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + return cv_img[..., ::-1].copy() + + +def auto_resize(img: Image.Image, size: int) -> Image.Image: + short_edge = min(img.size) + if short_edge < size: + r = size / short_edge + img = img.resize( + tuple(math.ceil(x * r) for x in img.size), Image.BICUBIC + ) + else: + # make a deep copy of this image for safety + img = img.copy() + return img + + +def pad(img: np.ndarray, scale: int) -> np.ndarray: + h, w = img.shape[:2] + ph = 0 if h % scale == 0 else math.ceil(h / scale) * scale - h + pw = 0 if w % scale == 0 else math.ceil(w / scale) * scale - w + return np.pad( + img, pad_width=((0, ph), (0, pw), (0, 0)), mode="constant", + constant_values=0 + ) diff --git a/utils/image/diffjpeg.py b/utils/image/diffjpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..339282170b3a58e3bc9764c3cec0c4cfc9cd9816 --- /dev/null +++ b/utils/image/diffjpeg.py @@ -0,0 +1,492 @@ +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/diffjpeg.py +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered diff --git a/utils/image/usm_sharp.py b/utils/image/usm_sharp.py new file mode 100644 index 0000000000000000000000000000000000000000..7b83532e89204e8f88cc0e120e7d9021b489be90 --- /dev/null +++ b/utils/image/usm_sharp.py @@ -0,0 +1,29 @@ +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py +import cv2 +import numpy as np +import torch + +from .common import filter2D + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..aad569867e9235e9f287d40950dfac0f83ccbcb2 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,350 @@ +# python3.7 +"""Contains utility functions for image processing. + +The module is primarily built on `cv2`. But, differently, we assume all colorful +images are with `RGB` channel order by default. Also, we assume all gray-scale +images to be with shape [height, width, 1]. +""" + +import os +import cv2 +import numpy as np + +# File extensions regarding images (not including GIFs). +IMAGE_EXTENSIONS = ( + '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', + '.tiff', '.tif' +) + +def check_file_ext(filename, *ext_list): + """Checks whether the given filename is with target extension(s). + + NOTE: If `ext_list` is empty, this function will always return `False`. + + Args: + filename: Filename to check. + *ext_list: A list of extensions. + + Returns: + `True` if the filename is with one of extensions in `ext_list`, + otherwise `False`. + """ + if len(ext_list) == 0: + return False + ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] + ext_list = [ext.lower() for ext in ext_list] + basename = os.path.basename(filename) + ext = os.path.splitext(basename)[1].lower() + return ext in ext_list + + +def _check_2d_image(image): + """Checks whether a given image is valid. + + A valid image is expected to be with dtype `uint8`. Also, it should have + shape like: + + (1) (height, width, 1) # gray-scale image. + (2) (height, width, 3) # colorful image. + (3) (height, width, 4) # colorful image with transparency (RGBA) + """ + assert isinstance(image, np.ndarray) + assert image.dtype == np.uint8 + assert image.ndim == 3 and image.shape[2] in [1, 3, 4] + + +def get_blank_image(height, width, channels=3, use_black=True): + """Gets a blank image, either white of black. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + height: Height of the returned image. + width: Width of the returned image. + channels: Number of channels. (default: 3) + use_black: Whether to return a black image. (default: True) + """ + shape = (height, width, channels) + if use_black: + return np.zeros(shape, dtype=np.uint8) + return np.ones(shape, dtype=np.uint8) * 255 + + +def load_image(path): + """Loads an image from disk. + + NOTE: This function will always return an image with `RGB` channel order for + color image and pixel range [0, 255]. + + Args: + path: Path to load the image from. + + Returns: + An image with dtype `np.ndarray`, or `None` if `path` does not exist. + """ + image = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if image is None: + return None + + if image.ndim == 2: + image = image[:, :, np.newaxis] + _check_2d_image(image) + if image.shape[2] == 3: + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image.shape[2] == 4: + return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image + + +def save_image(path, image): + """Saves an image to disk. + + NOTE: The input image (if colorful) is assumed to be with `RGB` channel + order and pixel range [0, 255]. + + Args: + path: Path to save the image to. + image: Image to save. + """ + if image is None: + return + + _check_2d_image(image) + if image.shape[2] == 1: + cv2.imwrite(path, image) + elif image.shape[2] == 3: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + elif image.shape[2] == 4: + cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) + + +def resize_image(image, *args, **kwargs): + """Resizes image. + + This is a wrap of `cv2.resize()`. + + NOTE: The channel order of the input image will not be changed. + + Args: + image: Image to resize. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + An image with dtype `np.ndarray`, or `None` if `image` is empty. + """ + if image is None: + return None + + _check_2d_image(image) + if image.shape[2] == 1: # Re-expand the squeezed dim of gray-scale image. + return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] + return cv2.resize(image, *args, **kwargs) + + +def add_text_to_image(image, + text='', + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255)): + """Overlays text on given image. + + NOTE: The input image is assumed to be with `RGB` channel order. + + Args: + image: The image to overlay text on. + text: Text content to overlay on the image. (default: empty) + position: Target position (bottom-left corner) to add text. If not set, + center of the image will be used by default. (default: None) + font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) + font_size: Font size of the text added. (default: 1.0) + line_type: Line type used to depict the text. (default: cv2.LINE_8) + line_width: Line width used to depict the text. (default: 1) + color: Color of the text added in `RGB` channel order. (default: + (255, 255, 255)) + + Returns: + An image with target text overlaid on. + """ + if image is None or not text: + return image + + _check_2d_image(image) + cv2.putText(img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False) + return image + + +def preprocess_image(image, min_val=-1.0, max_val=1.0): + """Pre-processes image by adjusting the pixel range and to dtype `float32`. + + This function is particularly used to convert an image or a batch of images + to `NCHW` format, which matches the data type commonly used in deep models. + + NOTE: The input image is assumed to be with pixel range [0, 255] and with + format `HWC` or `NHWC`. The returned image will be always be with format + `NCHW`. + + Args: + image: The input image for pre-processing. + min_val: Minimum value of the output image. + max_val: Maximum value of the output image. + + Returns: + The pre-processed image. + """ + assert isinstance(image, np.ndarray) + + image = image.astype(np.float64) + image = image / 255.0 * (max_val - min_val) + min_val + + if image.ndim == 3: + image = image[np.newaxis] + assert image.ndim == 4 and image.shape[3] in [1, 3, 4] + return image.transpose(0, 3, 1, 2) + + +def postprocess_image(image, min_val=-1.0, max_val=1.0): + """Post-processes image to pixel range [0, 255] with dtype `uint8`. + + This function is particularly used to handle the results produced by deep + models. + + NOTE: The input image is assumed to be with format `NCHW`, and the returned + image will always be with format `NHWC`. + + Args: + image: The input image for post-processing. + min_val: Expected minimum value of the input image. + max_val: Expected maximum value of the input image. + + Returns: + The post-processed image. + """ + assert isinstance(image, np.ndarray) + + image = image.astype(np.float64) + image = (image - min_val) / (max_val - min_val) * 255 + image = np.clip(image + 0.5, 0, 255).astype(np.uint8) + + assert image.ndim == 4 and image.shape[1] in [1, 3, 4] + return image.transpose(0, 2, 3, 1) + + +def parse_image_size(obj): + """Parses an object to a pair of image size, i.e., (height, width). + + Args: + obj: The input object to parse image size from. + + Returns: + A two-element tuple, indicating image height and width respectively. + + Raises: + If the input is invalid, i.e., neither a list or tuple, nor a string. + """ + if obj is None or obj == '': + height = 0 + width = 0 + elif isinstance(obj, int): + height = obj + width = obj + elif isinstance(obj, (list, tuple, str, np.ndarray)): + if isinstance(obj, str): + splits = obj.replace(' ', '').split(',') + numbers = tuple(map(int, splits)) + else: + numbers = tuple(obj) + if len(numbers) == 0: + height = 0 + width = 0 + elif len(numbers) == 1: + height = int(numbers[0]) + width = int(numbers[0]) + elif len(numbers) == 2: + height = int(numbers[0]) + width = int(numbers[1]) + else: + raise ValueError('At most two elements for image size.') + else: + raise ValueError(f'Invalid type of input: `{type(obj)}`!') + + return (max(0, height), max(0, width)) + + +def get_grid_shape(size, height=0, width=0, is_portrait=False): + """Gets the shape of a grid based on the size. + + This function makes greatest effort on making the output grid square if + neither `height` nor `width` is set. If `is_portrait` is set as `False`, the + height will always be equal to or smaller than the width. For example, if + input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, + output shape will be (3, 5). Otherwise, the height will always be equal to + or larger than the width. + + Args: + size: Size (height * width) of the target grid. + height: Expected height. If `size % height != 0`, this field will be + ignored. (default: 0) + width: Expected width. If `size % width != 0`, this field will be + ignored. (default: 0) + is_portrait: Whether to return a portrait size of a landscape size. + (default: False) + + Returns: + A two-element tuple, representing height and width respectively. + """ + assert isinstance(size, int) + assert isinstance(height, int) + assert isinstance(width, int) + if size <= 0: + return (0, 0) + + if height > 0 and width > 0 and height * width != size: + height = 0 + width = 0 + + if height > 0 and width > 0 and height * width == size: + return (height, width) + if height > 0 and size % height == 0: + return (height, size // height) + if width > 0 and size % width == 0: + return (size // width, width) + + height = int(np.sqrt(size)) + while height > 0: + if size % height == 0: + width = size // height + break + height = height - 1 + + return (width, height) if is_portrait else (height, width) + + +def list_images_from_dir(directory): + """Lists all images from the given directory. + + NOTE: Do NOT support finding images recursively. + + Args: + directory: The directory to find images from. + + Returns: + A list of sorted filenames, with the directory as prefix. + """ + image_list = [] + for filename in os.listdir(directory): + if check_file_ext(filename, *IMAGE_EXTENSIONS): + image_list.append(os.path.join(directory, filename)) + return sorted(image_list) diff --git a/utils/inference.py b/utils/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f8dc0ced1f5577d38c0cab6f80a99f18866680 --- /dev/null +++ b/utils/inference.py @@ -0,0 +1,321 @@ +import os +from typing import overload, Generator, Dict +from argparse import Namespace + +import numpy as np +import torch +from PIL import Image +from omegaconf import OmegaConf + +from model.cldm import ControlLDM +from model.gaussian_diffusion import Diffusion +from model.bsrnet import RRDBNet +from model.scunet import SCUNet +from model.swinir import SwinIR +from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage +from utils.face_restoration_helper import FaceRestoreHelper +from utils.helpers import ( + Pipeline, + BSRNetPipeline, SwinIRPipeline, SCUNetPipeline, + bicubic_resize +) +from utils.cond_fn import MSEGuidance, WeightedMSEGuidance + + +MODELS = { + ### stage_1 model weights + "bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth", + # the following checkpoint is up-to-date, but we use the old version in our paper + # "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth", + "swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt", + "scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth", + "swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt", + ### stage_2 model weights + "sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt", + "v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth", + "v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth", + "v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth" +} + + +def load_model_from_url(url: str) -> Dict[str, torch.Tensor]: + sd_path = load_file_from_url(url, model_dir="weights") + sd = torch.load(sd_path, map_location="cpu") + if "state_dict" in sd: + sd = sd["state_dict"] + if list(sd.keys())[0].startswith("module"): + sd = {k[len("module."):]: v for k, v in sd.items()} + return sd + + +class InferenceLoop: + + def __init__(self, args: Namespace) -> "InferenceLoop": + self.args = args + self.loop_ctx = {} + self.pipeline: Pipeline = None + self.init_stage1_model() + self.init_stage2_model() + self.init_cond_fn() + self.init_pipeline() + + @overload + def init_stage1_model(self) -> None: + ... + + @count_vram_usage + def init_stage2_model(self) -> None: + ### load uent, vae, clip + self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/my_cldm.yaml")) + sd = load_model_from_url(MODELS["sd_v21"]) + unused = self.cldm.load_pretrained_sd(sd) + print(f"strictly load pretrained sd_v2.1, unused weights: {unused}") + ### load controlnet + if self.args.version == "v1": + if self.args.task == "fr": + control_sd = load_model_from_url(MODELS["v1_face"]) + elif self.args.task == "sr": + control_sd = load_model_from_url(MODELS["v1_general"]) + else: + raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'") + else: + control_sd = load_model_from_url(MODELS["v2"]) + self.cldm.load_controlnet_from_ckpt(control_sd) + print(f"strictly load controlnet weight") + self.cldm.eval().to(self.args.device) + ### load diffusion + self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml")) + self.diffusion.to(self.args.device) + + def init_cond_fn(self) -> None: + if not self.args.guidance: + self.cond_fn = None + return + if self.args.g_loss == "mse": + cond_fn_cls = MSEGuidance + elif self.args.g_loss == "w_mse": + cond_fn_cls = WeightedMSEGuidance + else: + raise ValueError(self.args.g_loss) + self.cond_fn = cond_fn_cls( + scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop, + space=self.args.g_space, repeat=self.args.g_repeat + ) + + @overload + def init_pipeline(self) -> None: + ... + + def setup(self) -> None: + self.output_dir = self.args.output + os.makedirs(self.output_dir, exist_ok=True) + + def lq_loader(self) -> Generator[np.ndarray, None, None]: + img_exts = [".png", ".jpg", ".jpeg"] + if os.path.isdir(self.args.input): + file_names = sorted([ + file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts + ]) + file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names] + else: + assert os.path.splitext(self.args.input)[-1] in img_exts + file_paths = [self.args.input] + + def _loader() -> Generator[np.ndarray, None, None]: + for file_path in file_paths: + ### load lq + lq = np.array(Image.open(file_path).convert("RGB")) + print(f"load lq: {file_path}") + ### set context for saving results + self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0] + for i in range(self.args.n_samples): + self.loop_ctx["repeat_idx"] = i + yield lq + + return _loader + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + return lq + + @torch.no_grad() + def run(self) -> None: + self.setup() + # We don't support batch processing since input images may have different size + loader = self.lq_loader() + for i, lq in enumerate(loader()): + lq = self.after_load_lq(lq) + sample = self.pipeline.run( + lq[None], self.args.steps, 1.0, self.args.tiled, + self.args.tile_size, self.args.tile_stride, + self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale, + self.args.better_start, + index=i, input=self.args.input + )[0] + self.save(sample) + + def save(self, sample: np.ndarray) -> None: + file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"] + file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png" + save_path = os.path.join(self.args.output, file_name) + Image.fromarray(sample).save(save_path) + print(f"save result to {save_path}") + + +class BSRInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml")) + sd = load_model_from_url(MODELS["bsrnet"]) + self.bsrnet.load_state_dict(sd, strict=True) + self.bsrnet.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale) + + +class BFRInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml")) + sd = load_model_from_url(MODELS["swinir_face"]) + self.swinir_face.load_state_dict(sd, strict=True) + self.swinir_face.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device) + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + # For BFR task, super resolution is achieved by directly upscaling lq + return bicubic_resize(lq, self.args.upscale) + + +class BIDInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.scunet_psnr: SCUNet = instantiate_from_config(OmegaConf.load("configs/inference/scunet.yaml")) + sd = load_model_from_url(MODELS["scunet_psnr"]) + self.scunet_psnr.load_state_dict(sd, strict=True) + self.scunet_psnr.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = SCUNetPipeline(self.scunet_psnr, self.cldm, self.diffusion, self.cond_fn, self.args.device) + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + # For BID task, super resolution is achieved by directly upscaling lq + return bicubic_resize(lq, self.args.upscale) + + +class V1InferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.swinir: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml")) + if self.args.task == "fr": + sd = load_model_from_url(MODELS["swinir_face"]) + elif self.args.task == "sr": + sd = load_model_from_url(MODELS["swinir_general"]) + else: + raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'") + self.swinir.load_state_dict(sd, strict=True) + self.swinir.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipeline = SwinIRPipeline(self.swinir, self.cldm, self.diffusion, self.cond_fn, self.args.device) + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + # For BFR task, super resolution is achieved by directly upscaling lq + return bicubic_resize(lq, self.args.upscale) + + +class UnAlignedBFRInferenceLoop(InferenceLoop): + + @count_vram_usage + def init_stage1_model(self) -> None: + self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml")) + sd = load_model_from_url(MODELS["bsrnet"]) + self.bsrnet.load_state_dict(sd, strict=True) + self.bsrnet.eval().to(self.args.device) + + self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml")) + sd = load_model_from_url(MODELS["swinir_face"]) + self.swinir_face.load_state_dict(sd, strict=True) + self.swinir_face.eval().to(self.args.device) + + def init_pipeline(self) -> None: + self.pipes = { + "bg": BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale), + "face": SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device) + } + self.pipeline = self.pipes["face"] + + def setup(self) -> None: + super().setup() + self.cropped_face_dir = os.path.join(self.args.output, "cropped_faces") + os.makedirs(self.cropped_face_dir, exist_ok=True) + self.restored_face_dir = os.path.join(self.args.output, "restored_faces") + os.makedirs(self.restored_face_dir, exist_ok=True) + self.restored_bg_dir = os.path.join(self.args.output, "restored_backgrounds") + os.makedirs(self.restored_bg_dir, exist_ok=True) + + def lq_loader(self) -> Generator[np.ndarray, None, None]: + base_loader = super().lq_loader() + self.face_helper = FaceRestoreHelper( + device=self.args.device, + upscale_factor=1, + face_size=512, + use_parse=True, + det_model="retinaface_resnet50" + ) + + def _loader() -> Generator[np.ndarray, None, None]: + for lq in base_loader(): + ### set input image + self.face_helper.clean_all() + upscaled_bg = bicubic_resize(lq, self.args.upscale) + self.face_helper.read_image(upscaled_bg) + ### get face landmarks for each face + self.face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() + print(f"detect {len(self.face_helper.cropped_faces)} faces") + ### restore each face (has been upscaeled) + for i, lq_face in enumerate(self.face_helper.cropped_faces): + self.loop_ctx["is_face"] = True + self.loop_ctx["face_idx"] = i + self.loop_ctx["cropped_face"] = lq_face + yield lq_face + ### restore background (hasn't been upscaled) + self.loop_ctx["is_face"] = False + yield lq + + return _loader + + def after_load_lq(self, lq: np.ndarray) -> np.ndarray: + if self.loop_ctx["is_face"]: + self.pipeline = self.pipes["face"] + else: + self.pipeline = self.pipes["bg"] + return lq + + def save(self, sample: np.ndarray) -> None: + file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"] + if self.loop_ctx["is_face"]: + face_idx = self.loop_ctx["face_idx"] + file_name = f"{file_stem}_{repeat_idx}_face_{face_idx}.png" + Image.fromarray(sample).save(os.path.join(self.restored_face_dir, file_name)) + + cropped_face = self.loop_ctx["cropped_face"] + Image.fromarray(cropped_face).save(os.path.join(self.cropped_face_dir, file_name)) + + self.face_helper.add_restored_face(sample) + else: + self.face_helper.get_inverse_affine() + # paste each restored face to the input image + restored_img = self.face_helper.paste_faces_to_input_image( + upsample_img=sample + ) + file_name = f"{file_stem}_{repeat_idx}.png" + Image.fromarray(sample).save(os.path.join(self.restored_bg_dir, file_name)) + Image.fromarray(restored_img).save(os.path.join(self.output_dir, file_name)) diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c3e4961d92e0d4e42522fefcddf401270a8f26 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,66 @@ +import torch +import lpips + +from .image import rgb2ycbcr_pt +from .common import frozen_module + + +# https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/metrics/psnr_ssim.py#L52 +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +class LPIPS: + + def __init__(self, net: str) -> None: + self.model = lpips.LPIPS(net=net) + frozen_module(self.model) + + @torch.no_grad() + def __call__(self, img1: torch.Tensor, img2: torch.Tensor, normalize: bool) -> torch.Tensor: + """ + Compute LPIPS. + + Args: + img1 (torch.Tensor): The first image (NCHW, RGB, [-1, 1]). Specify `normalize` if input + image is range in [0, 1]. + img2 (torch.Tensor): The second image (NCHW, RGB, [-1, 1]). Specify `normalize` if input + image is range in [0, 1]. + normalize (bool): If specified, the input images will be normalized from [0, 1] to [-1, 1]. + + Returns: + lpips_values (torch.Tensor): The lpips scores of this batch. + """ + return self.model(img1, img2, normalize=normalize) + + def to(self, device: str) -> "LPIPS": + self.model.to(device) + return self diff --git a/utils/realesrgan/realesrganer.py b/utils/realesrgan/realesrganer.py new file mode 100644 index 0000000000000000000000000000000000000000..ce10fa9ffe9a1c12657d216b86f6a493e8878b85 --- /dev/null +++ b/utils/realesrgan/realesrganer.py @@ -0,0 +1,339 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from torch.nn import functional as F + +from utils.file import load_file_from_url +from utils.realesrgan.rrdbnet import RRDBNet + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + # if gpu_id: + # self.device = torch.device( + # f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + # else: + # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + + self.device = device + + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + try: + with torch.no_grad(): + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img_t = self.post_process() + output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + del output_img_t + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for RealESRGAN: {error}") + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') + +def set_realesrgan(bg_tile, device, scale=2): + ''' + scale: options: 2, 4. Default: 2. RealESRGAN official models only support x2 and x4 upsampling. + ''' + assert isinstance(scale, int), 'Expected param scale to be an integer!' + + use_half = False + if 'cuda' in str(device): # set False in CPU/MPS mode + no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16 + if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]: + use_half = True + + model_url = { + 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', + 4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' + } + + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) + upsampler = RealESRGANer( + scale=scale, + model_path=model_url[scale], + model=model, + tile=bg_tile, + tile_pad=40, + pre_pad=0, + device=device, + half=use_half + ) + return upsampler \ No newline at end of file diff --git a/utils/realesrgan/rrdbnet.py b/utils/realesrgan/rrdbnet.py new file mode 100644 index 0000000000000000000000000000000000000000..11b23a3bd4e640aaadd1b56c075dae7a43dcdc3b --- /dev/null +++ b/utils/realesrgan/rrdbnet.py @@ -0,0 +1,183 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + + +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Empirically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Empirically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out \ No newline at end of file diff --git a/utils/sampler.py b/utils/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..15101634510c845f920c33df05854e1bd99858d2 --- /dev/null +++ b/utils/sampler.py @@ -0,0 +1,423 @@ +from typing import Optional, Tuple, Dict + +import copy +import torch +from torch import nn +import numpy as np +from tqdm import tqdm +from einops import rearrange +import torch.nn.functional as F + +from model.gaussian_diffusion import extract_into_tensor +from model.cldm import ControlLDM +from utils.cond_fn import Guidance +from utils.common import sliding_windows, gaussian_weights + +import vidtome +from controller.controller import AttentionControl + +def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor: + _, _, h, w = imgs.size() + if h % multiple == 0 and w % multiple == 0: + return imgs.clone() + # get_pad = lambda x: (x // multiple + 1) * multiple - x + get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x + ph, pw = get_pad(h), get_pad(w) + return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0) + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedSampler(nn.Module): + """ + Implementation for spaced sampling schedule proposed in IDDPM. This class is designed + for sampling ControlLDM. + + https://arxiv.org/pdf/2102.09672.pdf + """ + + def __init__(self, betas: np.ndarray) -> "SpacedSampler": + super().__init__() + self.num_timesteps = len(betas) + self.original_betas = betas + self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0) + self.context = {} + + def register(self, name: str, value: np.ndarray) -> None: + self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) + + def make_schedule(self, num_steps: int) -> None: + # calcualte betas for spaced sampling + # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py + used_timesteps = space_timesteps(self.num_timesteps, str(num_steps)) + betas = [] + last_alpha_cumprod = 1.0 + for i, alpha_cumprod in enumerate(self.original_alphas_cumprod): + if i in used_timesteps: + # marginal distribution is the same as q(x_{S_t}|x_0) + betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + assert len(betas) == num_steps + self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...] + + betas = np.array(betas, dtype=np.float64) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + # print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}") + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) + sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + posterior_log_variance_clipped = np.log( + np.append(posterior_variance[1], posterior_variance[1:]) + ) + posterior_mean_coef1 = ( + betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + posterior_mean_coef2 = ( + (1.0 - alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - alphas_cumprod) + ) + + self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod) + self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod) + self.register("posterior_variance", posterior_variance) + self.register("posterior_log_variance_clipped", posterior_log_variance_clipped) + self.register("posterior_mean_coef1", posterior_mean_coef1) + self.register("posterior_mean_coef2", posterior_mean_coef2) + + def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Implement the posterior distribution q(x_{t-1}|x_t, x_0). + + Args: + x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`. + x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`. + t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get + parameters for each timestep. + + Returns: + posterior_mean (torch.Tensor): Mean of the posterior distribution. + posterior_variance (torch.Tensor): Variance of the posterior distribution. + posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution. + """ + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor: + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def apply_cond_fn( + self, + model: ControlLDM, + pred_x0: torch.Tensor, + t: torch.Tensor, + index: torch.Tensor, + cond_fn: Guidance + ) -> torch.Tensor: + t_now = int(t[0].item()) + 1 + if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start): + # stop guidance + self.context["g_apply"] = False + return pred_x0 + grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape) + # apply guidance for multiple times + loss_vals = [] + for _ in range(cond_fn.repeat): + # set target and pred for gradient computation + target, pred = None, None + if cond_fn.space == "latent": + target = model.vae_encode(cond_fn.target) + pred = pred_x0 + elif cond_fn.space == "rgb": + # We need to backward gradient to x0 in latent space, so it's required + # to trace the computation graph while decoding the latent. + with torch.enable_grad(): + target = cond_fn.target + pred_x0_rg = pred_x0.detach().clone().requires_grad_(True) + pred = model.vae_decode(pred_x0_rg) + assert pred.requires_grad + else: + raise NotImplementedError(cond_fn.space) + # compute gradient + delta_pred, loss_val = cond_fn(target, pred, t_now) + loss_vals.append(loss_val) + # update pred_x0 w.r.t gradient + if cond_fn.space == "latent": + delta_pred_x0 = delta_pred + pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale + elif cond_fn.space == "rgb": + pred.backward(delta_pred) + delta_pred_x0 = pred_x0_rg.grad + pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale + else: + raise NotImplementedError(cond_fn.space) + self.context["g_apply"] = True + self.context["g_loss"] = float(np.mean(loss_vals)) + return pred_x0 + + def predict_noise( + self, + model: ControlLDM, + x: torch.Tensor, + t: torch.Tensor, + cond: Dict[str, torch.Tensor], + uncond: Optional[Dict[str, torch.Tensor]], + cfg_scale: float + ) -> torch.Tensor: + if uncond is None or cfg_scale == 1.: + model_output = model(x, t, cond) + else: + # apply classifier-free guidance + model_cond = model(x, t, cond) + model_uncond = model(x, t, uncond) + model_output = model_uncond + cfg_scale * (model_cond - model_uncond) + return model_output + + @torch.no_grad() + def predict_noise_tiled( + self, + model: ControlLDM, + x: torch.Tensor, + t: torch.Tensor, + cond: Dict[str, torch.Tensor], + uncond: Optional[Dict[str, torch.Tensor]], + cfg_scale: float, + tile_size: int, + tile_stride: int + ): + _, _, h, w = x.shape + tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False) + eps = torch.zeros_like(x) + count = torch.zeros_like(x, dtype=torch.float32) + weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] + weights = torch.tensor(weights, dtype=torch.float32, device=x.device) + for hi, hi_end, wi, wi_end in tiles: + tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})") + tile_x = x[:, :, hi:hi_end, wi:wi_end] + tile_cond = { + "c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end], + "c_txt": cond["c_txt"] + } + if uncond: + tile_uncond = { + "c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end], + "c_txt": uncond["c_txt"] + } + tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale) + # accumulate noise + eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights + count[:, :, hi:hi_end, wi:wi_end] += weights + # average on noise (score) + eps.div_(count) + return eps + + @torch.no_grad() + def p_sample( + self, + model: ControlLDM, + x: torch.Tensor, + t: torch.Tensor, + index: torch.Tensor, + cond: Dict[str, torch.Tensor], + uncond: Optional[Dict[str, torch.Tensor]], + cfg_scale: float, + cond_fn: Optional[Guidance], + tiled: bool, + tile_size: int, + tile_stride: int, + controller: Optional[AttentionControl]=None + ) -> torch.Tensor: + if tiled: + eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride) + else: + eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale) + pred_x0 = self._predict_xstart_from_eps(x, index, eps) + if cond_fn: + assert not tiled, f"tiled sampling currently doesn't support guidance" + pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn) + + if controller is not None: + pred_x0 = controller.update_x0(pred_x0) + + model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index) + noise = torch.randn_like(x) + nonzero_mask = ( + (index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) + x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise + return x_prev + + @torch.no_grad() + def sample( + self, + model: ControlLDM, + device: str, + steps: int, + batch_size: int, + x_size: Tuple[int], + cond: Dict[str, torch.Tensor], + uncond: Dict[str, torch.Tensor], + cfg_scale: float, + cond_fn: Optional[Guidance]=None, + tiled: bool=False, + tile_size: int=-1, + tile_stride: int=-1, + x_T: Optional[torch.Tensor]=None, + progress: bool=True, + progress_leave: bool=True, + non_pad_ratio: Tuple[float]=(1, 1), + ) -> torch.Tensor: + self.make_schedule(steps) + self.to(device) + if x_T is None: + # TODO: not convert to float32, may trigger an error + img = torch.randn((batch_size, *x_size), device=device) + else: + img = x_T + timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...] + total_steps = len(self.timesteps) + iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress) + + if model.controller is not None: + # non_pad_flows = model.controller.step_store["flows"].copy() + # for j, flow in enumerate(model.controller.step_store["flows"]): + # if flow is not None: + # model.controller.step_store["flows"][j] = pad_to_multiples_of(model.controller.step_store["flows"][j], 8) + if not (model.controller.ToMe_period[0]): + vidtome.update_patch(model, controller=model.controller) + # flows=non_pad_flows, \ + # flow_confids=model.controller.step_store["flow_confids"].copy(), ) + + + for i, step in enumerate(iterator): + torch.cuda.empty_cache() + if model.controller is not None: + model.controller.set_step(i) + if i == int((total_steps * model.controller.ToMe_period[0])): + print(f"[INFO] activating ToMe @ step {i} ...") + model.activate_vidtome() + vidtome.update_patch(model, controller=model.controller) + # flows=non_pad_flows, \ + # flow_confids=model.controller.step_store["flow_confids"].copy(), + # for j, flow in enumerate(model.controller.step_store["flows"]): + # if flow is not None: + # model.controller.step_store["flows"][j] = pad_to_multiples_of(model.controller.step_store["flows"][j], 8) + + if i <= int((total_steps * model.controller.ToMe_period[1])) and i >= int((total_steps * model.controller.ToMe_period[0])): + # ratio = model.controller.merge_ratio[0] - (i / total_steps) * (model.controller.merge_ratio[0] - model.controller.merge_ratio[1]) + ToMe_start_step = int((total_steps * model.controller.ToMe_period[0])) + ToMe_end_step = int((total_steps * model.controller.ToMe_period[1])) + s = (i - ToMe_start_step) / (ToMe_end_step - ToMe_start_step) + ratio = model.controller.merge_ratio[1] + (np.cos(np.pi / 2 * s)) * (model.controller.merge_ratio[0] - model.controller.merge_ratio[1]) + vidtome.update_patch(model, current_step=i, + local_merge_ratio = ratio) + # flows=model.controller.step_store["flows"], occlusion_masks=model.controller.step_store["occ_masks"], + # flow_confids=model.controller.step_store["flow_confids"]) + print(f"[INFO] updating merging ratio to {ratio:.3f} @ step {i} s {s:.3f} ...") + ts = torch.full((batch_size,), step, device=device, dtype=torch.long) + index = torch.full_like(ts, fill_value=total_steps - i - 1) + img = self.p_sample( + model, img, ts, index, cond, uncond, cfg_scale, cond_fn, + tiled, tile_size, tile_stride, + controller=model.controller + ) + if model.controller is not None: + # model.controller.decoded_imgs.clear() + # for img_ in img: + # sample = model.vae_decode(img_[None]) + # sample = (sample + 1) / 2 + # # sample = wavelet_reconstruction(sample, clean) + # # sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True) + # sample = rearrange(sample * 255., "n c h w -> n h w c") + # sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() + # model.controller.decoded_imgs.append(sample) + + # img = model.controller.update_x0(img) + # img = model.controller.merge_x0(img, merge_ratio=1) + # img = model.controller.merge_x0_scores(img, merge_ratio=1) + # img = (img + model.controller.merge_x0(img, merge_ratio=1)) / 2 + + if i == int((total_steps * model.controller.ToMe_period[1])): + print(f"[INFO] removing ToMe patch @ step {i} ...") + vidtome.remove_patch(model) + + if cond_fn and self.context["g_apply"]: + loss_val = self.context["g_loss"] + desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}" + else: + desc = "Spaced Sampler" + iterator.set_description(desc) + + + # if model.controller is not None: + # merge.visualize_correspondence(img[0][None], img[1][None], ratio=0.05) + # img = img = model.controller.merge_x0_scores(img, merge_ratio=0.5) + return img diff --git a/utils/video_visualizer.py b/utils/video_visualizer.py new file mode 100755 index 0000000000000000000000000000000000000000..9461a7be1bd26b08eab3134acf5c7cf2b44bd4c7 --- /dev/null +++ b/utils/video_visualizer.py @@ -0,0 +1,150 @@ +import os.path +from skvideo.io import FFmpegWriter +# from image_utils import parse_image_size +# from image_utils import load_image +# from image_utils import resize_image +# from image_utils import list_images_from_dir +from .image_utils import parse_image_size +from .image_utils import load_image +from .image_utils import resize_image +from .image_utils import list_images_from_dir + + +class VideoVisualizer(object): + """Defines the video visualizer that presents images as a video.""" + + def __init__(self, + path=None, + frame_size=None, + fps=25.0, + codec='libx264', + pix_fmt='yuv420p', + crf=1): + """Initializes the video visualizer. + + Args: + path: Path to write the video. (default: None) + frame_size: Frame size, i.e., (height, width). (default: None) + fps: Frames per second. (default: 24) + codec: Codec. (default: `libx264`) + pix_fmt: Pixel format. (default: `yuv420p`) + crf: Constant rate factor, which controls the compression. The + larger this field is, the higher compression and lower quality. + `0` means no compression and consequently the highest quality. + To enable QuickTime playing (requires YUV to be 4:2:0, but + `crf = 0` results YUV to be 4:4:4), please set this field as + at least 1. (default: 1) + """ + self.set_path(path) + self.set_frame_size(frame_size) + self.set_fps(fps) + self.set_codec(codec) + self.set_pix_fmt(pix_fmt) + self.set_crf(crf) + self.video = None + + def set_path(self, path=None): + """Sets the path to save the video.""" + self.path = path + + def set_frame_size(self, frame_size=None): + """Sets the video frame size.""" + height, width = parse_image_size(frame_size) + self.frame_height = height + self.frame_width = width + + def set_fps(self, fps=25.0): + """Sets the FPS (frame per second) of the video.""" + self.fps = fps + + def set_codec(self, codec='libx264'): + """Sets the video codec.""" + self.codec = codec + + def set_pix_fmt(self, pix_fmt='yuv420p'): + """Sets the video pixel format.""" + self.pix_fmt = pix_fmt + + def set_crf(self, crf=1): + """Sets the CRF (constant rate factor) of the video.""" + self.crf = crf + + def init_video(self): + """Initializes an empty video with expected settings.""" + assert self.frame_height > 0 + assert self.frame_width > 0 + + video_setting = { + '-r': f'{self.fps:.2f}', + '-s': f'{self.frame_width}x{self.frame_height}', + '-vcodec': f'{self.codec}', + '-crf': f'{self.crf}', + '-pix_fmt': f'{self.pix_fmt}', + } + self.video = FFmpegWriter(self.path, outputdict=video_setting) + + def add(self, frame): + """Adds a frame into the video visualizer. + + NOTE: The input frame is assumed to be with `RGB` channel order. + """ + if self.video is None: + height, width = frame.shape[0:2] + if height & 1: + height -= 1 + if width & 1: + width -= 1 + # height = self.frame_height or height + # width = self.frame_width or width + self.set_frame_size((height, width)) + self.init_video() + if frame.shape[0:2] != (self.frame_height, self.frame_width): + frame = resize_image(frame, (self.frame_width, self.frame_height)) + self.video.writeFrame(frame) + + def visualize_collection(self, images, save_path=None): + """Visualizes a collection of images one by one.""" + if save_path is not None and save_path != self.path: + self.save() + self.set_path(save_path) + for image in images: + self.add(image) + self.save() + + def visualize_list(self, image_list, save_path=None): + """Visualizes a list of image files.""" + if save_path is not None and save_path != self.path: + self.save() + self.set_path(save_path) + for filename in image_list: + image = load_image(filename) + self.add(image) + self.save() + + def visualize_directory(self, directory, save_path=None): + """Visualizes all images under a directory.""" + image_list = list_images_from_dir(directory) + self.visualize_list(image_list, save_path) + + def save(self): + """Saves the video by closing the file.""" + if self.video is not None: + self.video.close() + self.video = None + self.set_path(None) + + +if __name__ == '__main__': + from glob import glob + import cv2 + video_visualizer = VideoVisualizer(path='/home/yehhh/DiffBIR/DAVIS_bear.mp4', + frame_size=None, + fps=25.0) + img_folder = "/home/yehhh/DiffBIR/inputs/bear" + imgs = sorted(glob(img_folder + '/*.png')) + imgs = sorted(glob(img_folder + '/*.jpg')) + for img in imgs: + image = cv2.imread(img) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + video_visualizer.add(image) + video_visualizer.save() \ No newline at end of file diff --git a/vidtome/__init__.py b/vidtome/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8d0e4aa3cb60265129f904aafef111dca48a24 --- /dev/null +++ b/vidtome/__init__.py @@ -0,0 +1,4 @@ +from . import merge, patch +from .patch import apply_patch, remove_patch, update_patch, collect_from_patch + +__all__ = ["merge", "patch", "apply_patch", "remove_patch", "update_patch", "collect_from_patch"] \ No newline at end of file diff --git a/vidtome/__pycache__/__init__.cpython-310.pyc b/vidtome/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..665041569b598a0fea9969874fe62fe708288e04 Binary files /dev/null and b/vidtome/__pycache__/__init__.cpython-310.pyc differ diff --git a/vidtome/__pycache__/__init__.cpython-39.pyc b/vidtome/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..030f7a0a8e48a6ba19d4f642a58d2edd1935f4d7 Binary files /dev/null and b/vidtome/__pycache__/__init__.cpython-39.pyc differ diff --git a/vidtome/__pycache__/merge.cpython-310.pyc b/vidtome/__pycache__/merge.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be9da9b1d583d6a28ff999584dff5e5628629142 Binary files /dev/null and b/vidtome/__pycache__/merge.cpython-310.pyc differ diff --git a/vidtome/__pycache__/merge.cpython-39.pyc b/vidtome/__pycache__/merge.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cc21bf207b7333e53f6fbc6b8c077d9805cabf1 Binary files /dev/null and b/vidtome/__pycache__/merge.cpython-39.pyc differ diff --git a/vidtome/__pycache__/patch.cpython-310.pyc b/vidtome/__pycache__/patch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a3090a5303b48bf784bace5af669445ef35fd80 Binary files /dev/null and b/vidtome/__pycache__/patch.cpython-310.pyc differ diff --git a/vidtome/__pycache__/patch.cpython-39.pyc b/vidtome/__pycache__/patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f49bde85ac01c03e16c2515099a18d408596545 Binary files /dev/null and b/vidtome/__pycache__/patch.cpython-39.pyc differ diff --git a/vidtome/__pycache__/utils.cpython-310.pyc b/vidtome/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f7a8f242473494a4a37a09e48bd8a647ac4deb7 Binary files /dev/null and b/vidtome/__pycache__/utils.cpython-310.pyc differ diff --git a/vidtome/__pycache__/utils.cpython-39.pyc b/vidtome/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e61dc98e0488f2e7a9e4122236c0ce824f215d7 Binary files /dev/null and b/vidtome/__pycache__/utils.cpython-39.pyc differ diff --git a/vidtome/merge.py b/vidtome/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac40a9dd92e7b1de454d9e6eb6b42d67a6571be --- /dev/null +++ b/vidtome/merge.py @@ -0,0 +1,1201 @@ +import cv2 +import time +import torch +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.cm as cm +from matplotlib.patches import ConnectionPatch +from controller.controller import AttentionControl + +from einops import repeat, rearrange +from typing import Tuple, Callable + +from vidtome.patch import PCA_token +from utils.flow_utils import coords_grid + +def do_nothing(x: torch.Tensor, mode: str = None): + return x + + +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + else: + return torch.gather(input, dim, index) + +def visualize_flow_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, flow: torch.Tensor, flow_confid: torch.Tensor, + ratio: float, H: int=64, out: str = "correspondence.png") -> Tuple[Callable, Callable, dict]: + if len(src_img.shape) == 4: + B, C, H, W = src_img.shape + src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) + tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) + else: + B, N, C = src_img.shape + W = N // H + + src_PCA_token = PCA_token(src_img, token_h=H) + tar_PCA_token = PCA_token(tar_img, token_h=H) + # Compute pre-frame token number. N = unm_pre + tnum * F. + + gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather + + with torch.no_grad(): + # Cosine similarity between src and dst tokens + a = src_img / src_img.norm(dim=-1, keepdim=True) + b = tar_img / tar_img.norm(dim=-1, keepdim=True) + + scores = a @ b.transpose(-1, -2) + # Can't reduce more than the # tokens in src + r = min(a.shape[1], int(a.shape[1] * ratio)) + print(f"[INFO] flow r {r} ") + # Find the most similar greedily + flow_confid = rearrange(flow_confid, 'b h w -> b (h w)') + edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + + src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] + grid = coords_grid(B, H, W).to(flow.device) + flow # [B, 2, H, W] + tar_xy = [(grid[0, 0, y, x].clamp(0, W-1).item(), \ + grid[0, 1, y, x].clamp(0, H-1).item()) for (x, y) in src_xy] + # tar_idx = torch.tensor([y * W + x for (x, y) in tar_xy]).to(src_idx.device) + + fig, ax = plt.subplots(1, 2, figsize=(8, 3)) + # Display the source and target images + ax[0].imshow(src_PCA_token, cmap='gray') + ax[1].imshow(tar_PCA_token, cmap='gray') + + ax[0].axis('off') + ax[1].axis('off') + + colors = cm.Greens(np.linspace(0.5, 1, len(src_xy))) + # Draw lines connecting corresponding points + for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): + ax[0].plot(x1, y1, marker='o', color=color, markersize=0.5) # red dot in source image + ax[1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image + con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", + axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) + ax[1].add_artist(con) + # plt.tight_layout() + plt.savefig(out, bbox_inches="tight") + plt.close() + +def visualize_correspondence_score(src_img: torch.Tensor, tar_img: torch.Tensor, score: torch.Tensor, + ratio: float, H: int=64, out: str = "correspondence_idx.png") -> Tuple[Callable, Callable, dict]: + if len(src_img.shape) == 4: + B, C, H, W = src_img.shape + src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) + tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) + else: + B, N, C = src_img.shape + W = N // H + + src_PCA_token = PCA_token(src_img, token_h=H) + tar_PCA_token = PCA_token(tar_img, token_h=H) + + + with torch.no_grad(): + # Can't reduce more than the # tokens in src + r = min(N, int(N * ratio)) + + node_max, node_idx = score.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + # src_idx = edge_idx[0, -r:, 0] # Merged Tokens + src_idx = edge_idx[0, :r, 0] # Merged Tokens + tar_idx = torch.gather(node_idx[0], dim=0, index=src_idx) + + src_idx = src_idx[:r] + tar_idx = tar_idx[:r] + + # x = src_idx % W + # y = src_idx // W + # src_xy + src_xy = [(id.item() % W, id.item() // W) for id in src_idx] + tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx] + + fig, ax = plt.subplots(1, 2, figsize=(8, 3)) + # Display the source and target images + ax[0].imshow(src_PCA_token, cmap='gray') + ax[1].imshow(tar_PCA_token, cmap='gray') + + colors = cm.cool(np.linspace(0, 1, len(src_xy))) + # Draw lines connecting corresponding points + for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): + ax[0].plot(x1, y1, marker='o', color=color, markersize=1) # red dot in source image + ax[1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image + con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", + axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) + ax[1].add_artist(con) + # plt.tight_layout() + plt.savefig(out, bbox_inches="tight") + plt.close() + +def visualize_cosine_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, + ratio: float, H: int=64, out: str = "correspondence.png", + flow: torch.Tensor=None, flow_confid: torch.Tensor=None, + controller: AttentionControl=None ) -> Tuple[Callable, Callable, dict]: + if len(src_img.shape) == 4: + B, C, H, W = src_img.shape + src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) + tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) + else: + B, N, C = src_img.shape + W = N // H + # import ipdb; ipdb.set_trace() + src_PCA_token = PCA_token(src_img, token_h=H) + tar_PCA_token = PCA_token(tar_img, token_h=H) + + # Compute pre-frame token number. N = unm_pre + tnum * F. + + gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather + + with torch.no_grad(): + # Cosine similarity between src and dst tokens + a = src_img / src_img.norm(dim=-1, keepdim=True) + b = tar_img / tar_img.norm(dim=-1, keepdim=True) + + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], int(a.shape[1] * ratio)) + print(f"[INFO] cosine r {r} ") + # Find the most similar greedily + # import ipdb; ipdb.set_trace() + # scores *= controller.distances[H][:,:scores.shape[1]] + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., int(4*r):int(5*r), :] # Merged Tokens + # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + # src_idx = edge_idx[..., :r, :] # Merged Tokens + + tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] + tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx[0]] + + + fig, ax = plt.subplots(1, 2, figsize=(8, 3)) + # Display the source and target images + ax[0].imshow(src_PCA_token, cmap='spring') + ax[1].imshow(tar_PCA_token, cmap='spring') + + # Hide the axis labels + ax[0].axis('off') + ax[1].axis('off') + + # colors = cm.Reds(np.linspace(0.5, 1, len(src_xy))) + colors = cm.cool(np.linspace(0.5, 1, len(src_xy))) + # Draw lines connecting corresponding points + for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): + # color = "orangered" + ax[0].plot(x1, y1, marker='o', color=color, markersize=0.5) # red dot in source image + ax[1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image + con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", + axesA=ax[1], axesB=ax[0], color=color, linewidth=0.2) + ax[1].add_artist(con) + # plt.tight_layout() + plt.savefig(out, bbox_inches="tight") + plt.close() + +def visualize_correspondence(src_img: torch.Tensor, tar_img: torch.Tensor, + ratio: float, H: int=64, out: str = "correspondence.png", + flow: torch.Tensor=None, flow_confid: torch.Tensor=None, + controller: AttentionControl=None ) -> Tuple[Callable, Callable, dict]: + + if len(src_img.shape) == 4: + B, C, H, W = src_img.shape + src_img = rearrange(src_img, 'b c h w -> b (h w) c', h=H) + tar_img = rearrange(tar_img, 'b c h w -> b (h w) c', h=H) + else: + B, N, C = src_img.shape + W = N // H + src_PCA_token = PCA_token(src_img, token_h=H, n=1) + tar_PCA_token = PCA_token(tar_img, token_h=H, n=1) + # import ipdb; ipdb.set_trace() + if abs(np.mean(src_PCA_token[:20, :20]) - np.mean(tar_PCA_token[:20, :20])) > 50: + if np.mean(src_PCA_token[:20, :20]) > np.mean(tar_PCA_token[:20, :20]): + src_PCA_token = 255 - src_PCA_token + else: + tar_PCA_token = 255 - tar_PCA_token + print(f"[INFO] src_PCA_token mean {np.mean(src_PCA_token[:20, :20])} tar_PCA_token mean {np.mean(tar_PCA_token[:20, :20])} ") + # Compute pre-frame token number. N = unm_pre + tnum * F. + + gather = mps_gather_workaround if src_img.device.type == "mps" else torch.gather + + with torch.no_grad(): + # Cosine similarity between src and dst tokens + a = src_img / src_img.norm(dim=-1, keepdim=True) + b = tar_img / tar_img.norm(dim=-1, keepdim=True) + + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], int(a.shape[1] * ratio)) + + # Find the most similar greedily + # import ipdb; ipdb.set_trace() + print(f"[INFO] no distance weigthed ... ") + # scores *= controller.distances[H][:,:scores.shape[1]] + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + # src_idx = edge_idx[..., :r, :] # Merged Tokens + + tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] + tar_xy = [(id.item() % W, id.item() // W) for id in tar_idx[0]] + + # Find the most similar greedily + flow_confid = rearrange(flow_confid, 'b h w -> b (h w)') + edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + + flow_src_xy = [(id.item() % W, id.item() // W) for id in src_idx[0]] + # import ipdb; ipdb.set_trace() + grid = coords_grid(B, H, W).to(flow.device) + flow # [B, 2, H, W] + flow_tar_xy = [(grid[0, 0, y, x].clamp(0, W-1).item(), \ + grid[0, 1, y, x].clamp(0, H-1).item()) for (x, y) in flow_src_xy] + + + fig, ax = plt.subplots(2, 2, figsize=(8, 4)) + + if len(controller.decoded_imgs): + step = out.split("/")[-1].split(".")[0] + + _, h_, w_, _ = controller.decoded_imgs[0].shape + mul = h_ // H + decoded_img = controller.decoded_imgs[1] + decoded_img = decoded_img[0, :, :int(W * mul), :] + if step == "49": + decoded_img = cv2.imread("/project/DiffBVR_eval/DAVIS/BDx8_results/DiffBIR_ours/cows/00001.png") + decoded_img = cv2.resize(decoded_img, (W, H)) + ax[0, 0].imshow(decoded_img, aspect='auto') + decoded_img = controller.decoded_imgs[2] + decoded_img = decoded_img[0, :, :int(W * mul), :] + if step == "49": + decoded_img = cv2.imread("/project/DiffBVR_eval/DAVIS/BDx8_results/DiffBIR_ours/cows/00002.png") + decoded_img = cv2.resize(decoded_img, (W, H)) + ax[0, 1].imshow(decoded_img, aspect='auto') + else: + # Display the source and target images + ax[0, 0].imshow(src_PCA_token, cmap='ocean', aspect='auto') + ax[0, 1].imshow(tar_PCA_token, cmap='ocean', aspect='auto') + + ax[0, 0].axis('off') + ax[0, 1].axis('off') + + ax[1, 0].imshow(src_PCA_token, cmap='Blues', aspect='auto') + ax[1, 1].imshow(tar_PCA_token, cmap='Blues', aspect='auto') + # ax[1, 0].imshow(np.mean(src_PCA_token, -1), cmap='ocean') + # ax[1, 1].imshow(np.mean(tar_PCA_token, -1), cmap='ocean') + + # Hide the axis labels + ax[1, 0].axis('off') + ax[1, 1].axis('off') + + + colors = cm.Greens(np.linspace(0.25, 0.75, len(flow_src_xy))) + # Draw lines connecting corresponding points + for (x1, y1), (x2, y2), color in zip(flow_src_xy, flow_tar_xy, colors): + # color = "mediumslateblue" + # ax[1, 0].plot(x1, y1, marker='o', color=color, markersize=1) # red dot in source image + ax[1, 1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image + con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", + axesA=ax[1, 1], axesB=ax[1, 0], color=color, linewidth=0.2) + ax[1, 1].add_artist(con) + # plt.tight_layout() + colors = cm.Reds(np.linspace(0.25, 0.75, len(src_xy))) + # Draw lines connecting corresponding points + for (x1, y1), (x2, y2), color in zip(src_xy, tar_xy, colors): + # color = "orangered" + # ax[1, 0].plot(x1, y1, marker='o', color=color, markersize=1) # red dot in source image + ax[1, 1].plot(x2, y2, marker='o', color=color, markersize=1) # red dot in target image + con = ConnectionPatch(xyA=(x2, y2), xyB=(x1, y1), coordsA="data", coordsB="data", + axesA=ax[1, 1], axesB=ax[1, 0], color=color, linewidth=0.2) + ax[1, 1].add_artist(con) + + plt.subplots_adjust(wspace=0.05, hspace=0.1) + plt.savefig(out, bbox_inches="tight") + plt.close() + +# For Local Token Merging +def bipartite_soft_matching_randframe(metric: torch.Tensor, + F: int, ratio: float, unm_pre: int, generator: torch.Generator=None, + target_stride: int = 4, align_batch: bool = False, + merge_mode: str = "replace", H: int=64, + flow_merge: bool=False, + controller: AttentionControl=None) -> Tuple[Callable, Callable, dict]: + """ + Partitions the multi-frame tokens into src and dst and merges ratio of src tokens from src to dst. + Dst tokens are partitioned by choosing one random frame. + + Args: + - metric [B, N, C]: metric to use for similarity. + - F: frame number. + - ratio: ratio of src tokens to be removed (by merging). + - unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...] + - generator: random number generator + - target_stride: stride of target frame. + - align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP. + - merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token. + + Returns: + Merge and unmerge operation according to the matching result. Return a dict including other values. + """ + B, N, _ = metric.shape + A = N // F + W = A // H + # Compute pre-frame token number. N = unm_pre + tnum * F. + tnum = (N - unm_pre) // F + if ratio <= 0: + return do_nothing, do_nothing, {"unm_num": tnum} + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + # Prepare idx buffer. Ignore previous unmerged tokens. + idx_buffer = torch.arange( + N - unm_pre, device=metric.device, dtype=torch.int64) + + # Select the random target frame. + target_stride = min(target_stride, F) + # import ipdb; ipdb.set_trace() + if controller is None: + randf = torch.randint(0, target_stride, torch.Size( + [1]), generator=generator, device=generator.device) + else: + randf = torch.tensor(target_stride // 2).to(metric.device) + # print(f"[INFO] randf: {randf} ... ") + dst_select = ((torch.div(idx_buffer, tnum, rounding_mode='floor')) % + target_stride == randf).to(torch.bool) + + # a_idx: src index. b_idx: dst index + a_idx = idx_buffer[None, ~dst_select, None] + unm_pre + b_idx = idx_buffer[None, dst_select, None] + unm_pre + # import ipdb; ipdb.set_trace() + + # Add unmerged tokens to dst. + unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[ + None, :, None] + b_idx = torch.cat([b_idx, unm_buffer], dim=1) + + # We're finished with these + del idx_buffer, unm_buffer + + num_dst = b_idx.shape[1] + + def split(x): + # Split src, dst tokens + b, n, c = x.shape + src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) + dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) + # print(f"[INFO] {x.shape} {num_dst}") + return src, dst + + # if flow is not None: + # start = time.time() + # if len(flow) != F-1: + # mid = F // 2 + # flow_confid = flow_confid[:mid] + flow_confid[mid+1:] + # flow = flow[:mid] + flow[mid+1:] + + # flow_confid = torch.cat(flow_confid, dim=0) + # flow = torch.cat(flow, dim=0) + # flow_confid = rearrange(flow_confid, 'b h w -> 1 (b h w)') + # print(f"[INFO] flow time {time.time() - start}") + + # Cosine similarity between src and dst tokens + metric = metric / metric.norm(dim=-1, keepdim=True) + # import ipdb; ipdb.set_trace() + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], int(a.shape[1] * ratio)) + + if align_batch: + # Cat scores of all samples in the batch. When using PnP, samples are (src, neg, pos). + # Find the most similar greedily among all samples. + scores = torch.cat([*scores], dim=-1) + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], + dim=-2, index=src_idx) % num_dst # Map index to (0, num_dst - 1) + + # Use the same matching result for all samples + unm_idx = unm_idx.expand(B, -1, -1) + src_idx = src_idx.expand(B, -1, -1) + dst_idx = dst_idx.expand(B, -1, -1) + else: + + if flow_merge: + # print(f"[INFO] flow merge ... ") + # start = time.time() + # edge_idx = flow_confid.argsort(dim=-1, descending=True)[..., None] + + # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + # src_idx = edge_idx[..., :r, :] # Merged Tokens + + # src_idx_tensor = src_idx[0, : ,0] + # f = src_idx_tensor // A + # id = src_idx_tensor % A + # x = id % W + # y = id // W + + # # Stack the results into a 2D tensor + # src_fxy = torch.stack((f, x, y), dim=1) + # # import ipdb; ipdb.set_trace() + # grid = coords_grid(F-1, H, W).to(flow.device) + flow # [F-1, 2, H, W] + + # x = grid[src_fxy[:, 0], 0, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, W-1).long() + # y = grid[src_fxy[:, 0], 1, src_fxy[:, 2], src_fxy[:, 1]].clamp(0, H-1).long() + # tar_xy = torch.stack((x, y), dim=1) + # tar_idx = y * W + x + # tar_idx = rearrange(tar_idx, ' d -> 1 d 1') + # print(f"[INFO] {src_idx[0, 10, 0]} {tar_idx[0, 10, 0]}") + unm_idx = controller.flow_correspondence[H][0][:, r:, :] + src_idx = controller.flow_correspondence[H][0][:, :r, :] + tar_idx = controller.flow_correspondence[H][1][:, :r, :] + # score[src_idx[i], tar_idx[i]] = flow_confid[src_idx[i]] + # scores[:, src_idx[0, :, 0], tar_idx[0, :, 0]] = flow_confid[0, src_idx[0, :, 0]] + # import ipdb; ipdb.set_trace() + else: + + ''' distacne weighted ''' + # # if H == 64: + # # Create a tensor that represents the coordinates of each pixel + # start = time.time() + # y, x = torch.meshgrid(torch.arange(H), torch.arange(W)) + # coords = torch.stack((y, x), dim=-1).float().to(metric.device) + # coords = rearrange(coords, 'h w c -> (h w) c') + + # # Calculate the Euclidean distance between all pixels + # distances = torch.cdist(coords, coords) + # radius = W // 30 + # radius = 1 if radius == 0 else radius + # # print(f"[INFO] W: {W} Radius: {radius} ") + # distances //= radius + # distances = torch.exp(-distances) + # # distances += torch.diag_embed(torch.ones(A)).to(metric.device) + # distances = repeat(distances, 'h a -> 1 (b h) a', b=F-1) + # print(f"[INFO] {W} {torch.mean(distances)} {torch.std(distances)}") + # node_max, node_idx = scores.max(dim=-1) + # scores *= distances + # print(f"[INFO] distance not weighted ... ") + if controller is not None: + if H not in controller.distances: + controller.set_distance(F-1, H, W, W//30, metric.device) + print(f"[INFO] distance weighted ... ") + # print(f"[INFO] controller distance time {time.time() - start}") + scores *= controller.distances[H] + + # Find the most similar greedily + ''' node_idx: src_idx to tar_idx ''' + node_max, node_idx = scores.max(dim=-1) + + # src_idx_tensor = torch.arange(node_max.shape[1], device=metric.device, dtype=torch.int64) + # id = src_idx_tensor % A + # x = id % W + # y = id // W + # src_xy = torch.stack((x, y), dim=1) + + # tar_idx_tensor = node_idx[0, :] + # x = tar_idx_tensor % W + # y = tar_idx_tensor // W + # tar_xy = torch.stack((x, y), dim=1) + + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + ''' idx in all src tokens ''' + src_idx = edge_idx[..., :r, :] # Merged Tokens + tar_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + # correspond_dis = gather(distance[None, ..., None], dim=-2, index=src_idx) + # import ipdb; ipdb.set_trace() + # import ipdb; ipdb.set_trace() + + + # src_idx_tensor = src_idx[0, : ,0] + # id = src_idx_tensor % A + # x = id % W + # y = id // W + # src_xy = torch.stack((x, y), dim=1) + + # tar_idx_tensor = tar_idx[0, : ,0] + # x = tar_idx_tensor % W + # y = tar_idx_tensor // W + # tar_xy = torch.stack((x, y), dim=1) + # cosine_delta = torch.sum(torch.norm((src_xy - tar_xy).float(), dim=-1)) + # import ipdb; ipdb.set_trace() + # print("&&&") + # if flow is not None: + # print(f"[INFO] Flow Delta: {flow_delta.item()} Cosine Delta: {cosine_delta.item()}") + # else: + # print(f"Cosine Delta: {cosine_delta.item()}") + + def merge(x: torch.Tensor, mode=None) -> torch.Tensor: + # Merge tokens according to matching result. + src, dst = split(x) + n, t1, c = src.shape + u_idx, s_idx, t_idx = unm_idx, src_idx, tar_idx + # print(f"[INFO] {s_idx[0, 10, 0]} {t_idx[0, 10, 0]}") + unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) + mode = mode if mode is not None else merge_mode + if mode != "replace": + src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) + # In other mode such as mean, combine matched src and dst tokens. + dst = dst.scatter_reduce(-2, t_idx.expand(-1, -1, c), + src, reduce=mode, include_self=True) + # In replace mode, just cat unmerged tokens and dst tokens. Ignore src tokens. + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor: + # Unmerge tokens to original size according to matching result. + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + b, _, c = unm.shape + u_idx, s_idx, t_idx = unm_idx, src_idx, tar_idx + # Restored src tokens take value from dst tokens + src = gather(dst, dim=-2, index=t_idx.expand(-1, -1, c)) + # Combine back to the original shape + out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) + # Scatter dst tokens + out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) + # Scatter unmerged tokens + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), + dim=1, index=u_idx).expand(-1, -1, c), src=unm) + # Scatter src tokens + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), + dim=1, index=s_idx).expand(-1, -1, c), src=src) + + return out + + # Return number of tokens not merged. + ret_dict = {"scores": scores, "unm_num": unm_idx.shape[1] if unm_idx.shape[1] is not None else 0} + return merge, unmerge, ret_dict + + +def bipartite_soft_matching_random2d_hier(metric: torch.Tensor, frame_num: int, ratio: float, unm_pre: int, generator: torch.Generator, target_stride: int = 4, adhere_src: bool = False, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + - rand_seed: if no_rand is false, and if not None, sets random seed. + """ + B, N, _ = metric.shape + F = frame_num + nf = (N - unm_pre) // F + + if ratio <= 0: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + + + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer = torch.arange(N - unm_pre, device=metric.device, dtype=torch.int64) + + + # randn = torch.randint(0, F, torch.Size([nf])).to(idx_buffer) * nf + # dst_indexes = torch.arange(nf, device=metric.device, dtype=torch.int64) + randn + # dst_select = torch.zeros_like(idx_buffer).to(torch.bool) + # dst_select[dst_indexes] = 1 + max_f = min(target_stride, F) + randn = torch.randint(0, max_f, torch.Size([1]), generator=generator, device = generator.device) + # randn = 0 + dst_select = ((torch.div(idx_buffer, nf, rounding_mode='floor')) % max_f == randn).to(torch.bool) + # dst_select = ((idx_buffer // nf) == 0).to(torch.bool) + a_idx = idx_buffer[None, ~dst_select, None] + unm_pre + b_idx = idx_buffer[None, dst_select, None] + unm_pre + + unm_buffer = torch.arange(unm_pre, device=metric.device, dtype=torch.int64)[None,:,None] + b_idx = torch.cat([b_idx, unm_buffer], dim = 1) + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + + # We're finished with these + del idx_buffer, unm_buffer + + num_dst = b_idx.shape[1] + + def split(x): + b, n, c = x.shape + src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) + dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) + return src, dst + + def split_coord(coord): + b, n, c = coord.shape + src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c)) + dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c)) + return src, dst + + + # Cosine similarity between A and B + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + + + if coord is not None: + src_coord, dst_coord = split_coord(coord) + mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field + + + scores = a @ b.transpose(-1, -2) + + if coord is not None: + scores[mask] = 0 + + # Can't reduce more than the # tokens in src + r = int(a.shape[1] * ratio) + r = min(a.shape[1], r) + + + + if adhere_src: + # scores = torch.sum(scores, dim=0) + scores = torch.cat([*scores], dim = -1) + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst + + unm_idx = unm_idx.expand(B, -1, -1) + src_idx = src_idx.expand(B, -1, -1) + dst_idx = dst_idx.expand(B, -1, -1) + else: + # scores = torch.cat([*scores][1:], dim = -1) + # node_max, node_idx = scores.max(dim=-1) + # edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + # src_idx = edge_idx[..., :r, :] # Merged Tokens + # dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst + + # unm_idx = unm_idx.expand(B, -1, -1) + # src_idx = src_idx.expand(B, -1, -1) + # dst_idx = dst_idx.expand(B, -1, -1) + + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + # if adhere_src: + # unm_idx[:,...] = unm_idx[0:1] + # src_idx[:,...] = src_idx[0:1] + # dst_idx[:,...] = dst_idx[0:1] + + def merge(x: torch.Tensor, mode=None, b_select = None, **kwarg) -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + if b_select is not None: + if not isinstance(b_select, list): + b_select = [b_select] + u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] + else: + u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx + + unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) + src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) + mode = mode if mode is not None else merge_mode + if mode != "replace": + dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True) + # dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add') + + # dst_cnt = torch.ones_like(dst) + # src_ones = torch.ones_like(src) + # dst_cnt = dst_cnt.scatter(-2, dst_idx.expand(n, r, c), src_ones, reduce='add') + + # dst = dst / dst_cnt + # dst2 = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode, include_self=True) + # assert torch.allclose(dst1, dst2) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor, b_select = None, unm_modi = None, **kwarg) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + b, _, c = unm.shape + if b_select is not None: + if not isinstance(b_select, list): + b_select = [b_select] + u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] + else: + u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx + if unm_modi is not None: + if unm_modi == "zero": + unm = torch.zeros_like(unm) + src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) + + # Combine back to the original shape + out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src) + + return out + + ret_dict = {"unm_num": unm_idx.shape[1]} + return merge, unmerge, ret_dict + +# For Global Token Merging. +def bipartite_soft_matching_2s( metric: torch.Tensor, + src_len: int, ratio: float, align_batch: bool, + merge_mode: str = "replace", unmerge_chunk: int = 0) -> Tuple[Callable, Callable, dict]: + """ + Partitions the tokens into src and dst and merges ratio of src tokens from src to dst. + Src tokens are partitioned as first src_len tokens. Others are dst tokens. + + Args: + - metric [B, N, C]: metric to use for similarity. + - src_len: src token length. [ src | dst ]: [ src_len | N - src_len ] + - ratio: ratio of src tokens to be removed (by merging). + - unm_pre: number of src tokens not merged at previous ToMe. Pre-sequence: [unm_pre|F_0|F_1|...] + - align_batch: whether to align similarity matching maps of samples in the batch. True when using PnP. + - merge_mode: how to merge tokens. "mean": tokens -> Mean(src_token, dst_token); "replace": tokens -> dst_token. + - unmerge_chunk: return which partition in unmerge. 0 for src and 1 for dst. + + Returns: + Merge and unmerge operation according to the matching result. Return a dict including other values. + """ + B, N, _ = metric.shape + + if ratio <= 0: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + + idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64) + + # [ src | dst ]: [ src_len | N - src_len ] + a_idx = idx_buffer[None, :src_len, None] + b_idx = idx_buffer[None, src_len:, None] + + del idx_buffer + + num_dst = b_idx.shape[1] + # import ipdb; ipdb.set_trace() + def split(x): + # Split src, dst tokens + b, n, c = x.shape + # print(f"[INFO] {num_dst} {x.shape} ") + src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) + dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) + return src, dst + + # Cosine similarity between src and dst tokens + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], int(a.shape[1] * ratio)) + + if align_batch: + # Cat scores of all samples in the batch. When using PnP, samples are (src, neg, pos). + # Find the most similar greedily among all samples. + scores = torch.cat([*scores], dim=-1) + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], + dim=-2, index=src_idx) % num_dst # Map index to (0, num_dst - 1) + + # Use the same matching result for all samples + unm_idx = unm_idx.expand(B, -1, -1) + src_idx = src_idx.expand(B, -1, -1) + dst_idx = dst_idx.expand(B, -1, -1) + else: + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode=None) -> torch.Tensor: + # Merge tokens according to matching result. + # import ipdb; ipdb.set_trace() + src, dst = split(x) + n, t1, c = src.shape + u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx + + unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) + mode = mode if mode is not None else merge_mode + if mode != "replace": + src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) + # In other mode such as mean, combine matched src and dst tokens. + dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), + src, reduce=mode, include_self=True) + # In replace mode, just cat unmerged tokens and dst tokens. Discard src tokens. + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor, **kwarg) -> torch.Tensor: + # Unmerge tokens to original size according to matching result. + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + b, _, c = unm.shape + u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx + # Restored src tokens take value from dst tokens + src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) + + # Combine back to the original shape + out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) + # Scatter dst tokens + out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) + # Scatter unmerged tokens + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), + dim=1, index=u_idx).expand(-1, -1, c), src=unm) + # Scatter src tokens + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), + dim=1, index=s_idx).expand(-1, -1, c), src=src) + + out = out[:, :src_len, :] if unmerge_chunk == 0 else out[:, src_len:, :] + return out + + ret_dict = {"unm_num": unm_idx.shape[1]} + return merge, unmerge, ret_dict + + +# Original ToMe +def bipartite_soft_matching_random2d(metric: torch.Tensor, + w: int, h: int, sx: int, sy: int, r: int, + no_rand: bool = False, + generator: torch.Generator = None) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + - rand_seed: if no_rand is false, and if not None, sets random seed. + """ + B, N, _ = metric.shape + + if r <= 0: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + if no_rand: + rand_idx = torch.zeros( + hsy, wsx, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint( + sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) + + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros( + hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_( + dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view( + hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) + + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros( + h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + # Cosine similarity between A and B + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, + a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, + a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def bipartite_soft_matching_2f(metric: torch.Tensor, src_len: int, ratio: float, adhere_src: bool, merge_mode: str = "replace", scores = None, coord = None, rec_field = 2, unmerge_chunk = 0) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + - rand_seed: if no_rand is false, and if not None, sets random seed. + """ + B, N, _ = metric.shape + + if ratio <= 0: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + + + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer = torch.arange(N, device=metric.device, dtype=torch.int64) + + + # randn = torch.randint(0, F, torch.Size([nf])).to(idx_buffer) * nf + # dst_indexes = torch.arange(nf, device=metric.device, dtype=torch.int64) + randn + # dst_select = torch.zeros_like(idx_buffer).to(torch.bool) + # dst_select[dst_indexes] = 1 + # randn = 0 + # dst_select = ((idx_buffer // nf) == 0).to(torch.bool) + a_idx = idx_buffer[None, :src_len, None] + b_idx = idx_buffer[None, src_len:, None] + + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + + # We're finished with these + del idx_buffer + + num_dst = b_idx.shape[1] + + def split(x): + b, n, c = x.shape + src = gather(x, dim=1, index=a_idx.expand(b, n - num_dst, c)) + dst = gather(x, dim=1, index=b_idx.expand(b, num_dst, c)) + return src, dst + + def split_coord(coord): + b, n, c = coord.shape + src = gather(coord, dim=1, index=a_idx.expand(b, n - num_dst, c)) + dst = gather(coord, dim=1, index=b_idx.expand(b, num_dst, c)) + return src, dst + + + # Cosine similarity between A and B + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + + + if coord is not None: + src_coord, dst_coord = split_coord(coord) + mask = torch.norm(src_coord[:,:,None,:] - dst_coord[:,None,:,:], dim=-1) > rec_field + + + scores = a @ b.transpose(-1, -2) + + if coord is not None: + scores[mask] = 0 + + # Can't reduce more than the # tokens in src + r = int(a.shape[1] * ratio) + r = min(a.shape[1], r) + + + + if adhere_src: + scores = torch.cat([*scores], dim = -1) + # scores = torch.sum(scores, dim=0) + node_max, node_idx = scores.max(dim=-1) + + # nscores = torch.cat([*scores], dim = -2) + # rev_node_max, rev_node_idx = nscores.max(dim = -2) + + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst + + unm_idx = unm_idx.expand(B, -1, -1) + src_idx = src_idx.expand(B, -1, -1) + dst_idx = dst_idx.expand(B, -1, -1) + else: + # scores = torch.cat([*scores][1:], dim = -1) + # node_max, node_idx = scores.max(dim=-1) + # edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + # unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + # src_idx = edge_idx[..., :r, :] # Merged Tokens + # dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) % num_dst + + # unm_idx = unm_idx.expand(B, -1, -1) + # src_idx = src_idx.expand(B, -1, -1) + # dst_idx = dst_idx.expand(B, -1, -1) + + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + # if adhere_src: + # unm_idx[:,...] = unm_idx[0:1] + # src_idx[:,...] = src_idx[0:1] + # dst_idx[:,...] = dst_idx[0:1] + + def merge(x: torch.Tensor, mode=None, b_select = None) -> torch.Tensor: + + src, dst = split(x) + n, t1, c = src.shape + if b_select is not None: + if not isinstance(b_select, list): + b_select = [b_select] + u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] + else: + u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx + + unm = gather(src, dim=-2, index=u_idx.expand(-1, -1, c)) + # src = gather(src, dim=-2, index=s_idx.expand(-1, -1, c)) + mode = mode if mode is not None else merge_mode + if mode != "replace": + dst = dst.scatter_reduce(-2, d_idx.expand(-1, -1, c), src, reduce=mode, include_self=True) + # dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add') + + # dst_cnt = torch.ones_like(dst) + # src_ones = torch.ones_like(src) + # dst_cnt = dst_cnt.scatter(-2, dst_idx.expand(n, r, c), src_ones, reduce='add') + + # dst = dst / dst_cnt + # dst2 = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode, include_self=True) + # assert torch.allclose(dst1, dst2) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor, b_select = None, unm_modi = None) -> torch.Tensor: + + + + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + b, _, c = unm.shape + if b_select is not None: + if not isinstance(b_select, list): + b_select = [b_select] + u_idx, s_idx, d_idx = unm_idx[b_select], src_idx[b_select], dst_idx[b_select] + else: + u_idx, s_idx, d_idx = unm_idx, src_idx, dst_idx + if unm_modi is not None: + if unm_modi == "zero": + unm = torch.zeros_like(unm) + src = gather(dst, dim=-2, index=d_idx.expand(-1, -1, c)) + + # Combine back to the original shape + out = torch.zeros(b, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(b, -1, c), src=dst) + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=u_idx).expand(-1, -1, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(b, -1, 1), dim=1, index=s_idx).expand(-1, -1, c), src=src) + + + if unmerge_chunk == 0: + out = out[:,:src_len,:] + else: + out = out[:,src_len:,:] + + return out + + ret_dict = {"unm_num": unm_idx.shape[1]} + return merge, unmerge, ret_dict \ No newline at end of file diff --git a/vidtome/patch.py b/vidtome/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5cc4d75283517c92b9c7587fa892cbe6a63594 --- /dev/null +++ b/vidtome/patch.py @@ -0,0 +1,620 @@ +import os +import math +import time +from typing import Type, Dict, Any, Tuple, Callable + +import numpy as np +from einops import rearrange, repeat +import torch +import torch.nn.functional as F + +from . import merge +from .utils import isinstance_str, init_generator, join_frame, split_frame, func_warper, join_warper, split_warper + + +def compute_merge(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]: + H, original_w = tome_info["size"] + # original_tokens = original_h * original_w + # downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + downsample = tome_info["args"]["downsample"] + + args = tome_info["args"] + # generator = module.generator + + # Frame Number and Token Number + fsize = x.shape[0] // args["batch_size"] + tsize = x.shape[1] + # Merge tokens in high resolution layers + # print(f"[INFO] {args['current_step']} downsample {downsample} time") + mid = x.shape[0] // 2 + + ''' visualize token correspondence ''' + + label = args["label"].split('_') + + # os.makedirs(os.path.join("token_0204_dis", str(args["current_step"])), exist_ok=True) + # os.makedirs(os.path.join("token_0204_dis", str(args["current_step"]), label[0]), exist_ok=True) + # out = os.path.join("token_0204_dis", str(args["current_step"]), label[0], f"correspondence_{label[1]}_{downsample}.png") + # merge.visualize_correspondence(x[0][None], x[mid][None], ratio=0.2, H=H, out=out) + + # corres_dir = "corres_no_dis_4" + # os.makedirs(corres_dir, exist_ok=True) + # if downsample == 1 and label[0] == "unet" and label[1] == "down": + # # merge.visualize_flow_correspondence(x[3][None], x[mid][None], flow=args["controller"].step_store["flows"][3], flow_confid=args["controller"].step_store["flow_confids"][3], \ + # # ratio=0.1, H=(64//downsample), out=f"flow_{label[1]}_{args['current_step']}.png") + # files = os.listdir(corres_dir) + # files = [f for f in files if f.startswith(f"{args['current_step']}")] + # print(files) + # cnt = len(files) + # # if files: + # # cnt = int(files[-1].split('_')[1].split('.')[0]) + 1 + # # else: + # # cnt = 0 + # path = os.path.join(corres_dir, f"{args['current_step']}_{cnt}.png") + # # merge.visualize_cosine_correspondence(x[3][None], x[mid][None], flow=args["controller"].step_store["flows"][3], flow_confid=args["controller"].step_store["flow_confids"][3], \ + # # ratio=0.1, H=(64//downsample), out=path, controller=args["controller"]) + # merge.visualize_correspondence(x[1][None], x[mid][None], flow=args["controller"].step_store["flows"][1], flow_confid=args["controller"].step_store["flow_confids"][1], \ + # ratio=0.1, H=(64//downsample), out=path, controller=args["controller"]) + # #ratio=args["local_merge_ratio"], H=(64//downsample), out=f"flow_correspondence.png") + + ''' visulaize token correspondence ended ''' + + if downsample <= args["max_downsample"] and downsample > args["min_downsample"]: + # print(f"[INFO] downsample: {args['min_downsample']} < {downsample} <= {args['max_downsample']} token shape: {x.shape} H: {H}") + if args["generator"] is None: + args["generator"] = init_generator(x.device) + # module.generator = module.generator.manual_seed(123) + elif args["generator"].device != x.device: + args["generator"] = init_generator(x.device, fallback=args["generator"]) + + # Local Token Merging! + + local_tokens = join_frame(x, fsize) + m_ls = [join_warper(fsize)] + u_ls = [split_warper(fsize)] + unm = 0 + curF = fsize + + # Recursive merge multi-frame tokens into one set. Such as 4->1 for 4 frames and 8->2->1 for 8 frames when target stride is 4. + while curF > 1: + current_step = args["current_step"] + if args["controller"] is not None: + controller, total_step = args["controller"], args["controller"].total_step + else: + controller, total_step = None, 1000 + + if controller is not None and label[0] == "unet" and label[1] == "down": + print(f"[INFO] flow merge @ {label[0]} {label[1]} {downsample}") + start = time.time() + m, u, ret_dict = merge.bipartite_soft_matching_randframe( + local_tokens, curF, args["local_merge_ratio"], unm, generator=args["generator"], + target_stride=x.shape[0], align_batch=args["align_batch"], + H=H, + flow_merge=True, + controller=controller, + ) + else: + m, u, ret_dict = merge.bipartite_soft_matching_randframe( + local_tokens, curF, args["local_merge_ratio"], unm, generator=args["generator"], + target_stride=x.shape[0], align_batch=args["align_batch"], + H=H, + flow_merge=False, + controller=controller, + ) + + # if controller is not None and label[1] == "up" and \ + # controller.merge_period[0] > 0 and \ + # (current_step + 5) >= min(controller.ToMe_period[1], controller.merge_period[0]) * total_step: + # # or current_step == int(controller.ToMe_period[1] * total_step)): + # print(f"[INFO] setting controller merge @ step {current_step} {label} {downsample}") + # # ret_dict["scores"].repeat(1, 4, 4) + # # import time + # # start = time.time() + # scores = ret_dict["scores"] + # if downsample > 1: + # scores = rearrange(scores, "1 (b h1 w1) (h2 w2) -> b h1 w1 h2 w2", h1=H, h2=H, b=fsize-1) + # scores = scores.repeat_interleave(downsample, dim=-1).repeat_interleave(downsample, dim=-2) + # scores = scores.repeat_interleave(downsample, dim=1).repeat_interleave(downsample, dim=2) + # scores = rearrange(scores, "b h1 w1 h2 w2 -> 1 (b h1 w1) (h2 w2)") + # # print(f"[INFO] repeat time {time.time() - start}") + # # import ipdb; ipdb.set_trace() + # # merge.visualize_correspondence_score(x[0][None], x[mid][None], score=ret_dict["scores"][:,:tsize], ratio=0.5, H=H, out="latent_correspondence_1.png") + # # merge.visualize_correspondence_score(x[0][None], x[mid][None], score=controller.step_store["corres_scores"][:,:tsize], ratio=0.5, H=H, out="latent_correspondence_1.png") + # if controller.step_store["corres_scores"] is None: + # controller.step_store["corres_scores"] = scores + # else: + # controller.step_store["corres_scores"] += scores + + + unm += ret_dict["unm_num"] + m_ls.append(m) + u_ls.append(u) + local_tokens = m(local_tokens) + + # assert (x.shape[1] - unm) % tsize == 0 + # Total token number = current frame number * per-frame token number + unmerged token number + curF = (local_tokens.shape[1] - unm) // tsize + # print(f"[INFO] curF {curF}") + + merged_tokens = local_tokens + + # Global Token Merging! + if args["merge_global"]: + if hasattr(module, "global_tokens") and module.global_tokens is not None: + # Merge local tokens with global tokens. Randomly determine merging destination. + if torch.rand(1, generator=args["generator"], device=args["generator"].device) > args["global_rand"]: + src_len = local_tokens.shape[1] + tokens = torch.cat( + [local_tokens, module.global_tokens.to(local_tokens)], dim=1) + local_chunk = 0 + else: + src_len = module.global_tokens.shape[1] + tokens = torch.cat( + [module.global_tokens.to(local_tokens), local_tokens], dim=1) + local_chunk = 1 + m, u, _ = merge.bipartite_soft_matching_2s( + tokens, src_len, args["global_merge_ratio"], args["align_batch"], unmerge_chunk=local_chunk) + merged_tokens = m(tokens) + # print(f"[INFO] global merging {local_tokens.shape} {tokens.shape} {merged_tokens.shape}") + # import ipdb; ipdb.set_trace() + m_ls.append(m) + u_ls.append(u) + + # Update global tokens with unmerged local tokens. There should be a better way to do this. + module.global_tokens = u(merged_tokens).detach().clone().cpu() + else: + module.global_tokens = local_tokens.detach().clone().cpu() + + m = func_warper(m_ls) + u = func_warper(u_ls[::-1]) + else: + m, u = (merge.do_nothing, merge.do_nothing) + merged_tokens = x + + # if args["current_step"] >= 30: + # x_ = u(m(x)) + # print(f"[INFO] {label[0]} {label[1]} {downsample}") + # for i, j in zip(x, x_): + # print(f"[INFO] mean {torch.mean(i).item()} {torch.mean(j).item()}") + # print(f"[INFO] std {torch.std(i).item()} {torch.std(j).item()}") + # import ipdb; ipdb.set_trace() + + # Return merge op, unmerge op, and merged tokens. + return m, u, merged_tokens + +def PCA_token(token: torch.Tensor, token_h=64, n=3): + from sklearn.decomposition import PCA + import cv2 + pca = PCA(n_components=n) # reduce to 2 dimensions + # Fit the PCA model to your data and apply the dimensionality reduction + token = pca.fit_transform(token[0].cpu()) + # import ipdb; ipdb.set_trace() + token = rearrange(token, '(h w) c -> h w c', h=token_h) + token = (token - token.min()) / (token.max() - token.min()) + token = (np.clip(token, 0, 1) * 255).astype(np.uint8) + cv2.imwrite(f'token.png', token) + return token + +from utils.flow_utils import flow_warp +def warp_token(module: torch.nn.Module, x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]: + original_h, original_w = tome_info["size"] + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + + args = tome_info["args"] + # generator = module.generator + + # Frame Number and Token Number + fsize = x.shape[0] // args["batch_size"] + tsize = x.shape[1] + # print(f"[INFO] token size {x.shape[1]}, latent size 64 x 120, downsample {downsample} time") + # Merge tokens in high resolution layers + total_step = 50 + warp_period = (0, 1) + if downsample <= args["max_downsample"] and x.shape[1] == 64 * 120: + if args["current_step"] >= total_step * warp_period[0] and \ + args["current_step"] <= total_step * warp_period[1]: + mid = x.shape[0] // 2 + x = rearrange(x, 'b (h w) c -> b c h w', h=64) + # import ipdb; ipdb.set_trace() + # mid_x = repeat(x[mid][None], 'b c h w -> (repeat b) c h w', repeat=x.shape[0]) + for i in range(x.shape[0]): + if i == mid: + continue + x[i] = flow_warp(x[mid][None], args["flows"][i], mode='nearest')[0] * args["occlusion_masks"][i] + \ + (1 - args["occlusion_masks"][i]) * x[i] + x = rearrange(x, 'b c h w -> b (h w) c', h=64) + return x + + +def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: + """ + Make a patched class on the fly so we don't have to import any specific modules. + This patch applies ToMe to the forward function of the block. + """ + + class ToMeBlock(block_class): + # Save for unpatching later + _parent = block_class + + def _forward(self, x: torch.Tensor, context: torch.Tensor = None, label: str = None) -> torch.Tensor: + # m_a, m_c, m_m, u_a, u_c, u_m = compute_merge( + # self, x, self._tome_info) + # print(f"[INFO] ~~~ ToMeblock ~~~ {label} ~~~") + + B, A, C = x.shape + original_h, original_w = self._tome_info["size"] + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // A))) + # print(f"[INFO] downsample {downsample} time A {A} original_h {original_h} original_w {original_w}") + self._tome_info["args"]["downsample"] = downsample + H, W = original_h // downsample, original_w // downsample + + if self._tome_info["args"]["controller"] is None: + non_pad_ratio_h, non_pad_ratio_w = 1, 1 + print(f"[INFO] no padding removal") + else: + non_pad_ratio_h, non_pad_ratio_w = self._tome_info["args"]["controller"].non_pad_ratio + + padding_size_w = W - int(W * non_pad_ratio_w) + padding_size_h = H - int(H * non_pad_ratio_h) + padding_mask = torch.zeros((H, W), device=x.device, dtype=torch.bool) + if padding_size_w: + padding_mask[:, -padding_size_w:] = 1 + if padding_size_h: + padding_mask[-padding_size_h:, :] = 1 + padding_mask = rearrange(padding_mask, 'h w -> (h w)') + + idx_buffer = torch.arange(A, device=x.device, dtype=torch.int64) + non_pad_idx = idx_buffer[None, ~padding_mask, None] + # pad_idx = idx_buffer[None, padding_mask, None] + del idx_buffer, padding_mask + x_non_pad = torch.gather(x, dim=1, index=non_pad_idx.expand(B, -1, C)) + self._tome_info["args"]["label"] = label + self._tome_info["size"] = (int(H * non_pad_ratio_h), int(W * non_pad_ratio_w)) + # self._tome_info["non_pad_size"] = (int(H * non_pad_ratio_h), int(W * non_pad_ratio_w)) + # print(f"[INFO] original shape {x.shape} removed padding shape {x_non_pad.shape}") + m_a, u_a, merged_tokens = compute_merge( + self, self.norm1(x_non_pad), self._tome_info) + # current_step, total_step = self._tome_info["args"]["current_step"], self._tome_info["args"]["controller"].total_step + # print(f'[INFO] {int(self._tome_info["args"]["controller"].ToMe_period[1] * total_step)} {current_step} {total_step}') + # if downsample == 1 and label == "unet_up" and \ + # self._tome_info["args"]["controller"].merge_period[0] > 0 and \ + # (current_step >= self._tome_info["args"]["controller"].merge_period[0] * total_step \ + # or current_step == int(self._tome_info["args"]["controller"].ToMe_period[1] * total_step)): + # print(f"[INFO] setting controller merge @ step {self._tome_info['args']['current_step']}") + # self._tome_info["args"]["controller"].set_merge(m_a, u_a) + # m_a, u_a, merged_tokens = compute_merge( + # self, self.norm1(x), self._tome_info) + # This is where the meat of the computation happens + # test = u_a(self.attn1(m_a(self.norm1(x)), context=context if self.disable_self_attn else None)) + # import ipdb; ipdb.set_trace() + # x = u_a(merged_tokens) + ''' global merging ''' + if self._tome_info["args"]["controller"] is None: + print(f"[INFO] local + global merging ... ") + x_non_pad = u_a(self.attn1(merged_tokens, + context=context if self.disable_self_attn else None)) + x_non_pad + else: + x_non_pad = u_a(self.attn1(m_a(self.norm1(x_non_pad)), + context=context if self.disable_self_attn else None)) + x_non_pad + # print(label, downsample, self.disable_self_attn) + # x = u_a(self.attn1(m_a(self.norm1(x)), + # context=context if self.disable_self_attn else None)) + x + + # attn_output = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + # attn_output = warp_token(self, attn_output, self._tome_info) + # x = attn_output + x + + + # attn_out = self.attn2(self.norm2(x), context=context) + # m_a, u_a, merged_tokens = compute_merge( + # self, attn_out, self._tome_info) + # x = u_a(m_a(attn_out)) + x + + # attn_output = self.attn2(self.norm2(x), context=context) + # attn_output = warp_token(self, attn_output, self._tome_info) + # x = attn_output + x + x_non_pad = self.attn2(self.norm2(x_non_pad), context=context) + x_non_pad + x_non_pad = self.ff(self.norm3(x_non_pad)) + x_non_pad + x.scatter_(dim=1, index=non_pad_idx.expand(B, -1, C), src=x_non_pad) + del x_non_pad + self._tome_info["size"] = (original_h, original_w) + torch.cuda.empty_cache() + # x = self.attn2(self.norm2(x), context=context) + x + # x = self.ff(self.norm3(x)) + x + + return x + + return ToMeBlock + + +def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]: + """ + Make a patched class for a diffusers model. + This patch applies ToMe to the forward function of the block. + """ + class ToMeBlock(block_class): + # Save for unpatching later + _parent = block_class + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ) -> torch.Tensor: + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # Merge input tokens + + m_a, u_a, merged_tokens = compute_merge( + self, norm_hidden_states, self._tome_info) + norm_hidden_states = merged_tokens + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + # tt = time.time() + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + # print(time.time() - tt) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + + # Unmerge output tokens + attn_output = u_a(attn_output) + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2( + hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + return ToMeBlock + + +def hook_tome_model(model: torch.nn.Module): + """ Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """ + def hook(module, args): + # print(args[0].shape) + module._tome_info["size"] = (args[0].shape[2], args[0].shape[3]) + return None + model._tome_info["hooks"].append(model.register_forward_pre_hook(hook)) + + +def hook_tome_module(module: torch.nn.Module): + """ Adds a forward pre hook to initialize random number generator. + All modules share the same generator state to keep their randomness in VidToMe consistent in one pass. + This hook can be removed with remove_patch. """ + def hook(module, args): + # import ipdb; ipdb.set_trace() + if not hasattr(module, "generator"): + module.generator = init_generator(args[0].device) + elif module.generator.device != args[0].device: + module.generator = init_generator( + args[0].device, fallback=module.generator) + else: + return None + + # module.generator = module.generator.manual_seed(module._tome_info["args"]["seed"]) + return None + + module._tome_info["hooks"].append(module.register_forward_pre_hook(hook)) + + +def apply_patch( + model: torch.nn.Module, + local_merge_ratio: float = 0.9, + merge_global: bool = False, + global_merge_ratio = 0.8, + max_downsample: int = 2, + min_downsample: int = 0, + seed: int = 123, + batch_size: int = 2, + include_control: bool = False, + align_batch: bool = False, + target_stride: int = 4, + global_rand=0.5): + """ + Patches a stable diffusion model with VidToMe. + Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model). + + Important Args: + - model: A top level Stable Diffusion module to patch in place. Should have a ".model.diffusion_model" + - local_merge_ratio: The ratio of tokens to merge locally. I.e., 0.9 would merge 90% src tokens. + If there are 4 frames in a chunk (3 src, 1 dst), the compression ratio will be 1.3 / 4.0. + And the largest compression ratio is 0.25 (when local_merge_ratio = 1.0). + Higher values result in more consistency, but with more visual quality loss. + - merge_global: Whether or not to include global token merging. + - global_merge_ratio: The ratio of tokens to merge locally. I.e., 0.8 would merge 80% src tokens. + When find significant degradation in video quality. Try to lower the value. + + Args to tinker with if you want: + - max_downsample [1, 2, 4, or 8]: Apply VidToMe to layers with at most this amount of downsampling. + E.g., 1 only applies to layers with no downsampling (4/15) while + 8 applies to all layers (15/15). I recommend a value of 1 or 2. + - seed: Manual random seed. + - batch_size: Video batch size. Number of video chunks in one pass. When processing one video, it + should be 2 (cond + uncond) or 3 (when using PnP, source + cond + uncond). + - include_control: Whether or not to patch ControlNet model. + - align_batch: Whether or not to align similarity matching maps of samples in the batch. It should + be True when using PnP as control. + - target_stride: Stride between target frames. I.e., when target_stride = 4, there is 1 target frame + in any 4 consecutive frames. + - global_rand: Probability in global token merging src/dst split. Global tokens are always src when + global_rand = 1.0 and always dst when global_rand = 0.0 . + """ + + # Make sure the module is not currently patched + remove_patch(model) + + is_diffusers = isinstance_str( + model, "DiffusionPipeline") or isinstance_str(model, "ModelMixin") + + if not is_diffusers: + if (not hasattr(model, "model") or not hasattr(model.model, "diffusion_model")) \ + and not hasattr(model, "unet"): + # Provided model not supported + raise RuntimeError( + "Provided model was not a Stable Diffusion / Latent Diffusion model, as expected.") + else: + diffusion_model = model.unet if hasattr(model, "unet") else model.model.diffusion_model + else: + # Supports "pipe.unet" and "unet" + diffusion_model = model.unet if hasattr(model, "unet") else model + + if isinstance_str(model, "StableDiffusionControlNetPipeline") and include_control: + diffusion_models = [diffusion_model, model.controlnet] + else: + diffusion_models = [diffusion_model] + + if not is_diffusers and hasattr(model, "controlnet"): + diffusion_models = [diffusion_model, model.controlnet] + + for diffusion_model in diffusion_models: + diffusion_model._tome_info = { + "size": None, + "hooks": [], + "args": { + "max_downsample": max_downsample, + "min_downsample": min_downsample, + "generator": None, + "seed": seed, + "batch_size": batch_size, + "align_batch": align_batch, + "merge_global": merge_global, + "global_merge_ratio": global_merge_ratio, + "local_merge_ratio": local_merge_ratio, + "global_rand": global_rand, + "target_stride": target_stride, + "current_step": 0, + "frame_ids": [0], + "flows": None, + "occlusion_masks": None, + "flow_confids": None, + "label": "", + "downsample": 1, + "non_pad_size": (0, 0), + "controller": None, + } + } + hook_tome_model(diffusion_model) + + for name, module in diffusion_model.named_modules(): + # If for some reason this has a different name, create an issue and I'll fix it + # if isinstance_str(module, "BasicTransformerBlock") and "down_blocks" not in name: + # print(module.__class__.__name__) + if isinstance_str(module, "BasicTransformerBlock"): + make_tome_block_fn = make_diffusers_tome_block if is_diffusers else make_tome_block + module.__class__ = make_tome_block_fn(module.__class__) + module._tome_info = diffusion_model._tome_info + hook_tome_module(module) + + # Something introduced in SD 2.0 (LDM only) + if not hasattr(module, "disable_self_attn") and not is_diffusers: + module.disable_self_attn = False + + # Something needed for older versions of diffusers + if not hasattr(module, "use_ada_layer_norm_zero") and is_diffusers: + module.use_ada_layer_norm = False + module.use_ada_layer_norm_zero = False + # import ipdb; ipdb.set_trace() + return model + + +def remove_patch(model: torch.nn.Module): + """ Removes a patch from a ToMe Diffusion module if it was already patched. """ + # For diffusers + modelu = model.unet if hasattr(model, "unet") else model + model_ls = [modelu] + if hasattr(model, "controlnet"): + model_ls.append(model.controlnet) + for model in model_ls: + for _, module in model.named_modules(): + if hasattr(module, "_tome_info"): + for hook in module._tome_info["hooks"]: + hook.remove() + module._tome_info["hooks"].clear() + + if module.__class__.__name__ == "ToMeBlock": + module.__class__ = module._parent + + return model + + +def update_patch(model: torch.nn.Module, **kwargs): + """ Update arguments in patched modules """ + # For diffusers + model0 = model.unet if hasattr(model, "unet") else model + model_ls = [model0] + if hasattr(model, "controlnet"): + model_ls.append(model.controlnet) + for model in model_ls: + for _, module in model.named_modules(): + if hasattr(module, "_tome_info"): + for k, v in kwargs.items(): + # setattr(module, k, v) + if k in module._tome_info["args"]: + module._tome_info["args"][k] = v + # print(f"[INFO] update {k} to {v} in {module.__class__.__name__}") + return model + + +def collect_from_patch(model: torch.nn.Module, attr="tome"): + """ Collect attributes in patched modules """ + # For diffusers + model0 = model.unet if hasattr(model, "unet") else model + model_ls = [model0] + if hasattr(model, "controlnet"): + model_ls.append(model.controlnet) + ret_dict = dict() + for model in model_ls: + for name, module in model.named_modules(): + if hasattr(module, attr): + res = getattr(module, attr) + ret_dict[name] = res + + return ret_dict diff --git a/vidtome/utils.py b/vidtome/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c084454b292891440332fbe4fdd6becb0679e2c9 --- /dev/null +++ b/vidtome/utils.py @@ -0,0 +1,59 @@ +import torch +from einops import rearrange + +def isinstance_str(x: object, cls_name: str): + """ + Checks whether x has any class *named* cls_name in its ancestry. + Doesn't require access to the class's implementation. + + Useful for patching! + """ + for _cls in x.__class__.__mro__: + if _cls.__name__ == cls_name: + return True + + return False + +def init_generator(device: torch.device, fallback: torch.Generator=None): + """ + Forks the current default random generator given device. + """ + if device.type == "cpu": + return torch.Generator(device="cpu").set_state(torch.get_rng_state()) + elif device.type == "cuda": + return torch.Generator(device=device).set_state(torch.cuda.get_rng_state()) + else: + if fallback is None: + return init_generator(torch.device("cpu")) + else: + return fallback + +def join_frame(x, fsize): + """ Join multi-frame tokens """ + x = rearrange(x, "(B F) N C -> B (F N) C", F=fsize) + return x + +def split_frame(x, fsize): + """ Split multi-frame tokens """ + x = rearrange(x, "B (F N) C -> (B F) N C", F=fsize) + return x + +def func_warper(funcs): + """ Warp a function sequence """ + def fn(x, **kwarg): + for func in funcs: + x = func(x, **kwarg) + return x + return fn + +def join_warper(fsize): + def fn(x): + x = join_frame(x, fsize) + return x + return fn + +def split_warper(fsize): + def fn(x): + x = split_frame(x, fsize) + return x + return fn diff --git a/weights/BSRNet.pth b/weights/BSRNet.pth new file mode 100644 index 0000000000000000000000000000000000000000..795a44e8fd4d64612dbf5faff11a5918f5424210 --- /dev/null +++ b/weights/BSRNet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa633d80ff4db5a546740ae8e3baebe925fc84d3c03e3e9002493acc5e88c3ec +size 67046751 diff --git a/weights/gmflow_sintel-0c07dcb3.pth b/weights/gmflow_sintel-0c07dcb3.pth new file mode 100644 index 0000000000000000000000000000000000000000..206c914272a1e515b0881350eab262ba36ce07ef --- /dev/null +++ b/weights/gmflow_sintel-0c07dcb3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c07dcb35770464f38a5ff4de18c04177b242dc5de8cd2068adf46f3d4fe193a +size 18768907 diff --git a/weights/scunet_color_real_psnr.pth b/weights/scunet_color_real_psnr.pth new file mode 100644 index 0000000000000000000000000000000000000000..5642477ee11c1b655656dd377e1764e02dbba990 --- /dev/null +++ b/weights/scunet_color_real_psnr.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa78899ba2caec9d235a900e91d96c689da71c42029230c2028b00f09f809c2e +size 71982841 diff --git a/weights/v2-1_512-ema-pruned.ckpt b/weights/v2-1_512-ema-pruned.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..6b3b1390add5046368c3e1229fde7bffc6532323 --- /dev/null +++ b/weights/v2-1_512-ema-pruned.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:88ecb782561455673c4b78d05093494b9c539fc6bfc08f3a9a4a0dd7b0b10f36 +size 5214865159 diff --git a/weights/v2.pth b/weights/v2.pth new file mode 100644 index 0000000000000000000000000000000000000000..ff27b75c416412ab3a63f2366a9ffcdf764a6435 --- /dev/null +++ b/weights/v2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c023992d042a4e966a335b084fae7a4e38b3ac3d0df819398c839ff371ebaba5 +size 1452694018