diff --git a/3DTopia/.gitignore b/3DTopia/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0be8b7cbb5e914722cdeb68149788f8ddeab2820 --- /dev/null +++ b/3DTopia/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +checkpoints +results +tmp \ No newline at end of file diff --git a/3DTopia/LICENSE b/3DTopia/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/3DTopia/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/3DTopia/README.md b/3DTopia/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a9db3069cd3c8da9a29aafffb3d43b7644563b93 --- /dev/null +++ b/3DTopia/README.md @@ -0,0 +1,65 @@ +

+ + logo + +

+
+

3DTopia

+ A two-stage text-to-3D generation model. The first stage uses diffusion model to quickly generate candidates. The second stage refines the assets chosen from the first stage. + +https://github.com/3DTopia/3DTopia/assets/23376858/c9716cf0-6e61-4983-82b2-2e8f579bd46c + +
+ +## News + +[2024/01/18] We release a text-to-3D model 3DTopia! + +## 1. Quick Start + +### 1.1 Install Environment for this Repository +We recommend using Anaconda to manage the environment. +```bash +conda env create -f environment.yml +``` + +### 1.2 Install Second Stage Refiner +Please refer to [threefiner](https://github.com/3DTopia/threefiner) to install our second stage mesh refiner. We have tested installing both environments together with Pytorch 1.12.0 and CUDA 11.3. + +### 1.3 Download Checkpoints \[Optional\] +We have implemented automatic checkpoint download for both `gradio_demo.py` and `sample_stage1.py`. If you prefer to download manually, you may download checkpoint `3dtopia_diffusion_state_dict.ckpt` or `model.safetensors` from [huggingface](https://huggingface.co/hongfz16/3DTopia). + +### Q&A +- If you encounter this error in the second stage `ImportError: /lib64/libc.so.6: version 'GLIBC_2.25' not found`, try to install a lower version of pymeshlab by `pip install pymeshlab==0.2`. + +## 2. Inference + +### 2.1 First Stage +Run the following command to sample `a robot` as the first stage. Results will be located under the folder `results`. +```bash +python -u sample_stage1.py --text "a robot" --samples 1 --sampler ddim --steps 200 --cfg_scale 7.5 --seed 0 +``` + +Arguments: +- `--ckpt` specifies checkpoint file path; +- `--test_folder` controls which subfolder to put all the results; +- `--seed` will fix random seeds; `--sampler` can be set to `ddim` for DDIM sampling (By default, we use 1000 steps DDPM sampling); +- `--steps` controls sampling steps only for DDIM; +- `--samples` controls number of samples; +- `--text` is the input text; +- `--no_video` and `--no_mcubes` suppress rendering multi-view videos and marching cubes, which are by-default enabled; +- `--mcubes_res` controls the resolution of the 3D volumn sampled for marching cubes; One can lower this resolution to save graphics memory; +- `--render_res` controls the resolution of the rendered video; + +### 2.2 Second Stage +There are two steps as the second stage refinement. Here is a simple example. Please refer to [threefiner](https://github.com/3DTopia/threefiner) for more detailed usage. +```bash +# step 1 +threefiner sd --mesh results/default/stage1/a_robot_0_0.ply --prompt "a robot" --text_dir --front_dir='-y' --outdir results/default/stage2/ --save a_robot_0_0_sd.glb +# step 2 +threefiner if2 --mesh results/default/stage2/a_robot_0_0_sd.glb --prompt "a robot" --outdir results/default/stage2/ --save a_robot_0_0_if2.glb +``` +The resulting mesh can be found at `results/default/stage2/a_robot_0_0_if2.glb` + +## 3. Acknowledgement +We thank the community for building and open-sourcing the foundation of this work. Specifically, we want to thank [EG3D](https://github.com/NVlabs/eg3d), [Stable Diffusion](https://github.com/CompVis/stable-diffusion) for their codes. We also want to thank [Objaverse](https://objaverse.allenai.org) for the wonderful dataset. diff --git a/3DTopia/assets/3dtopia.jpeg b/3DTopia/assets/3dtopia.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..b4c1f6b718da9519547a2ce7fc766c7ad29247b9 Binary files /dev/null and b/3DTopia/assets/3dtopia.jpeg differ diff --git a/3DTopia/assets/sample_data/pose/000000.txt b/3DTopia/assets/sample_data/pose/000000.txt new file mode 100644 index 0000000000000000000000000000000000000000..f06623177e8b5f19d1fb96b5b0d0441ae6048f4f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000000.txt @@ -0,0 +1 @@ +-0.8414713144302368 -0.5386366844177246 -0.04239124804735184 0.05086996778845787 3.72529200376448e-07 -0.07845887541770935 0.9969174861907959 -1.1963008642196655 -0.5403022766113281 0.838877260684967 0.06602128595113754 -0.07922526448965073 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000001.txt b/3DTopia/assets/sample_data/pose/000001.txt new file mode 100644 index 0000000000000000000000000000000000000000..eb2e88460ba96961a3509c60a3d1bc363cf61731 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000001.txt @@ -0,0 +1 @@ +-0.9320390224456787 -0.3607495129108429 -0.03410102799534798 0.04092103987932205 -2.0861622829215776e-07 -0.0941082164645195 0.9955618977546692 -1.1946742534637451 -0.3623576760292053 0.9279026389122009 0.08771242946386337 -0.105255126953125 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000002.txt b/3DTopia/assets/sample_data/pose/000002.txt new file mode 100644 index 0000000000000000000000000000000000000000..70a984ac99d014a045401e25d51e90476bb2468b --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000002.txt @@ -0,0 +1 @@ +-0.9854495525360107 -0.1689407229423523 -0.0186510868370533 0.022381475195288658 1.5646213569198153e-07 -0.10973447561264038 0.9939608573913574 -1.1927531957626343 -0.1699671447277069 0.9794984459877014 0.10813764482736588 -0.12976518273353577 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000003.txt b/3DTopia/assets/sample_data/pose/000003.txt new file mode 100644 index 0000000000000000000000000000000000000000..93692ad3e17af9bca69cbe8be8ab2e7b89d93fae --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000003.txt @@ -0,0 +1 @@ +-0.9995737075805664 0.028960729017853737 0.0036580152809619904 -0.0043916041031479836 -5.648469141306123e-07 -0.12533310055732727 0.9921146631240845 -1.1905378103256226 0.02919083461165428 0.9916918873786926 0.1252794861793518 -0.15033574402332306 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000004.txt b/3DTopia/assets/sample_data/pose/000004.txt new file mode 100644 index 0000000000000000000000000000000000000000..dd0b6b5c01faf4ac7bb4d8f549f990f61da59387 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000004.txt @@ -0,0 +1 @@ +-0.973847508430481 0.2249356210231781 0.03201328590512276 -0.03841566666960716 1.5646217832454568e-07 -0.14090147614479065 0.9900236129760742 -1.188028335571289 0.22720229625701904 0.9641320705413818 0.1372164487838745 -0.16465960443019867 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000005.txt b/3DTopia/assets/sample_data/pose/000005.txt new file mode 100644 index 0000000000000000000000000000000000000000..e00dfcbb2a62f6f5f93b195e4c7579b7eae30a1d --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000005.txt @@ -0,0 +1 @@ +-0.9092977046966553 0.4110230505466461 0.06509938091039658 -0.078119657933712 -3.83704957584996e-07 -0.15643461048603058 0.987688422203064 -1.1852262020111084 0.416146457195282 0.8981026411056519 0.14224585890769958 -0.17069458961486816 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000006.txt b/3DTopia/assets/sample_data/pose/000006.txt new file mode 100644 index 0000000000000000000000000000000000000000..86bbbaef592616de66c5adaaafdad35054f7f257 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000006.txt @@ -0,0 +1 @@ +-0.8084962368011475 0.5797380208969116 0.1011803075671196 -0.12141657620668411 -2.2351736461700966e-08 -0.17192888259887695 0.9851093292236328 -1.1821314096450806 0.5885012149810791 0.7964572906494141 0.13900375366210938 -0.16680487990379333 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000007.txt b/3DTopia/assets/sample_data/pose/000007.txt new file mode 100644 index 0000000000000000000000000000000000000000..f89eba6a7e6ad8112a94d48cbb0c89b1089aad2f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000007.txt @@ -0,0 +1 @@ +-0.6754631996154785 0.7243322134017944 0.13817371428012848 -0.1658085733652115 -1.6391271628890536e-07 -0.18738147616386414 0.9822871685028076 -1.1787446737289429 0.7373935580253601 0.6634989380836487 0.1265692412853241 -0.15188303589820862 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000008.txt b/3DTopia/assets/sample_data/pose/000008.txt new file mode 100644 index 0000000000000000000000000000000000000000..21c17dfd0c3c174320d8f8cdc59cd1ca04f3f5fc --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000008.txt @@ -0,0 +1 @@ +-0.5155014991760254 0.8390849828720093 0.17376619577407837 -0.20851942896842957 1.2665987014770508e-07 -0.20278730988502502 0.9792228937149048 -1.175067663192749 0.8568887114524841 0.5047908425331116 0.1045369878411293 -0.12544460594654083 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000009.txt b/3DTopia/assets/sample_data/pose/000009.txt new file mode 100644 index 0000000000000000000000000000000000000000..ed0c8b01f0cec6488b4dd0d164cf5ef1db14edcd --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000009.txt @@ -0,0 +1 @@ +-0.3349881172180176 0.91953045129776 0.20553946495056152 -0.24664731323719025 -1.0430810704065152e-07 -0.21814337372779846 0.9759166836738586 -1.1710999011993408 0.9422222375869751 0.3269205689430237 0.07307547330856323 -0.08769046515226364 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000010.txt b/3DTopia/assets/sample_data/pose/000010.txt new file mode 100644 index 0000000000000000000000000000000000000000..89bb8d731884d8d45904d6cc2c7fb16faecf5516 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000010.txt @@ -0,0 +1 @@ +-0.1411200612783432 0.9626389741897583 0.23110917210578918 -0.2773309648036957 -1.8998981943241233e-07 -0.2334454208612442 0.9723699688911438 -1.1668438911437988 0.9899925589561462 0.13722087442874908 0.03294399753212929 -0.03953259065747261 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000011.txt b/3DTopia/assets/sample_data/pose/000011.txt new file mode 100644 index 0000000000000000000000000000000000000000..fd2e8b687234a44c62961db3c50562677acd04ae --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000011.txt @@ -0,0 +1 @@ +0.05837392061948776 0.9669322967529297 0.24826295673847198 -0.29791900515556335 -1.9557775843281888e-08 -0.2486870288848877 0.9685839414596558 -1.1622999906539917 0.9982947707176208 -0.05654003843665123 -0.014516821131110191 0.01742047443985939 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000012.txt b/3DTopia/assets/sample_data/pose/000012.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec690a4bf437c985b44db3b268164f33fc4a6152 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000012.txt @@ -0,0 +1 @@ +0.2555410861968994 0.9325323700904846 0.255111962556839 -0.3061343729496002 -7.450580596923828e-09 -0.26387304067611694 0.964557409286499 -1.1574687957763672 0.9667981863021851 -0.24648404121398926 -0.06743041425943375 0.08091648668050766 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000013.txt b/3DTopia/assets/sample_data/pose/000013.txt new file mode 100644 index 0000000000000000000000000000000000000000..bdd0f7851359ddbde82bf598055d47fbdcc35225 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000013.txt @@ -0,0 +1 @@ +0.4425203502178192 0.861151397228241 0.25018739700317383 -0.300225168466568 -2.458690460116486e-07 -0.2789909243583679 0.9602935910224915 -1.1523523330688477 0.8967583179473877 -0.42494940757751465 -0.12345901876688004 0.14815115928649902 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000014.txt b/3DTopia/assets/sample_data/pose/000014.txt new file mode 100644 index 0000000000000000000000000000000000000000..782685d881a0e2f7f93f64b9f7fd229064cd0ec0 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000014.txt @@ -0,0 +1 @@ +0.6118577718734741 0.7560015320777893 0.23257626593112946 -0.27909165620803833 2.9802322387695312e-08 -0.29404014348983765 0.955793023109436 -1.146951675415039 0.7909678220748901 -0.584809422492981 -0.1799107939004898 0.21589307487010956 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000015.txt b/3DTopia/assets/sample_data/pose/000015.txt new file mode 100644 index 0000000000000000000000000000000000000000..acbfd3638de2adc6042114baee3656d9d716c045 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000015.txt @@ -0,0 +1 @@ +0.756802499294281 0.6216520071029663 0.2019868791103363 -0.2423844039440155 -1.4901159417490817e-08 -0.3090169131755829 0.9510564804077148 -1.1412678956985474 0.6536435484886169 -0.7197620272636414 -0.23386473953723907 0.28063780069351196 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000016.txt b/3DTopia/assets/sample_data/pose/000016.txt new file mode 100644 index 0000000000000000000000000000000000000000..013a71d628da4714379a1a3461d11c4d7ed897ec --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000016.txt @@ -0,0 +1 @@ +0.8715758323669434 0.46382859349250793 0.15880392491817474 -0.19056479632854462 -8.940693874137651e-08 -0.3239172697067261 0.9460852146148682 -1.1353023052215576 0.4902608096599579 -0.8245849609375 -0.28231847286224365 0.33878225088119507 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000017.txt b/3DTopia/assets/sample_data/pose/000017.txt new file mode 100644 index 0000000000000000000000000000000000000000..953c867619ee156d85a30b81f43c46b727c72908 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000017.txt @@ -0,0 +1 @@ +0.951602041721344 0.2891636788845062 0.10410525649785995 -0.1249263659119606 7.450580596923828e-09 -0.33873775601387024 0.9408808350563049 -1.1290569305419922 0.30733293294906616 -0.8953441381454468 -0.32234352827072144 0.3868124783039093 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000018.txt b/3DTopia/assets/sample_data/pose/000018.txt new file mode 100644 index 0000000000000000000000000000000000000000..e9b756715528be1e34cd4c39a7e33d9026885e1a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000018.txt @@ -0,0 +1 @@ +0.9936910271644592 0.1049124225974083 0.039643093943595886 -0.04757172241806984 -0.0 -0.35347482562065125 0.9354441165924072 -1.1225329637527466 0.11215253174304962 -0.9295423626899719 -0.3512447774410248 0.421493798494339 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000019.txt b/3DTopia/assets/sample_data/pose/000019.txt new file mode 100644 index 0000000000000000000000000000000000000000..0b9c1a7fda4a3ac75568ab601fa815a06f447ab5 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000019.txt @@ -0,0 +1 @@ +0.9961645603179932 -0.08135451376438141 -0.032210517674684525 0.038652628660202026 1.862645149230957e-09 -0.3681243658065796 0.9297765493392944 -1.1157318353652954 -0.0874989926815033 -0.9262105226516724 -0.3667125105857849 0.4400551915168762 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000020.txt b/3DTopia/assets/sample_data/pose/000020.txt new file mode 100644 index 0000000000000000000000000000000000000000..656c47232cb297949c8b24916d4026845e409ea5 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000020.txt @@ -0,0 +1 @@ +0.9589242935180664 -0.2620696723461151 -0.10855279117822647 0.13026338815689087 1.4901161193847656e-08 -0.38268333673477173 0.9238795638084412 -1.108655333518982 -0.28366219997406006 -0.8859305381774902 -0.36696434020996094 0.4403572976589203 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000021.txt b/3DTopia/assets/sample_data/pose/000021.txt new file mode 100644 index 0000000000000000000000000000000000000000..39cb359918f5d10abbc0d3362dd9a5ecf67ff103 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000021.txt @@ -0,0 +1 @@ +0.8834545612335205 -0.4299834668636322 -0.18607036769390106 0.22328442335128784 4.470348002882929e-08 -0.39714762568473816 0.9177546501159668 -1.101305603981018 -0.4685167372226715 -0.8107945919036865 -0.35086193680763245 0.4210345447063446 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000022.txt b/3DTopia/assets/sample_data/pose/000022.txt new file mode 100644 index 0000000000000000000000000000000000000000..c0094b6672bb6244a2cd51b02c37fc8051b5c90a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000022.txt @@ -0,0 +1 @@ +0.7727645039558411 -0.578461229801178 -0.2611852288246155 0.3134223520755768 -2.980232949312267e-08 -0.41151440143585205 0.9114034175872803 -1.093684196472168 -0.6346929669380188 -0.7043001651763916 -0.31800374388694763 0.38160452246665955 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000023.txt b/3DTopia/assets/sample_data/pose/000023.txt new file mode 100644 index 0000000000000000000000000000000000000000..6a0887ec9ec4f11fb3d6390abaf7783cecc0ea6f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000023.txt @@ -0,0 +1 @@ +0.6312665939331055 -0.7017531991004944 -0.33021968603134155 0.3962639272212982 1.4901161193847656e-08 -0.4257790148258209 0.9048272371292114 -1.0857925415039062 -0.775566041469574 -0.5711871385574341 -0.26878002285957336 0.32253631949424744 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000024.txt b/3DTopia/assets/sample_data/pose/000024.txt new file mode 100644 index 0000000000000000000000000000000000000000..687344b34e28ac322c26e376e7c8393ec4e72365 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000024.txt @@ -0,0 +1 @@ +0.46460211277008057 -0.7952209711074829 -0.38957479596138 0.467489629983902 -0.0 -0.4399392306804657 0.8980275988578796 -1.0776331424713135 -0.8855195641517639 -0.4172254800796509 -0.20439667999744415 0.2452760487794876 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000025.txt b/3DTopia/assets/sample_data/pose/000025.txt new file mode 100644 index 0000000000000000000000000000000000000000..388b1e20e526d3a4d0ffc9f326ea1d75c6e7b6c0 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000025.txt @@ -0,0 +1 @@ +0.279415488243103 -0.8555179834365845 -0.43590813875198364 0.5230898261070251 -7.450580596923828e-09 -0.4539904296398163 0.8910065293312073 -1.069207787513733 -0.960170328617096 -0.24896103143692017 -0.12685194611549377 0.1522223800420761 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000026.txt b/3DTopia/assets/sample_data/pose/000026.txt new file mode 100644 index 0000000000000000000000000000000000000000..64b583bbe7e213a2b9dd13c20b4cbf27917572d9 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000026.txt @@ -0,0 +1 @@ +0.08308916538953781 -0.8807101845741272 -0.4663105905056 0.5595741271972656 2.7939665869780583e-07 -0.46792876720428467 0.8837661147117615 -1.060518741607666 -0.9965419769287109 -0.07343138754367828 -0.0388796441257 0.046656012535095215 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000027.txt b/3DTopia/assets/sample_data/pose/000027.txt new file mode 100644 index 0000000000000000000000000000000000000000..6936a913ed3666bd4f705c4445c2e1692ef56484 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000027.txt @@ -0,0 +1 @@ +-0.11654932051897049 -0.8703341484069824 -0.4784710705280304 0.5741645693778992 1.1175869474300271e-07 -0.481754332780838 0.8763062953948975 -1.0515679121017456 -0.9931849241256714 0.1021328940987587 0.05614820867776871 -0.06737759709358215 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000028.txt b/3DTopia/assets/sample_data/pose/000028.txt new file mode 100644 index 0000000000000000000000000000000000000000..d425c87f033bea37fc20adfada19f8afb77895e0 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000028.txt @@ -0,0 +1 @@ +-0.31154143810272217 -0.8254019618034363 -0.4708009660243988 0.5649611949920654 1.4901161193847656e-08 -0.4954586327075958 0.8686314821243286 -1.0423579216003418 -0.9502326250076294 0.27061471343040466 0.15435591340065002 -0.18522702157497406 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000029.txt b/3DTopia/assets/sample_data/pose/000029.txt new file mode 100644 index 0000000000000000000000000000000000000000..7886694a8b39dea814cd563fba9620cf1617381a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000029.txt @@ -0,0 +1 @@ +-0.494113564491272 -0.7483266592025757 -0.44255948066711426 0.5310712456703186 -1.4901161193847656e-07 -0.5090416669845581 0.860741913318634 -1.0328905582427979 -0.8693974018096924 0.4253043532371521 0.25152426958084106 -0.3018290102481842 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000030.txt b/3DTopia/assets/sample_data/pose/000030.txt new file mode 100644 index 0000000000000000000000000000000000000000..9fc9f1e48c461ab064a28a6a7da8a9682c5d9dbc --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000030.txt @@ -0,0 +1 @@ +-0.6569865345954895 -0.6428073644638062 -0.3939129412174225 0.4726954400539398 1.4901162970204496e-08 -0.522498607635498 0.8526401519775391 -1.0231680870056152 -0.7539023160934448 0.5601730942726135 0.3432745933532715 -0.41192948818206787 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000031.txt b/3DTopia/assets/sample_data/pose/000031.txt new file mode 100644 index 0000000000000000000000000000000000000000..b9be832443b12709ae39f75016318dc78ac5077d --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000031.txt @@ -0,0 +1 @@ +-0.7936679124832153 -0.5136478543281555 -0.3259711265563965 0.39116519689559937 -2.2351741790771484e-07 -0.5358269214630127 0.8443279266357422 -1.0131936073303223 -0.6083512902259827 0.6701160073280334 0.4252684712409973 -0.5103223323822021 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000032.txt b/3DTopia/assets/sample_data/pose/000032.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3c77ddb53b7661d1f25778b6f0925a437178a7b --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000032.txt @@ -0,0 +1 @@ +-0.8987078070640564 -0.36654093861579895 -0.2407723218202591 0.28892695903778076 1.639126594454865e-07 -0.5490229725837708 0.8358070850372314 -1.0029689073562622 -0.43854713439941406 0.7511465549468994 0.49341118335723877 -0.5920935273170471 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000033.txt b/3DTopia/assets/sample_data/pose/000033.txt new file mode 100644 index 0000000000000000000000000000000000000000..13dba8064732a2969c047bd7eee5438ac99303cd --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000033.txt @@ -0,0 +1 @@ +-0.9679195880889893 -0.2078111618757248 -0.14122851192951202 0.1694747358560562 -1.8626440123625798e-07 -0.56208336353302 0.827080488204956 -0.9924965500831604 -0.2512587904930115 0.8005476593971252 0.5440514087677002 -0.6528617739677429 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000034.txt b/3DTopia/assets/sample_data/pose/000034.txt new file mode 100644 index 0000000000000000000000000000000000000000..3d2298ea6b094f1da0aaefb40b2cfdda07b2e08e --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000034.txt @@ -0,0 +1 @@ +-0.9985430240631104 -0.04414258524775505 -0.031028015539050102 0.03722957894206047 -3.3453090964030707e-06 -0.5750053524971008 0.8181495666503906 -0.9817797541618347 -0.053956516087055206 0.8169578313827515 0.5741673707962036 -0.6890010833740234 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000035.txt b/3DTopia/assets/sample_data/pose/000035.txt new file mode 100644 index 0000000000000000000000000000000000000000..f2aa76ce87ec4f9eb23655cd4f18d0b2bde3f1c6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000035.txt @@ -0,0 +1 @@ +-0.9893580675125122 0.11771200597286224 0.08552265912294388 -0.10262733697891235 -1.043080928297968e-07 -0.5877853035926819 0.8090168237686157 -0.970820426940918 0.14549998939037323 0.8004074692726135 0.5815301537513733 -0.6978363394737244 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000036.txt b/3DTopia/assets/sample_data/pose/000036.txt new file mode 100644 index 0000000000000000000000000000000000000000..e6dd057f262657b11171b5e9e263897f9b962595 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000036.txt @@ -0,0 +1 @@ +-0.940730631351471 0.2712166905403137 0.20363546907901764 -0.24436251819133759 1.6391278734317893e-07 -0.6004202365875244 0.7996845841407776 -0.9596214294433594 0.3391546905040741 0.7522878646850586 0.5648337006568909 -0.6778002977371216 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000037.txt b/3DTopia/assets/sample_data/pose/000037.txt new file mode 100644 index 0000000000000000000000000000000000000000..e2eb5ccbae9d91c08c51f7896bfb9c434d313b0f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000037.txt @@ -0,0 +1 @@ +-0.8545990586280823 0.4103184938430786 0.31827494502067566 -0.3819308578968048 -5.513428504855256e-07 -0.6129069924354553 0.7901549935340881 -0.9481860995292664 0.519288182258606 0.6752656102180481 0.5237899422645569 -0.6285476088523865 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000038.txt b/3DTopia/assets/sample_data/pose/000038.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c67481b528183a88bda479577977d8a912b3acc --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000038.txt @@ -0,0 +1 @@ +-0.7343969345092773 0.5296936631202698 0.42436450719833374 -0.5092376470565796 -5.9604616353681195e-08 -0.6252426505088806 0.7804303765296936 -0.936516523361206 0.6787199378013611 0.5731458067893982 0.45917630195617676 -0.5510116219520569 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000039.txt b/3DTopia/assets/sample_data/pose/000039.txt new file mode 100644 index 0000000000000000000000000000000000000000..5539121de371664f1ad0987a401926f9542decdc --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000039.txt @@ -0,0 +1 @@ +-0.5849173665046692 0.6249576807022095 0.5170100331306458 -0.6204122304916382 -5.960463056453591e-08 -0.6374240517616272 0.770513117313385 -0.9246158599853516 0.811092734336853 0.4506864845752716 0.37284043431282043 -0.4474082589149475 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000040.txt b/3DTopia/assets/sample_data/pose/000040.txt new file mode 100644 index 0000000000000000000000000000000000000000..7340764fc6dc28476c22bd62d2b8b7c130508936 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000040.txt @@ -0,0 +1 @@ +-0.4121185541152954 0.692828893661499 0.5917317271232605 -0.7100781798362732 -1.341104507446289e-07 -0.6494481563568115 0.7604060173034668 -0.9124871492385864 0.9111302495002747 0.31337738037109375 0.26764971017837524 -0.32117941975593567 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000041.txt b/3DTopia/assets/sample_data/pose/000041.txt new file mode 100644 index 0000000000000000000000000000000000000000..5b93615c6eed7a5500d1936ae081317ad38e9ac6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000041.txt @@ -0,0 +1 @@ +-0.22289007902145386 0.7312408685684204 0.6446757316589355 -0.7736107707023621 -1.1920927533992653e-07 -0.6613120436668396 0.7501108646392822 -0.9001333117485046 0.9748435020446777 0.1671922653913498 0.14739994704723358 -0.1768796592950821 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000042.txt b/3DTopia/assets/sample_data/pose/000042.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa90ee264a33681b93ab18b934cd7343a2a291d1 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000042.txt @@ -0,0 +1 @@ +-0.02477543242275715 0.7394039630889893 0.6728058457374573 -0.8073671460151672 -1.7136329688582919e-07 -0.6730126738548279 0.7396309971809387 -0.887557327747345 0.9996929168701172 0.018324699252843857 0.01667424477636814 -0.020009009167551994 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000043.txt b/3DTopia/assets/sample_data/pose/000043.txt new file mode 100644 index 0000000000000000000000000000000000000000..4dff3f84326c1434af139ddba028c26ef6fa464d --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000043.txt @@ -0,0 +1 @@ +0.17432667315006256 0.7178065776824951 0.6740651726722717 -0.8088783025741577 -5.215405707303944e-08 -0.6845470666885376 0.7289686799049377 -0.8747623562812805 0.9846878051757812 -0.12707868218421936 -0.11933477967977524 0.14320188760757446 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000044.txt b/3DTopia/assets/sample_data/pose/000044.txt new file mode 100644 index 0000000000000000000000000000000000000000..676b06b5ef91c0926a8f2f1c8e3104de89e7753d --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000044.txt @@ -0,0 +1 @@ +0.3664790987968445 0.6681637167930603 0.6474955081939697 -0.7769947052001953 -2.9802322387695312e-08 -0.6959127187728882 0.7181264162063599 -0.8617515563964844 0.9304263591766357 -0.263178288936615 -0.25503745675086975 0.3060450553894043 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000045.txt b/3DTopia/assets/sample_data/pose/000045.txt new file mode 100644 index 0000000000000000000000000000000000000000..9b3e98714504ae3743baa7e85023c241c9c39d7a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000045.txt @@ -0,0 +1 @@ +0.5440210700035095 0.5933132171630859 0.5933132171630859 -0.7119758129119873 1.4901161193847656e-08 -0.7071067690849304 0.7071068286895752 -0.8485281467437744 0.8390715718269348 -0.38468101620674133 -0.38468098640441895 0.46161726117134094 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000046.txt b/3DTopia/assets/sample_data/pose/000046.txt new file mode 100644 index 0000000000000000000000000000000000000000..de5ee9d3b5e7aa35e8cc57db0cad4d8c9e10a118 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000046.txt @@ -0,0 +1 @@ +0.699874758720398 0.4970666766166687 0.5129328966140747 -0.6155195832252502 -1.4901161193847656e-08 -0.7181262969970703 0.6959128975868225 -0.8350953459739685 0.7142656445503235 -0.4870518445968628 -0.5025984048843384 0.6031181812286377 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000047.txt b/3DTopia/assets/sample_data/pose/000047.txt new file mode 100644 index 0000000000000000000000000000000000000000..292398d3996a7b4e9c17c561cde7b08df6d13644 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000047.txt @@ -0,0 +1 @@ +0.8278263807296753 0.38402023911476135 0.408939927816391 -0.49072784185409546 4.470348358154297e-08 -0.728968620300293 0.6845471262931824 -0.8214565515518188 0.5609843134880066 -0.56668621301651 -0.6034594774246216 0.7241514921188354 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000048.txt b/3DTopia/assets/sample_data/pose/000048.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7017f7c4bd6ab51b19f5fca5a032ff5ff71618a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000048.txt @@ -0,0 +1 @@ +0.9227754473686218 0.2593373954296112 0.28500810265541077 -0.3420097231864929 -0.0 -0.7396311163902283 0.6730124950408936 -0.8076150417327881 0.38533815741539 -0.6210393905639648 -0.682513415813446 0.81901615858078 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000049.txt b/3DTopia/assets/sample_data/pose/000049.txt new file mode 100644 index 0000000000000000000000000000000000000000..13928ac07d2057760018824acfeecadb15d7413e --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000049.txt @@ -0,0 +1 @@ +0.9809362888336182 0.1285126805305481 0.14576902985572815 -0.17492283880710602 -7.450581485102248e-09 -0.7501111030578613 0.66131192445755 -0.7935742139816284 0.19432991743087769 -0.6487048268318176 -0.7358112335205078 0.8829733729362488 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000050.txt b/3DTopia/assets/sample_data/pose/000050.txt new file mode 100644 index 0000000000000000000000000000000000000000..764270d7e9b8b544b87908ce1c470cc55b4e9473 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000050.txt @@ -0,0 +1 @@ +0.9999901652336121 -0.0028742607682943344 -0.0033653262071311474 0.004038391634821892 -2.3283061589829401e-10 -0.7604058980941772 0.6494479179382324 -0.7793375253677368 -0.00442569749429822 -0.6494415998458862 -0.7603984475135803 0.9124780297279358 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000051.txt b/3DTopia/assets/sample_data/pose/000051.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5794890908d9dbabb26e1429bfc337a52782633 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000051.txt @@ -0,0 +1 @@ +0.979177713394165 -0.12940019369125366 -0.15641796588897705 0.1877015084028244 -7.450580596923828e-09 -0.7705132365226746 0.6374240517616272 -0.7649087905883789 -0.2030048966407776 -0.624151349067688 -0.7544693946838379 0.9053632020950317 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000052.txt b/3DTopia/assets/sample_data/pose/000052.txt new file mode 100644 index 0000000000000000000000000000000000000000..9e4b173f044faee03f774286251bfe3ef1b096b6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000052.txt @@ -0,0 +1 @@ +0.9193285703659058 -0.24602729082107544 -0.30709224939346313 0.3685106933116913 -0.0 -0.7804304361343384 0.6252426505088806 -0.7502912282943726 -0.39349088072776794 -0.5748034119606018 -0.7174719572067261 0.8609663844108582 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000053.txt b/3DTopia/assets/sample_data/pose/000053.txt new file mode 100644 index 0000000000000000000000000000000000000000..4ba7967d04ca6116b6a6de9fdeedaa997f5c4136 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000053.txt @@ -0,0 +1 @@ +0.8228285908699036 -0.3483087122440338 -0.4490368962287903 0.5388442873954773 -0.0 -0.7901550531387329 0.6129070520401001 -0.7354884147644043 -0.5682896375656128 -0.5043174624443054 -0.6501621007919312 0.7801946997642517 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000054.txt b/3DTopia/assets/sample_data/pose/000054.txt new file mode 100644 index 0000000000000000000000000000000000000000..18bb9aa82b377f174a1f78a8793725eba9593f80 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000054.txt @@ -0,0 +1 @@ +0.6935251355171204 -0.4325622320175171 -0.5761187076568604 0.6913425326347351 -1.4901159417490817e-08 -0.7996845245361328 0.6004201769828796 -0.7205043435096741 -0.7204324007034302 -0.4164064824581146 -0.5546013116836548 0.6655217409133911 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000055.txt b/3DTopia/assets/sample_data/pose/000055.txt new file mode 100644 index 0000000000000000000000000000000000000000..5c795995f90fa82ceaf094217099c6cb0f4417d6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000055.txt @@ -0,0 +1 @@ +0.5365728735923767 -0.49600499868392944 -0.6826922297477722 0.8192306160926819 4.470348713425665e-08 -0.8090170621871948 0.5877854228019714 -0.7053423523902893 -0.8438540697097778 -0.3153897225856781 -0.4340965449810028 0.5209159851074219 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000056.txt b/3DTopia/assets/sample_data/pose/000056.txt new file mode 100644 index 0000000000000000000000000000000000000000..d3d4f889b29e5366d8c01f2a081bd03e1b6c6a09 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000056.txt @@ -0,0 +1 @@ +0.3582288920879364 -0.5368444323539734 -0.7638519406318665 0.9166225790977478 1.490115550950577e-07 -0.8181496262550354 0.575005292892456 -0.6900063157081604 -0.93363356590271 -0.20598354935646057 -0.2930847704410553 0.3517022430896759 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000057.txt b/3DTopia/assets/sample_data/pose/000057.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca2c0b4a3adb1181f77b0cdbf518ff7c8cad69a0 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000057.txt @@ -0,0 +1 @@ +0.1656038910150528 -0.5543226599693298 -0.8156602382659912 0.9787925481796265 7.4505797087454084e-09 -0.8270803093910217 0.5620836615562439 -0.6744999885559082 -0.9861923456192017 -0.0930832177400589 -0.136967733502388 0.1643616110086441 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000058.txt b/3DTopia/assets/sample_data/pose/000058.txt new file mode 100644 index 0000000000000000000000000000000000000000..557f498858ea589f14c5ed04bdae4cbb81939903 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000058.txt @@ -0,0 +1 @@ +-0.03362308070063591 -0.5487122535705566 -0.8353347778320312 1.0024018287658691 4.749744064724837e-08 -0.8358075618743896 0.5490226745605469 -0.6588274240493774 -0.9994345307350159 0.01845986768603325 0.028102422133088112 -0.03372287005186081 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000059.txt b/3DTopia/assets/sample_data/pose/000059.txt new file mode 100644 index 0000000000000000000000000000000000000000..e4ee736ef15c8763363013a8f4bd528fec551aec --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000059.txt @@ -0,0 +1 @@ +-0.23150981962680817 -0.5212697386741638 -0.8213896751403809 0.9856675863265991 7.450580596923828e-09 -0.8443279266357422 0.5358267426490784 -0.6429921984672546 -0.9728325605392456 0.12404916435480118 0.1954701989889145 -0.23456427454948425 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000060.txt b/3DTopia/assets/sample_data/pose/000060.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a1a4f0de215b9b5d45d6c2a1b2dc47848914eae --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000060.txt @@ -0,0 +1 @@ +-0.42016705870628357 -0.474139541387558 -0.7737252116203308 0.9284706711769104 1.4901151246249356e-07 -0.8526401519775391 0.5224984884262085 -0.6269983053207397 -0.9074463844299316 0.2195366770029068 0.35825133323669434 -0.4299015402793884 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000061.txt b/3DTopia/assets/sample_data/pose/000061.txt new file mode 100644 index 0000000000000000000000000000000000000000..9d93876315c8f72cb7560190146125be50bd9318 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000061.txt @@ -0,0 +1 @@ +-0.5920736193656921 -0.4102281630039215 -0.6936582326889038 0.8323898315429688 -2.980232594040899e-08 -0.8607421517372131 0.5090413093566895 -0.6108497381210327 -0.8058839440345764 0.3013899028301239 0.5096226930618286 -0.6115471720695496 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000062.txt b/3DTopia/assets/sample_data/pose/000062.txt new file mode 100644 index 0000000000000000000000000000000000000000..eff3056cc60ea6abf104750d715c890dafd3ebef --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000062.txt @@ -0,0 +1 @@ +-0.7403759360313416 -0.3330437242984772 -0.5838878750801086 0.7006657123565674 1.043080928297968e-07 -0.8686315417289734 0.4954584836959839 -0.5945504903793335 -0.6721928119659424 0.36682555079460144 0.6431138515472412 -0.7717366218566895 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000063.txt b/3DTopia/assets/sample_data/pose/000063.txt new file mode 100644 index 0000000000000000000000000000000000000000..dd26b45880fdfe87414665a9f4e62e9c8c63763a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000063.txt @@ -0,0 +1 @@ +-0.8591620922088623 -0.2465151995420456 -0.44840940833091736 0.5380914807319641 4.4703490686970326e-08 -0.8763067126274109 0.4817536771297455 -0.5781044363975525 -0.5117037892341614 0.4139043986797333 0.7528894543647766 -0.9034671783447266 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000064.txt b/3DTopia/assets/sample_data/pose/000064.txt new file mode 100644 index 0000000000000000000000000000000000000000..132633e40f4c478d77778bba5642750d3711fd92 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000064.txt @@ -0,0 +1 @@ +-0.9436956644058228 -0.15479804575443268 -0.29236292839050293 0.35083532333374023 -1.043081283569336e-07 -0.8837655782699585 0.46792975068092346 -0.5615156292915344 -0.3308148980140686 0.4415833055973053 0.8340057730674744 -1.0008068084716797 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000065.txt b/3DTopia/assets/sample_data/pose/000065.txt new file mode 100644 index 0000000000000000000000000000000000000000..f09ada342d0fc5da8de63ee7fba3456573a95c2b --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000065.txt @@ -0,0 +1 @@ +-0.9906071424484253 -0.06207740679383278 -0.12183358520269394 0.146200492978096 8.195635103902532e-08 -0.891006588935852 0.45399045944213867 -0.5447887182235718 -0.13673707842826843 0.44972628355026245 0.8826374411582947 -1.0591652393341064 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000066.txt b/3DTopia/assets/sample_data/pose/000066.txt new file mode 100644 index 0000000000000000000000000000000000000000..64173ae9e839c86a2a67043f6f5280963d5ac9d7 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000066.txt @@ -0,0 +1 @@ +-0.9980266094207764 0.027624979615211487 0.0563855841755867 -0.06766645610332489 -1.7657868056630832e-06 -0.8980275988578796 0.4399391710758209 -0.5279269218444824 0.06278911978006363 0.43907099962234497 0.8962554335594177 -1.0755064487457275 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000067.txt b/3DTopia/assets/sample_data/pose/000067.txt new file mode 100644 index 0000000000000000000000000000000000000000..b1cd5fefa04afc735826d918a8ab7b65a4e1b4c4 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000067.txt @@ -0,0 +1 @@ +-0.9656579494476318 0.11062507331371307 0.23508931696414948 -0.28210771083831787 -4.023314090773056e-07 -0.9048269987106323 0.4257793426513672 -0.5109351277351379 0.2598170340061188 0.4111570417881012 0.8737534880638123 -1.0485038757324219 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000068.txt b/3DTopia/assets/sample_data/pose/000068.txt new file mode 100644 index 0000000000000000000000000000000000000000..f860f797e8891165b6cc501ef4a61a21e33144c6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000068.txt @@ -0,0 +1 @@ +-0.894791305065155 0.18373513221740723 0.40692755579948425 -0.48831334710121155 -2.682209014892578e-07 -0.9114032983779907 0.41151440143585205 -0.4938172996044159 0.4464847445487976 0.36821937561035156 0.8155156970024109 -0.9786188006401062 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000069.txt b/3DTopia/assets/sample_data/pose/000069.txt new file mode 100644 index 0000000000000000000000000000000000000000..ec94ae96b22c3a416e27d5794ab8046f8a6d8633 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000069.txt @@ -0,0 +1 @@ +-0.7882521152496338 0.24438586831092834 0.5647425055503845 -0.6776911020278931 2.9802318834981634e-08 -0.9177546501159668 0.39714789390563965 -0.47657743096351624 0.6153523921966553 0.3130526840686798 0.7234220504760742 -0.8681064248085022 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000070.txt b/3DTopia/assets/sample_data/pose/000070.txt new file mode 100644 index 0000000000000000000000000000000000000000..10ca2355685da194750f3114d96886af2ae6e946 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000070.txt @@ -0,0 +1 @@ +-0.6502879858016968 0.29071998596191406 0.7018599510192871 -0.8422322273254395 -2.9802318834981634e-08 -0.9238795638084412 0.38268351554870605 -0.45922014117240906 0.7596877217292786 0.24885447323322296 0.6007877588272095 -0.7209452986717224 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000071.txt b/3DTopia/assets/sample_data/pose/000071.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d086455038440042ad1f62d0178bee54bb029e1 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000071.txt @@ -0,0 +1 @@ +-0.48639875650405884 0.3216439187526703 0.8123796582221985 -0.9748559594154358 -7.450575623124678e-08 -0.9297764897346497 0.36812451481819153 -0.4417493939399719 0.873736560344696 0.1790553480386734 0.4522421360015869 -0.5426904559135437 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000072.txt b/3DTopia/assets/sample_data/pose/000072.txt new file mode 100644 index 0000000000000000000000000000000000000000..253bdb8edec9e4eb40174ab5526043a7a296708e --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000072.txt @@ -0,0 +1 @@ +-0.30311834812164307 0.33684486150741577 0.8914337754249573 -1.069720983505249 -9.68574909165909e-08 -0.9354440569877625 0.35347482562065125 -0.4241698682308197 0.95295250415802 0.10714472085237503 0.2835502326488495 -0.3402603268623352 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000073.txt b/3DTopia/assets/sample_data/pose/000073.txt new file mode 100644 index 0000000000000000000000000000000000000000..1065ec11a970cb433f82159667976f17b869b7ad --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000073.txt @@ -0,0 +1 @@ +-0.10775362700223923 0.3367651700973511 0.9354028105735779 -1.122483253479004 1.1175870895385742e-08 -0.9408808946609497 0.33873745799064636 -0.40648552775382996 0.9941775798797607 0.03650019317865372 0.10138332843780518 -0.1216600239276886 -0.0 -0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000074.txt b/3DTopia/assets/sample_data/pose/000074.txt new file mode 100644 index 0000000000000000000000000000000000000000..936daa35e0a7e2920faf188120d95254364af547 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000074.txt @@ -0,0 +1 @@ +0.09190679341554642 0.3225465416908264 0.9420811533927917 -1.1304973363876343 -7.450580596923828e-09 -0.9460853934288025 0.3239174783229828 -0.3887008726596832 0.9957676529884338 -0.02977021597325802 -0.08695167303085327 0.10434205830097198 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000075.txt b/3DTopia/assets/sample_data/pose/000075.txt new file mode 100644 index 0000000000000000000000000000000000000000..6459bfbf10f90125316735b8fb2ed5930c17c0bb --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000075.txt @@ -0,0 +1 @@ +0.2879031300544739 0.2959333658218384 0.910788357257843 -1.0929460525512695 -0.0 -0.9510565400123596 0.3090173006057739 -0.37082037329673767 0.9576596021652222 -0.08896704763174057 -0.27381211519241333 0.3285748362541199 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000076.txt b/3DTopia/assets/sample_data/pose/000076.txt new file mode 100644 index 0000000000000000000000000000000000000000..e43d0f90f4fdaa1ca35987f4ce7752e58cd5946a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000076.txt @@ -0,0 +1 @@ +0.4724216163158417 0.25915926694869995 0.8424096703529358 -1.0108915567398071 -5.960463056453591e-08 -0.9557929039001465 0.2940405011177063 -0.3528483808040619 0.8813725709915161 -0.13891111314296722 -0.45153722167015076 0.5418452024459839 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000077.txt b/3DTopia/assets/sample_data/pose/000077.txt new file mode 100644 index 0000000000000000000000000000000000000000..662781cae983bf3ae319dbed50dede17db50e6db --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000077.txt @@ -0,0 +1 @@ +0.6381067037582397 0.2148085981607437 0.7393761277198792 -0.8872514367103577 -1.4901159417490817e-08 -0.9602935910224915 0.27899107336997986 -0.3347893953323364 0.7699478268623352 -0.17802608013153076 -0.6127697825431824 0.7353238463401794 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000078.txt b/3DTopia/assets/sample_data/pose/000078.txt new file mode 100644 index 0000000000000000000000000000000000000000..8e2d7a807fdfeb4e3023d3a797d0b608bfcd0a9a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000078.txt @@ -0,0 +1 @@ +0.7783521413803101 0.1656668782234192 0.6055761575698853 -0.726691484451294 7.450580596923828e-09 -0.964557409286499 0.26387304067611694 -0.31664761900901794 0.6278279423713684 -0.20538613200187683 -0.7507652640342712 0.9009183049201965 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000079.txt b/3DTopia/assets/sample_data/pose/000079.txt new file mode 100644 index 0000000000000000000000000000000000000000..24ba57c4edabd091941cd2a2fb2fda2b1cf94bcb --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000079.txt @@ -0,0 +1 @@ +0.8875669240951538 0.1145661398768425 0.4462054967880249 -0.5354465842247009 7.4505797087454084e-09 -0.9685830473899841 0.2486899495124817 -0.29842785000801086 0.4606785774230957 -0.22072899341583252 -0.8596823811531067 1.0316189527511597 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000080.txt b/3DTopia/assets/sample_data/pose/000080.txt new file mode 100644 index 0000000000000000000000000000000000000000..b74aa513a77caee515d08792c10e1422e2b7c072 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000080.txt @@ -0,0 +1 @@ +0.9613975286483765 0.0642356276512146 0.26756060123443604 -0.3210725784301758 -0.0 -0.9723699688911438 0.233445405960083 -0.2801344394683838 0.27516335248947144 -0.22443383932113647 -0.9348340034484863 1.1218007802963257 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000081.txt b/3DTopia/assets/sample_data/pose/000081.txt new file mode 100644 index 0000000000000000000000000000000000000000..e23aaf83ca0b1a6f9823145f9d23729ecc6aa851 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000081.txt @@ -0,0 +1 @@ +0.9969000816345215 0.017163122072815895 0.07678337395191193 -0.09214004129171371 1.862645371275562e-09 -0.9759168028831482 0.2181432992219925 -0.26177191734313965 0.078678198158741 -0.2174670696258545 -0.9728915691375732 1.16746985912323 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000082.txt b/3DTopia/assets/sample_data/pose/000082.txt new file mode 100644 index 0000000000000000000000000000000000000000..884d4a666c7ff8885c74f4ff22f123a25e86cdb5 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000082.txt @@ -0,0 +1 @@ +0.9926593899726868 -0.024525828659534454 -0.1184307411313057 0.14211684465408325 -0.0 -0.9792227745056152 0.20278730988502502 -0.2433447241783142 -0.12094360589981079 -0.20129872858524323 -0.972034752368927 1.1664414405822754 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000083.txt b/3DTopia/assets/sample_data/pose/000083.txt new file mode 100644 index 0000000000000000000000000000000000000000..22b28a7dc788c0d18d7b8fb27841332257e4a5db --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000083.txt @@ -0,0 +1 @@ +0.9488444924354553 -0.05916447937488556 -0.310151070356369 0.372181236743927 3.725290742551124e-09 -0.9822872877120972 0.187381312251091 -0.22485758364200592 -0.3157437741756439 -0.17779573798179626 -0.932037889957428 1.1184452772140503 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000084.txt b/3DTopia/assets/sample_data/pose/000084.txt new file mode 100644 index 0000000000000000000000000000000000000000..328296c9245ecf6dd9f22d0f4807869a6661795c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000084.txt @@ -0,0 +1 @@ +0.8672021627426147 -0.08561316877603531 -0.4905413091182709 0.588649570941925 -0.0 -0.9851093292236328 0.17192910611629486 -0.20631492137908936 -0.49795621633529663 -0.1490972936153412 -0.8542889356613159 1.0251468420028687 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000085.txt b/3DTopia/assets/sample_data/pose/000085.txt new file mode 100644 index 0000000000000000000000000000000000000000..1d753f8a18d6c2b1704e5c8f395d2318c470f194 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000085.txt @@ -0,0 +1 @@ +0.7509872913360596 -0.10329629480838776 -0.6521871089935303 0.7826245427131653 -0.0 -0.9876883625984192 0.15643447637557983 -0.1877213567495346 -0.6603167057037354 -0.11748029291629791 -0.7417413592338562 0.8900896906852722 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000086.txt b/3DTopia/assets/sample_data/pose/000086.txt new file mode 100644 index 0000000000000000000000000000000000000000..dcfa47155ee42f92e895cff4f456a09cb13b9a5c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000086.txt @@ -0,0 +1 @@ +0.6048324704170227 -0.11220713704824448 -0.7884079813957214 0.946089506149292 4.470348002882929e-08 -0.9900237321853638 0.14090131223201752 -0.16908152401447296 -0.7963526844978333 -0.08522169291973114 -0.5987984538078308 0.718558669090271 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000087.txt b/3DTopia/assets/sample_data/pose/000087.txt new file mode 100644 index 0000000000000000000000000000000000000000..e3874f74764b3d722fc17ec2887449e02ed64866 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000087.txt @@ -0,0 +1 @@ +0.43456563353538513 -0.1128801479935646 -0.8935384154319763 1.0722460746765137 3.725290298461914e-09 -0.9921147227287292 0.12533323466777802 -0.15039989352226257 -0.9006401896476746 -0.05446551740169525 -0.43113893270492554 0.5173667073249817 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000088.txt b/3DTopia/assets/sample_data/pose/000088.txt new file mode 100644 index 0000000000000000000000000000000000000000..7f91baf2142506bb58bb4adc9b969d367b20358e --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000088.txt @@ -0,0 +1 @@ +0.24697357416152954 -0.10633499920368195 -0.9631701111793518 1.15580415725708 1.8626447051417472e-09 -0.9939608573913574 0.10973432660102844 -0.1316811591386795 -0.9690220952033997 -0.02710147760808468 -0.24548210203647614 0.29457858204841614 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000089.txt b/3DTopia/assets/sample_data/pose/000089.txt new file mode 100644 index 0000000000000000000000000000000000000000..01730422d0cebff267fe7caaa136318f455b21e4 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000089.txt @@ -0,0 +1 @@ +0.049535539001226425 -0.09399279206991196 -0.9943397641181946 1.1932077407836914 6.51925802230835e-09 -0.9955620169639587 0.09410832822322845 -0.11292997747659683 -0.9987723231315613 -0.0046617123298347 -0.04931569844484329 0.05917895957827568 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000090.txt b/3DTopia/assets/sample_data/pose/000090.txt new file mode 100644 index 0000000000000000000000000000000000000000..b60a7c97118e2848dcfe298ebfd1c23b7bcade64 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000090.txt @@ -0,0 +1 @@ +-0.14987720549106598 -0.07757285982370377 -0.98565673828125 1.1827882528305054 -0.0 -0.9969173669815063 0.07845908403396606 -0.09415092319250107 -0.9887046217918396 0.011759229004383087 0.14941516518592834 -0.1792982518672943 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000091.txt b/3DTopia/assets/sample_data/pose/000091.txt new file mode 100644 index 0000000000000000000000000000000000000000..a14b8cf84c53f5acc76d1b80f245aadf6c5dad75 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000091.txt @@ -0,0 +1 @@ +-0.34331491589546204 -0.058974120765924454 -0.937366783618927 1.1248406171798706 3.7252885221050747e-09 -0.9980267882347107 0.06279051303863525 -0.0753486305475235 -0.939220130443573 0.021556934341788292 0.34263747930526733 -0.4111650288105011 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000092.txt b/3DTopia/assets/sample_data/pose/000092.txt new file mode 100644 index 0000000000000000000000000000000000000000..2a1e764a5cd4a1ce03586527814eeeafc2d91e99 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000092.txt @@ -0,0 +1 @@ +-0.5230658054351807 -0.04014846310019493 -0.8513460755348206 1.0216155052185059 3.7252898543727042e-09 -0.9988899230957031 0.04710644856095314 -0.05652773752808571 -0.8522922396659851 0.02463977038860321 0.5224851369857788 -0.6269820332527161 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000093.txt b/3DTopia/assets/sample_data/pose/000093.txt new file mode 100644 index 0000000000000000000000000000000000000000..953d8bb0c4d870bdba217954835590a2acfb9e2f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000093.txt @@ -0,0 +1 @@ +-0.6819634437561035 -0.022973379120230675 -0.7310251593589783 0.8772302269935608 -9.313223969797946e-09 -0.9995065331459045 0.03141074627637863 -0.03769290819764137 -0.7313860654830933 0.02142098918557167 0.6816269159317017 -0.8179523944854736 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000094.txt b/3DTopia/assets/sample_data/pose/000094.txt new file mode 100644 index 0000000000000000000000000000000000000000..077a7907fa1c1cc7701873b81a59f27a86591ddd --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000094.txt @@ -0,0 +1 @@ +-0.8136734962463379 -0.009131004102528095 -0.5812501311302185 0.6975001096725464 -4.190950253502024e-09 -0.9998766779899597 0.015707319602370262 -0.018848778679966927 -0.5813218355178833 0.012780634686350822 0.8135731816291809 -0.9762880206108093 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000095.txt b/3DTopia/assets/sample_data/pose/000095.txt new file mode 100644 index 0000000000000000000000000000000000000000..481ab52b37e4e23f1c5de5c779c81f73e477e395 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000095.txt @@ -0,0 +1 @@ +-0.9129449129104614 -1.2935611884759817e-15 -0.4080818295478821 0.4896984398365021 4.2351617070456256e-22 -1.0 3.169856057190978e-15 -3.8038288779917735e-15 -0.4080818295478821 2.8939047931950254e-15 0.9129449129104614 -1.0955342054367065 0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000096.txt b/3DTopia/assets/sample_data/pose/000096.txt new file mode 100644 index 0000000000000000000000000000000000000000..7774578fccfa2776f29df14063d1b9806ee9a20c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000096.txt @@ -0,0 +1 @@ +-0.9758200645446777 0.003433206817135215 -0.21854621171951294 0.2622557282447815 -8.381896066111949e-09 -0.9998766183853149 -0.015707315877079964 0.018848782405257225 -0.2185731828212738 -0.015327518805861473 0.9756997227668762 -1.1708403825759888 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000097.txt b/3DTopia/assets/sample_data/pose/000097.txt new file mode 100644 index 0000000000000000000000000000000000000000..3e6fdbaa06f2f98f1f12b4a4cae0e24ad2a52202 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000097.txt @@ -0,0 +1 @@ +-0.9997926354408264 0.0006389844347722828 -0.020337846130132675 0.024408958852291107 1.5809190756499447e-07 -0.9995064735412598 -0.03141075372695923 0.037692904472351074 -0.02034788206219673 -0.03140425682067871 0.9992992877960205 -1.1991593837738037 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000098.txt b/3DTopia/assets/sample_data/pose/000098.txt new file mode 100644 index 0000000000000000000000000000000000000000..c760319b66cb65b2e30cd921736ef9cc5749cf4c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000098.txt @@ -0,0 +1 @@ +-0.9839062690734863 -0.00841712299734354 0.17848443984985352 -0.2141815721988678 1.3038505386475663e-08 -0.9988898634910583 -0.04710642620921135 0.05652773752808571 0.1786828190088272 -0.04634832963347435 0.9828140139579773 -1.1793771982192993 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000099.txt b/3DTopia/assets/sample_data/pose/000099.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d49a8f5085e4de83c1a31dacd59e0f782855080 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000099.txt @@ -0,0 +1 @@ +-0.9287950992584229 -0.023269735276699066 0.36986207962036133 -0.4438343644142151 -1.4901159417490817e-08 -0.9980266094207764 -0.06279050558805466 0.0753486156463623 0.3705933690071106 -0.05831952393054962 0.9269623160362244 -1.1123547554016113 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000100.txt b/3DTopia/assets/sample_data/pose/000100.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9904c7007ec8ce29c9d0ef9d4b20141ed8e3a74 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000100.txt @@ -0,0 +1 @@ +-0.8366557955741882 -0.04297434538602829 0.5460407733917236 -0.6552488803863525 -3.725291186640334e-09 -0.9969173669815063 -0.07845912128686905 0.09415092319250107 0.5477291941642761 -0.06564325839281082 0.8340766429901123 -1.000891923904419 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000101.txt b/3DTopia/assets/sample_data/pose/000101.txt new file mode 100644 index 0000000000000000000000000000000000000000..7e85d82ccedb21034bc93d929d56b5b0f9cc4698 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000101.txt @@ -0,0 +1 @@ +-0.7111608982086182 -0.06616085022687912 0.6999088525772095 -0.839890718460083 -1.4901154088420299e-08 -0.9955620169639587 -0.09410828351974487 0.11292998492717743 0.7030289173126221 -0.06692616641521454 0.7080047130584717 -0.8496062159538269 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000102.txt b/3DTopia/assets/sample_data/pose/000102.txt new file mode 100644 index 0000000000000000000000000000000000000000..f52b5d44b44d916edc72ac097a931fb20b69aab8 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000102.txt @@ -0,0 +1 @@ +-0.557314932346344 -0.09111247956752777 0.8252866864204407 -0.9903444051742554 7.450577044210149e-09 -0.993960976600647 -0.10973427444696426 0.1316811740398407 0.8303009271621704 -0.06115657836198807 0.5539493560791016 -0.6647393703460693 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000103.txt b/3DTopia/assets/sample_data/pose/000103.txt new file mode 100644 index 0000000000000000000000000000000000000000..628a79aa312adfdc54645db2f3186e6876244f3c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000103.txt @@ -0,0 +1 @@ +-0.3812505602836609 -0.115867018699646 0.9171820282936096 -1.1006184816360474 -3.725290298461914e-09 -0.9921146631240845 -0.12533321976661682 0.15039989352226257 0.9244717359542847 -0.04778335988521576 0.3782442808151245 -0.4538930654525757 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000104.txt b/3DTopia/assets/sample_data/pose/000104.txt new file mode 100644 index 0000000000000000000000000000000000000000..511f249cc3bb9f7cd2c1952ddffc165697aae457 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000104.txt @@ -0,0 +1 @@ +-0.1899867206811905 -0.13833492994308472 0.971992015838623 -1.1663905382156372 7.450580596923828e-09 -0.990023672580719 -0.14090122282505035 0.16908153891563416 0.9817867279052734 -0.026769354939460754 0.18809135258197784 -0.22570960223674774 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000105.txt b/3DTopia/assets/sample_data/pose/000105.txt new file mode 100644 index 0000000000000000000000000000000000000000..7bf130e319aebed4e688318b892f049659653bb3 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000105.txt @@ -0,0 +1 @@ +0.008851335383951664 -0.15642881393432617 0.9876496195793152 -1.185179591178894 1.7462300494486271e-09 -0.9876883029937744 -0.15643493831157684 0.1877213567495346 0.9999608993530273 0.0013846629299223423 -0.008742359466850758 0.01049080304801464 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000106.txt b/3DTopia/assets/sample_data/pose/000106.txt new file mode 100644 index 0000000000000000000000000000000000000000..dd2446d7e26c97486770c44c0c16aac0b33be82f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000106.txt @@ -0,0 +1 @@ +0.20733638107776642 -0.16819317638874054 0.9637025594711304 -1.1564432382583618 -7.450580596923828e-09 -0.9851093292236328 -0.171929270029068 0.20631493628025055 0.9782696962356567 0.035647179931402206 -0.20424900949001312 0.24509884417057037 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000107.txt b/3DTopia/assets/sample_data/pose/000107.txt new file mode 100644 index 0000000000000000000000000000000000000000..c63eef9320d8658f8550ed8aba13409f58845b0c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000107.txt @@ -0,0 +1 @@ +0.39755570888519287 -0.1719369888305664 0.9013251662254333 -1.081590175628662 -2.2351741790771484e-08 -0.9822872877120972 -0.1873813271522522 0.22485756874084473 0.9175780415534973 0.07449448853731155 -0.39051389694213867 0.4686166048049927 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000108.txt b/3DTopia/assets/sample_data/pose/000108.txt new file mode 100644 index 0000000000000000000000000000000000000000..d18b02f5b68c2656ccca9e3fb434e603868fe18e --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000108.txt @@ -0,0 +1 @@ +0.5719255805015564 -0.16634754836559296 0.8032617568969727 -0.9639140367507935 2.9802318834981634e-08 -0.9792227745056152 -0.20278732478618622 0.2433447390794754 0.8203054070472717 0.11597928404808044 -0.5600425601005554 0.672051191329956 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000109.txt b/3DTopia/assets/sample_data/pose/000109.txt new file mode 100644 index 0000000000000000000000000000000000000000..166a0b7be54186c836c494541344aa9fa9fd3c31 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000109.txt @@ -0,0 +1 @@ +0.7234946489334106 -0.1505908966064453 0.6737046241760254 -0.8084455132484436 -7.450580596923828e-09 -0.9759168028831482 -0.21814334392547607 0.26177188754081726 0.6903300285339355 0.1578255444765091 -0.7060705423355103 0.8472847938537598 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000110.txt b/3DTopia/assets/sample_data/pose/000110.txt new file mode 100644 index 0000000000000000000000000000000000000000..70e94906e0987d8359ab076191051939d4a136df --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000110.txt @@ -0,0 +1 @@ +0.8462203741073608 -0.12438744306564331 0.5181108713150024 -0.6217329502105713 -0.0 -0.972369909286499 -0.2334454357624054 0.2801344394683838 0.5328330993652344 0.19754627346992493 -0.8228392004966736 0.9874071478843689 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000111.txt b/3DTopia/assets/sample_data/pose/000111.txt new file mode 100644 index 0000000000000000000000000000000000000000..09e241678a4175b8286de5dbde5176e05649dbdf --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000111.txt @@ -0,0 +1 @@ +0.9352098703384399 -0.08805953711271286 0.34296926856040955 -0.41156312823295593 -0.0 -0.9685831069946289 -0.24868986010551453 0.2984278202056885 0.3540937900543213 0.23257721960544586 -0.9058284759521484 1.0869944095611572 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000112.txt b/3DTopia/assets/sample_data/pose/000112.txt new file mode 100644 index 0000000000000000000000000000000000000000..890cbd2954f25f2baa3d4592b868bbbb20b918b4 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000112.txt @@ -0,0 +1 @@ +0.9869155883789062 -0.04254636913537979 0.15552331507205963 -0.18662789463996887 -7.450581485102248e-09 -0.9645574688911438 -0.26387304067611694 0.31664755940437317 0.16123799979686737 0.26042044162750244 -0.9519367218017578 1.1423239707946777 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000113.txt b/3DTopia/assets/sample_data/pose/000113.txt new file mode 100644 index 0000000000000000000000000000000000000000..033ec4314be76c46782822ce4e48fdca921af492 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000113.txt @@ -0,0 +1 @@ +0.9992760419845581 0.010614471510052681 -0.036535248160362244 0.04384230822324753 9.313225746154785e-10 -0.960293710231781 -0.27899110317230225 0.3347893953323364 -0.038045912981033325 0.27878910303115845 -0.9595984220504761 1.1515181064605713 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000114.txt b/3DTopia/assets/sample_data/pose/000114.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f45dbf2b8a907c1cab96db926977bed501de707 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000114.txt @@ -0,0 +1 @@ +0.9717984795570374 0.06933855265378952 -0.2253885120153427 0.2704661190509796 -1.4901162970204496e-08 -0.9557930827140808 -0.29404035210609436 0.3528483510017395 -0.23581309616565704 0.2857479453086853 -0.9288381934165955 1.1146059036254883 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000115.txt b/3DTopia/assets/sample_data/pose/000115.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ec438c94b2f7224f4b43e7e7fc095543c906619 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000115.txt @@ -0,0 +1 @@ +0.9055783748626709 0.1310785412788391 -0.40341824293136597 0.4841018319129944 1.4901162970204496e-08 -0.9510565996170044 -0.30901703238487244 0.37082037329673767 -0.4241790473461151 0.2798391580581665 -0.8612562417984009 1.0335074663162231 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000116.txt b/3DTopia/assets/sample_data/pose/000116.txt new file mode 100644 index 0000000000000000000000000000000000000000..635a38de7120930fa33c9ce5644faae62e131de3 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000116.txt @@ -0,0 +1 @@ +0.8032556176185608 0.19293642044067383 -0.5635209679603577 0.6762250065803528 -0.0 -0.9460852742195129 -0.3239175081253052 0.3887009024620056 -0.595634400844574 0.2601885497570038 -0.7599483132362366 0.9119382500648499 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000117.txt b/3DTopia/assets/sample_data/pose/000117.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b97f76a4e73f094109721b801d6143752ec7a3c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000117.txt @@ -0,0 +1 @@ +0.6689097285270691 0.2517986595630646 -0.6993976831436157 0.8392772674560547 -2.9802322387695312e-08 -0.9408807754516602 -0.3387379050254822 0.40648549795150757 -0.7433436512947083 0.22658511996269226 -0.6293643712997437 0.7552372217178345 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000118.txt b/3DTopia/assets/sample_data/pose/000118.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0059c94ccb540c18eaff754b8aa9993f6aa77dd --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000118.txt @@ -0,0 +1 @@ +0.5078965425491333 0.30448970198631287 -0.805808424949646 0.966969907283783 -2.9802322387695312e-08 -0.9354440569877625 -0.3534749150276184 0.4241698086261749 -0.8614181280136108 0.17952869832515717 -0.4751087725162506 0.5701305866241455 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000119.txt b/3DTopia/assets/sample_data/pose/000119.txt new file mode 100644 index 0000000000000000000000000000000000000000..8155cd49cc7d2f1b840ec963913fdecf443e2476 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000119.txt @@ -0,0 +1 @@ +0.32663485407829285 0.34793341159820557 -0.8787785172462463 1.0545341968536377 -4.4703469370688254e-08 -0.9297763109207153 -0.3681248426437378 0.4417494237422943 -0.9451503753662109 0.12024238705635071 -0.3036973476409912 0.3644372224807739 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000120.txt b/3DTopia/assets/sample_data/pose/000120.txt new file mode 100644 index 0000000000000000000000000000000000000000..453e91e44313cd67b3db21655d9c28889113920c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000120.txt @@ -0,0 +1 @@ +0.1323518306016922 0.3793167471885681 -0.9157520532608032 1.0989023447036743 1.4901161193847656e-08 -0.9238796234130859 -0.38268330693244934 0.4592200517654419 -0.9912027716636658 0.05064881592988968 -0.1222771555185318 0.14673250913619995 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000121.txt b/3DTopia/assets/sample_data/pose/000121.txt new file mode 100644 index 0000000000000000000000000000000000000000..b8c0d781bee51b0bedecad3a4d3a7f1bb3696f5c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000121.txt @@ -0,0 +1 @@ +-0.06720828264951706 0.3962496519088745 -0.9156795144081116 1.0988155603408813 -3.166496043149891e-08 -0.9177546501159668 -0.39714762568473816 0.476577490568161 -0.9977388381958008 -0.0266916174441576 0.06168072298169136 -0.07401663064956665 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000122.txt b/3DTopia/assets/sample_data/pose/000122.txt new file mode 100644 index 0000000000000000000000000000000000000000..c4264455ebecff57680a44302b389b3e0f933d0a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000122.txt @@ -0,0 +1 @@ +-0.2640887498855591 0.39690470695495605 -0.8790467977523804 1.0548564195632935 -1.0430807861894209e-07 -0.9114033579826355 -0.4115141034126282 0.4938172996044159 -0.9644981622695923 -0.1086762398481369 0.2406913787126541 -0.28882941603660583 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000123.txt b/3DTopia/assets/sample_data/pose/000123.txt new file mode 100644 index 0000000000000000000000000000000000000000..933e6ad09b0591b1e6064a4090051f2654729453 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000123.txt @@ -0,0 +1 @@ +-0.45044049620628357 0.3801383674144745 -0.807835042476654 0.9694024920463562 -1.0430805730266002e-07 -0.9048270583152771 -0.42577916383743286 0.5109351277351379 -0.8928060531616211 -0.19178827106952667 0.4075707793235779 -0.489084929227829 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000124.txt b/3DTopia/assets/sample_data/pose/000124.txt new file mode 100644 index 0000000000000000000000000000000000000000..4f48bf53bcde81245411808371cb73a0261919e6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000124.txt @@ -0,0 +1 @@ +-0.6188351511955261 0.3455813229084015 -0.7054193019866943 0.8465033173561096 -4.470347292340193e-08 -0.8980275392532349 -0.439939022064209 0.5279269218444824 -0.7855206727981567 -0.2722497582435608 0.5557310581207275 -0.6668769717216492 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000125.txt b/3DTopia/assets/sample_data/pose/000125.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b1a8475036401e325178ba2ad9184218377e036 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000125.txt @@ -0,0 +1 @@ +-0.7625583410263062 0.2936951220035553 -0.576409101486206 0.6916911005973816 -4.4703462265260896e-08 -0.8910065293312073 -0.4539904296398163 0.5447885990142822 -0.6469191312789917 -0.3461942672729492 0.6794444918632507 -0.8153334856033325 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000126.txt b/3DTopia/assets/sample_data/pose/000126.txt new file mode 100644 index 0000000000000000000000000000000000000000..338c3452f049a63246b7b3913ba4225f26a96e33 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000126.txt @@ -0,0 +1 @@ +-0.8758810758590698 0.22578869760036469 -0.42644059658050537 0.5117289423942566 -4.470348002882929e-08 -0.8837655782699585 -0.4679297208786011 0.5615156888961792 -0.48252683877944946 -0.4098508059978485 0.774073600769043 -0.9288883209228516 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000127.txt b/3DTopia/assets/sample_data/pose/000127.txt new file mode 100644 index 0000000000000000000000000000000000000000..de9dc4ed1416526cc4fabeb0d5defe351813f2a9 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000127.txt @@ -0,0 +1 @@ +-0.9542849063873291 0.14399497210979462 -0.2619266211986542 0.3143114745616913 3.650783639841393e-07 -0.8763066530227661 -0.4817536175251007 0.5781043767929077 -0.2988981604576111 -0.459730327129364 0.8362461924552917 -1.003495693206787 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000128.txt b/3DTopia/assets/sample_data/pose/000128.txt new file mode 100644 index 0000000000000000000000000000000000000000..4afe2894717f298b381ac546ad1dd410ff91e788 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000128.txt @@ -0,0 +1 @@ +-0.9946445822715759 0.05120658501982689 -0.08977368474006653 0.1077304556965828 -5.103644298287691e-07 -0.8686315417289734 -0.4954585134983063 0.5945504903793335 -0.10335099697113037 -0.49280521273612976 0.8639796376228333 -1.0367757081985474 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000129.txt b/3DTopia/assets/sample_data/pose/000129.txt new file mode 100644 index 0000000000000000000000000000000000000000..92e56dd71cc479b36a5d57724ffcd5f1186b52a8 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000129.txt @@ -0,0 +1 @@ +-0.9953508973121643 -0.04902723804116249 0.08289926499128342 -0.09948068112134933 6.332989528345934e-07 -0.8607419729232788 -0.5090413689613342 0.6108497381210327 0.0963117778301239 -0.5066748857498169 0.8567403554916382 -1.0280886888504028 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000130.txt b/3DTopia/assets/sample_data/pose/000130.txt new file mode 100644 index 0000000000000000000000000000000000000000..ecd4e0ee7c552170eae6ae7b1b9ad3efdb114cfb --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000130.txt @@ -0,0 +1 @@ +-0.9563760161399841 -0.1526419222354889 0.24908897280693054 -0.2989071309566498 7.450580596923828e-09 -0.8526401519775391 -0.5224985480308533 0.6269983053207397 0.2921384572982788 -0.49970507621765137 0.8154445886611938 -0.9785334467887878 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000131.txt b/3DTopia/assets/sample_data/pose/000131.txt new file mode 100644 index 0000000000000000000000000000000000000000..37207cc643e3c0f610c452f13e7e4cb727e350a1 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000131.txt @@ -0,0 +1 @@ +-0.8792728781700134 -0.2552239000797272 0.4021683931350708 -0.48260238766670227 7.450576333667414e-08 -0.844327986240387 -0.5358267426490784 0.6429921984672546 0.476317822933197 -0.47113797068595886 0.7423946261405945 -0.8908737301826477 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000132.txt b/3DTopia/assets/sample_data/pose/000132.txt new file mode 100644 index 0000000000000000000000000000000000000000..8c91ab4790a62fe18fb51c98bcd3267cee3b96f2 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000132.txt @@ -0,0 +1 @@ +-0.767116367816925 -0.35220232605934143 0.536176860332489 -0.6434125900268555 -0.0 -0.8358073830604553 -0.5490227341651917 0.6588274240493774 0.6415076851844788 -0.4211643934249878 0.6411615014076233 -0.7693938612937927 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000133.txt b/3DTopia/assets/sample_data/pose/000133.txt new file mode 100644 index 0000000000000000000000000000000000000000..663c86d135a5adaf022d0d4181a456face2bf48d --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000133.txt @@ -0,0 +1 @@ +-0.6243770122528076 -0.4390561878681183 0.6460515856742859 -0.7752619981765747 5.960462345910855e-08 -0.8270806670188904 -0.5620833039283752 0.674500048160553 0.7811228632926941 -0.3509519398212433 0.5164101123809814 -0.6196922063827515 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000134.txt b/3DTopia/assets/sample_data/pose/000134.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce2f7e85679bd8570b11952368e620fac314fdad --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000134.txt @@ -0,0 +1 @@ +-0.45674604177474976 -0.5115229487419128 0.7278236150741577 -0.8733885884284973 4.470348002882929e-08 -0.8181498050689697 -0.5750052332878113 0.6900063753128052 0.8895970582962036 -0.26263129711151123 0.37368670105934143 -0.4484238922595978 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000135.txt b/3DTopia/assets/sample_data/pose/000135.txt new file mode 100644 index 0000000000000000000000000000000000000000..d9ee972511034a81abeb0ce52a353c6bfa118964 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000135.txt @@ -0,0 +1 @@ +-0.27090585231781006 -0.5658054351806641 0.7787646651268005 -0.9345173835754395 1.4901162970204496e-08 -0.8090171217918396 -0.5877851843833923 0.7053421139717102 0.962605893611908 -0.159234419465065 0.21916747093200684 -0.26300084590911865 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000136.txt b/3DTopia/assets/sample_data/pose/000136.txt new file mode 100644 index 0000000000000000000000000000000000000000..54f5dc33b1954e1c00dd58012f9a07f7837d1219 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000136.txt @@ -0,0 +1 @@ +-0.0742654800415039 -0.5987615585327148 0.7974767088890076 -0.9569714665412903 -7.4505797087454084e-09 -0.7996850609779358 -0.6004195213317871 0.7205042839050293 0.9972383975982666 -0.04459046944975853 0.059388987720012665 -0.07126673310995102 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000137.txt b/3DTopia/assets/sample_data/pose/000137.txt new file mode 100644 index 0000000000000000000000000000000000000000..539b9094884bbc1dfd05b284ed81410379f30ca4 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000137.txt @@ -0,0 +1 @@ +0.12533560395240784 -0.6080741286277771 0.7839240431785583 -0.9407090544700623 3.725290298461914e-09 -0.7901548743247986 -0.6129072308540344 0.7354885339736938 0.9921144247055054 0.07681908458471298 -0.09903453290462494 0.1188414990901947 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000138.txt b/3DTopia/assets/sample_data/pose/000138.txt new file mode 100644 index 0000000000000000000000000000000000000000..edd169da92400cf23729351c5e4b5cda4904e343 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000138.txt @@ -0,0 +1 @@ +0.3199399411678314 -0.5923787355422974 0.7394092082977295 -0.8872910737991333 -0.0 -0.7804303765296936 -0.6252428889274597 0.7502911686897278 0.9474378824234009 0.20004017651081085 -0.24969083070755005 0.2996290326118469 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000139.txt b/3DTopia/assets/sample_data/pose/000139.txt new file mode 100644 index 0000000000000000000000000000000000000000..ceb637943c6e219c3e9dd2250a0635aa63851ef6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000139.txt @@ -0,0 +1 @@ +0.5017891526222229 -0.5513653755187988 0.6664860248565674 -0.7997834086418152 1.192092469182171e-07 -0.770513117313385 -0.6374239921569824 0.7649087309837341 0.8649898171424866 0.3198525011539459 -0.38663509488105774 0.4639623761177063 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000140.txt b/3DTopia/assets/sample_data/pose/000140.txt new file mode 100644 index 0000000000000000000000000000000000000000..61e450b42271cc785ec348dc49a5a61705a33b8c --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000140.txt @@ -0,0 +1 @@ +0.6636338233947754 -0.4858245551586151 0.5688273906707764 -0.6825927495956421 8.940696716308594e-08 -0.7604058980941772 -0.6494481563568115 0.7793376445770264 0.7480576038360596 0.43099576234817505 -0.5046311020851135 0.6055574417114258 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000141.txt b/3DTopia/assets/sample_data/pose/000141.txt new file mode 100644 index 0000000000000000000000000000000000000000..127403537b562462ed6dc6f8b739ecae8b1d2e5a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000141.txt @@ -0,0 +1 @@ +0.7990213632583618 -0.3976486325263977 0.4510436952114105 -0.5412524342536926 2.9802318834981634e-08 -0.7501109838485718 -0.66131192445755 0.793574333190918 0.6013026237487793 0.5284023284912109 -0.5993546843528748 0.7192258834838867 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000142.txt b/3DTopia/assets/sample_data/pose/000142.txt new file mode 100644 index 0000000000000000000000000000000000000000..8e60c3f744ad13545f419459d0d546dac43680db --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000142.txt @@ -0,0 +1 @@ +0.9025545120239258 -0.2897827625274658 0.31846699118614197 -0.38216039538383484 -2.9802322387695312e-08 -0.7396310567855835 -0.6730126142501831 0.8076150417327881 0.43057551980018616 0.607430636882782 -0.6675573587417603 0.8010690212249756 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000143.txt b/3DTopia/assets/sample_data/pose/000143.txt new file mode 100644 index 0000000000000000000000000000000000000000..287c8aa1d607520837208c984c89f47387edd729 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000143.txt @@ -0,0 +1 @@ +0.970105767250061 -0.16612771153450012 0.17690807580947876 -0.21228961646556854 7.450580596923828e-09 -0.728968620300293 -0.6845470666885376 0.8214565515518188 0.24268268048763275 0.6640830039978027 -0.707176685333252 0.8486118316650391 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000144.txt b/3DTopia/assets/sample_data/pose/000144.txt new file mode 100644 index 0000000000000000000000000000000000000000..be8c242941a09da32614d3bb1b80f8723a7f8c48 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000144.txt @@ -0,0 +1 @@ +0.9989818334579468 -0.03139603137969971 0.03239819407463074 -0.038877833634614944 1.862645371275562e-09 -0.7181263566017151 -0.6959128975868225 0.8350954055786133 0.04511489346623421 0.6952042579650879 -0.7173951864242554 0.8608742952346802 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000145.txt b/3DTopia/assets/sample_data/pose/000145.txt new file mode 100644 index 0000000000000000000000000000000000000000..03d849a0a2e74ff583ac675767c2467801cd8f02 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000145.txt @@ -0,0 +1 @@ +0.9880316853523254 0.1090722382068634 -0.1090722382068634 0.13088670372962952 -0.0 -0.7071068286895752 -0.7071068286895752 0.848528265953064 -0.15425144135951996 0.6986439228057861 -0.6986439228057861 0.8383726477622986 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000146.txt b/3DTopia/assets/sample_data/pose/000146.txt new file mode 100644 index 0000000000000000000000000000000000000000..76962ce71616e1212648a9cc68c9df8d1af1bbb0 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000146.txt @@ -0,0 +1 @@ +0.9376917481422424 0.2495262175798416 -0.24180757999420166 0.2901691794395447 2.980232594040899e-08 -0.6959127187728882 -0.7181264758110046 0.8617516160011292 -0.3474683463573456 0.6733812093734741 -0.6525515913963318 0.7830621004104614 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000147.txt b/3DTopia/assets/sample_data/pose/000147.txt new file mode 100644 index 0000000000000000000000000000000000000000..d94caf48ba11227454895acbe160a5c35fe08b24 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000147.txt @@ -0,0 +1 @@ +0.8499690294265747 0.38404446840286255 -0.3606417775154114 0.4327700436115265 -1.4901161193847656e-08 -0.6845471262931824 -0.728968620300293 0.8747624158859253 -0.5268326997756958 0.619600772857666 -0.5818438529968262 0.6982126832008362 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000148.txt b/3DTopia/assets/sample_data/pose/000148.txt new file mode 100644 index 0000000000000000000000000000000000000000..0521e665ca29959ad23ce635da508f1e673821ce --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000148.txt @@ -0,0 +1 @@ +0.7283607721328735 0.5067906975746155 -0.46114397048950195 0.553372859954834 2.9802322387695312e-08 -0.6730124950408936 -0.7396311163902283 0.8875573873519897 -0.6851938366889954 0.5387182831764221 -0.4901959300041199 0.588235080242157 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000149.txt b/3DTopia/assets/sample_data/pose/000149.txt new file mode 100644 index 0000000000000000000000000000000000000000..247211ba513d96f8a8bbfd06470b338387f490b6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000149.txt @@ -0,0 +1 @@ +0.5777150988578796 0.6122695207595825 -0.5397883057594299 0.6477459669113159 5.960465188081798e-08 -0.6613119840621948 -0.7501110434532166 0.9001331925392151 -0.8162385821342468 0.4333503842353821 -0.38204991817474365 0.4584598243236542 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000150.txt b/3DTopia/assets/sample_data/pose/000150.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e4259f98d11229016597f447fda80f832b2ebd5 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000150.txt @@ -0,0 +1 @@ +0.4040374457836151 0.6955756545066833 -0.5940772891044617 0.7128933072090149 -2.2351731843173184e-07 -0.6494478583335876 -0.760405957698822 0.912487268447876 -0.9147422909736633 0.30723246932029724 -0.26240116357803345 0.3148817718029022 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000151.txt b/3DTopia/assets/sample_data/pose/000151.txt new file mode 100644 index 0000000000000000000000000000000000000000..e51fabb814d266646dedf7cca87b970ef7a0c5a1 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000151.txt @@ -0,0 +1 @@ +0.21425242722034454 0.7526208758354187 -0.622621476650238 0.7471462488174438 -5.215405352032576e-08 -0.6374234557151794 -0.7705134749412537 0.924615740776062 -0.9767781496047974 0.1650843769311905 -0.1365695297718048 0.16388361155986786 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000152.txt b/3DTopia/assets/sample_data/pose/000152.txt new file mode 100644 index 0000000000000000000000000000000000000000..02f0d719f3488d06de0c57f81847788a949e7740 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000152.txt @@ -0,0 +1 @@ +0.01592571847140789 0.7803348898887634 -0.6251589059829712 0.7501961588859558 -6.938351759799843e-08 -0.6252382397651672 -0.7804338932037354 0.936516523361206 -0.9998730421066284 0.01242893747985363 -0.009957351721823215 0.01194903813302517 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000153.txt b/3DTopia/assets/sample_data/pose/000153.txt new file mode 100644 index 0000000000000000000000000000000000000000..b5446b2e71626b4e006e7f84b339e1a5252fb931 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000153.txt @@ -0,0 +1 @@ +-0.18303552269935608 0.7768062353134155 -0.6025526523590088 0.7230633497238159 -1.043080928297968e-07 -0.6129070520401001 -0.7901548743247986 0.9481860995292664 -0.9831060767173767 -0.14462649822235107 0.11218380182981491 -0.13462068140506744 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000154.txt b/3DTopia/assets/sample_data/pose/000154.txt new file mode 100644 index 0000000000000000000000000000000000000000..03f3501e10aae7752f61e4b2bf11395797ab6693 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000154.txt @@ -0,0 +1 @@ +-0.3747004270553589 0.7414241433143616 -0.5566772818565369 0.6680126190185547 -4.470347292340193e-08 -0.6004204750061035 -0.799684464931488 0.9596216082572937 -0.9271458387374878 -0.2996421754360199 0.2249777913093567 -0.26997312903404236 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000155.txt b/3DTopia/assets/sample_data/pose/000155.txt new file mode 100644 index 0000000000000000000000000000000000000000..9adba264cdd3fa5131f7e00537dbbb41cd4b2341 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000155.txt @@ -0,0 +1 @@ +-0.5514266490936279 0.6749008893966675 -0.49034419655799866 0.5884130001068115 1.4901161193847656e-08 -0.5877853035926819 -0.80901700258255 0.970820426940918 -0.8342233896255493 -0.44611358642578125 0.32412049174308777 -0.38894450664520264 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000156.txt b/3DTopia/assets/sample_data/pose/000156.txt new file mode 100644 index 0000000000000000000000000000000000000000..083b436903364a1b52b5ef97e083c72d298f50ae --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000156.txt @@ -0,0 +1 @@ +-0.7061696648597717 0.5792850852012634 -0.40712860226631165 0.48855406045913696 1.4901168299275014e-08 -0.5750053524971008 -0.818149745464325 0.9817796945571899 -0.7080430388450623 -0.5777523517608643 0.4060514271259308 -0.48726147413253784 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000157.txt b/3DTopia/assets/sample_data/pose/000157.txt new file mode 100644 index 0000000000000000000000000000000000000000..0a6b93de25e0422c2661ad190d8bbd0cfadaf003 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000157.txt @@ -0,0 +1 @@ +-0.8327592015266418 0.4579005837440491 -0.3111884593963623 0.37342676520347595 -3.7252868878567824e-07 -0.5620833039283752 -0.827080488204956 0.9924965500831604 -0.553634524345398 -0.6887590885162354 0.4680800437927246 -0.5616962909698486 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000158.txt b/3DTopia/assets/sample_data/pose/000158.txt new file mode 100644 index 0000000000000000000000000000000000000000..f32d22f520ff6c3fa804bc923787587db83ff477 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000158.txt @@ -0,0 +1 @@ +-0.9261496067047119 0.3152289092540741 -0.2070665806531906 0.24848027527332306 -7.45057349149647e-08 -0.5490228533744812 -0.835807204246521 1.0029689073562622 -0.3771549463272095 -0.7740828990936279 0.5084771513938904 -0.6101730465888977 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000159.txt b/3DTopia/assets/sample_data/pose/000159.txt new file mode 100644 index 0000000000000000000000000000000000000000..d059c9453bb4b6d6036e6357d94b15aae3cc8035 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000159.txt @@ -0,0 +1 @@ +-0.9826175570487976 0.15674014389514923 -0.09946972131729126 0.11936488747596741 -4.917378646496218e-07 -0.5358267426490784 -0.8443277478218079 1.0131934881210327 -0.18563862144947052 -0.8296516537666321 0.5265126824378967 -0.6318156123161316 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000160.txt b/3DTopia/assets/sample_data/pose/000160.txt new file mode 100644 index 0000000000000000000000000000000000000000..47ea199c68e79262de73bd5b2f7a84c18a7c2b2a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000160.txt @@ -0,0 +1 @@ +-0.9999117851257324 -0.011320343241095543 0.006937115918844938 -0.008324497379362583 -0.0 -0.522498607635498 -0.8526401519775391 1.0231682062149048 0.013276812620460987 -0.8525649905204773 0.5224524140357971 -0.626943051815033 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000161.txt b/3DTopia/assets/sample_data/pose/000161.txt new file mode 100644 index 0000000000000000000000000000000000000000..30010e886c7d3b90c795b97930ecd7b672f6843f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000161.txt @@ -0,0 +1 @@ +-0.9773425459861755 -0.18218766152858734 0.10774486511945724 -0.12929487228393555 5.662440685227921e-07 -0.5090415477752686 -0.8607418537139893 1.0328903198242188 0.2116631716489792 -0.8412397503852844 0.49750807881355286 -0.5970093607902527 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000162.txt b/3DTopia/assets/sample_data/pose/000162.txt new file mode 100644 index 0000000000000000000000000000000000000000..18de1b495916f082a77529102796b17272bfb15b --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000162.txt @@ -0,0 +1 @@ +-0.9158093929290771 -0.3488530218601227 0.1989825814962387 -0.2387789785861969 -2.6822073095900123e-07 -0.49545881152153015 -0.8686313033103943 1.0423578023910522 0.40161237120628357 -0.7955010533332825 0.45374563336372375 -0.5444949865341187 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000163.txt b/3DTopia/assets/sample_data/pose/000163.txt new file mode 100644 index 0000000000000000000000000000000000000000..0ffb1eb0f1f1b14928f8608d55d7caf394f7dfd6 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000163.txt @@ -0,0 +1 @@ +-0.8177659511566162 -0.5043586492538452 0.2772735059261322 -0.3327282667160034 1.4901154088420299e-08 -0.48175370693206787 -0.8763065338134766 1.051567792892456 0.575550377368927 -0.7166138887405396 0.39396175742149353 -0.47275421023368835 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000164.txt b/3DTopia/assets/sample_data/pose/000164.txt new file mode 100644 index 0000000000000000000000000000000000000000..d6dbce3289f8ad6e960998dd2bf14f9b2d1d3aa7 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000164.txt @@ -0,0 +1 @@ +-0.6871210336685181 -0.6420933604240417 0.33997073769569397 -0.4079653024673462 3.1292415769712534e-07 -0.4679299294948578 -0.8837653398513794 1.0605188608169556 0.7265424728393555 -0.6072539687156677 0.3215245306491852 -0.385829359292984 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000165.txt b/3DTopia/assets/sample_data/pose/000165.txt new file mode 100644 index 0000000000000000000000000000000000000000..9fcf6f650fa615f11a899719603cd1ce69db335e --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000165.txt @@ -0,0 +1 @@ +-0.5290825366973877 -0.7560816407203674 0.3852425813674927 -0.4622913897037506 1.3411039390121005e-07 -0.4539903998374939 -0.8910064697265625 1.069207787513733 0.8485701680183411 -0.47141605615615845 0.24019837379455566 -0.28823819756507874 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000166.txt b/3DTopia/assets/sample_data/pose/000166.txt new file mode 100644 index 0000000000000000000000000000000000000000..87b4b3e158ffe731093160489ff82a88974dddf2 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000166.txt @@ -0,0 +1 @@ +-0.3499513566493988 -0.8412432074546814 0.4121206998825073 -0.49454501271247864 2.0861612881617475e-07 -0.4399392604827881 -0.8980274200439453 1.0776331424713135 0.936767578125 -0.3142661154270172 0.15395735204219818 -0.18474876880645752 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000167.txt b/3DTopia/assets/sample_data/pose/000167.txt new file mode 100644 index 0000000000000000000000000000000000000000..1570aaacef31dad805db29f459e59a1d671fedfb --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000167.txt @@ -0,0 +1 @@ +-0.15686866641044617 -0.8936248421669006 0.42050766944885254 -0.5046095252037048 8.940696005765858e-08 -0.4257791042327881 -0.9048271179199219 1.0857925415039062 0.9876194596290588 -0.14193904399871826 0.06679143011569977 -0.08014968037605286 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000168.txt b/3DTopia/assets/sample_data/pose/000168.txt new file mode 100644 index 0000000000000000000000000000000000000000..28e1bdfd3ed2fa47360b21a91ddf4dc100c462df --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000168.txt @@ -0,0 +1 @@ +0.04246791824698448 -0.910581648349762 0.4111417531967163 -0.4933716952800751 2.60770320892334e-08 -0.41151300072669983 -0.9114038348197937 1.0936838388442993 0.9990978240966797 0.03870541974902153 -0.017476091161370277 0.020971447229385376 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000169.txt b/3DTopia/assets/sample_data/pose/000169.txt new file mode 100644 index 0000000000000000000000000000000000000000..b9de0ca4e16008afcb9074467e065abac6a06e1f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000169.txt @@ -0,0 +1 @@ +0.2401116043329239 -0.890906035900116 0.3855292797088623 -0.462635338306427 -1.4901161193847656e-08 -0.3971477150917053 -0.9177546501159668 1.1013054847717285 0.970745325088501 0.2203635573387146 -0.09535978734493256 0.11443176120519638 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000170.txt b/3DTopia/assets/sample_data/pose/000170.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca76a0491d972790bc92a6ff2e60ddd24604f2c2 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000170.txt @@ -0,0 +1 @@ +0.42818257212638855 -0.834902822971344 0.34582775831222534 -0.41499361395835876 -2.9802318834981634e-08 -0.3826831877231598 -0.9238796234130859 1.1086554527282715 0.9036921262741089 0.39558911323547363 -0.1638583242893219 0.1966300755739212 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000171.txt b/3DTopia/assets/sample_data/pose/000171.txt new file mode 100644 index 0000000000000000000000000000000000000000..d018746d105d9ca7bc289bd82187832022b2a396 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000171.txt @@ -0,0 +1 @@ +0.5991833806037903 -0.744390070438385 0.29472458362579346 -0.35366982221603394 1.4901157641133977e-08 -0.3681242763996124 -0.9297765493392944 1.1157318353652954 0.8006117343902588 0.557106614112854 -0.22057394683361053 0.26468899846076965 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000172.txt b/3DTopia/assets/sample_data/pose/000172.txt new file mode 100644 index 0000000000000000000000000000000000000000..02c9cdd36ee85a463ca7c1fd981ecd6b5a8c12f9 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000172.txt @@ -0,0 +1 @@ +0.7462967038154602 -0.6226441264152527 0.23527754843235016 -0.282333105802536 2.9802318834981634e-08 -0.35347476601600647 -0.9354439973831177 1.1225329637527466 0.6656134724617004 0.6981187462806702 -0.26379701495170593 0.31655651330947876 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000173.txt b/3DTopia/assets/sample_data/pose/000173.txt new file mode 100644 index 0000000000000000000000000000000000000000..0174b2409348ae3a8a2ca1c7cf957d847ee5a17f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000173.txt @@ -0,0 +1 @@ +0.8636573553085327 -0.47427839040756226 0.1707507073879242 -0.20490090548992157 -0.0 -0.3387379050254822 -0.9408807158470154 1.1290568113327026 0.5040791630744934 0.8125985264778137 -0.29255348443984985 0.3510642647743225 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000174.txt b/3DTopia/assets/sample_data/pose/000174.txt new file mode 100644 index 0000000000000000000000000000000000000000..591a8dee55f6948e83a456fa86296bf3550aa528 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000174.txt @@ -0,0 +1 @@ +0.9465868473052979 -0.30506426095962524 0.10444684326648712 -0.12533621490001678 1.4901162970204496e-08 -0.32391735911369324 -0.9460853934288025 1.1353024244308472 0.3224489986896515 0.8955519795417786 -0.306615948677063 0.3679392337799072 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000175.txt b/3DTopia/assets/sample_data/pose/000175.txt new file mode 100644 index 0000000000000000000000000000000000000000..e03ff318cf1447e40b5c296bba185eb54fe3663a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000175.txt @@ -0,0 +1 @@ +0.9917788505554199 -0.1217007040977478 0.03954293206334114 -0.04745154082775116 -3.725290298461914e-09 -0.3090168833732605 -0.9510565400123596 1.141268014907837 0.12796369194984436 0.9432377815246582 -0.3064764142036438 0.36777186393737793 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000176.txt b/3DTopia/assets/sample_data/pose/000176.txt new file mode 100644 index 0000000000000000000000000000000000000000..a03b53267607cc86cfb041843ebf308eb288e916 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000176.txt @@ -0,0 +1 @@ +0.997431755065918 0.06845686584711075 -0.021060077473521233 0.025272099301218987 -0.0 -0.2940402626991272 -0.955793023109436 1.146951675415039 -0.07162310928106308 0.9533383250236511 -0.2932851016521454 0.35194218158721924 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000177.txt b/3DTopia/assets/sample_data/pose/000177.txt new file mode 100644 index 0000000000000000000000000000000000000000..f577381efeecf95d9bb59657c115ba47f8ceb97a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000177.txt @@ -0,0 +1 @@ +0.9633201360702515 0.257699191570282 -0.07486848533153534 0.08984224498271942 -6.705520405603238e-08 -0.27899086475372314 -0.9602935314178467 1.1523525714874268 -0.26835453510284424 0.9250701665878296 -0.26875752210617065 0.3225092887878418 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000178.txt b/3DTopia/assets/sample_data/pose/000178.txt new file mode 100644 index 0000000000000000000000000000000000000000..8baf4f70fb1b4b8c16f7f13212a36eff7ffd40e7 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000178.txt @@ -0,0 +1 @@ +0.8908040523529053 0.4382828176021576 -0.1199006661772728 0.143880695104599 -0.0 -0.2638731598854065 -0.9645572900772095 1.1574686765670776 -0.45438748598098755 0.8592315912246704 -0.23505929112434387 0.2820710241794586 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000179.txt b/3DTopia/assets/sample_data/pose/000179.txt new file mode 100644 index 0000000000000000000000000000000000000000..79dd4508d344313742c389fbc5eb1794defab2a7 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000179.txt @@ -0,0 +1 @@ +0.7827745079994202 0.6027545928955078 -0.15476103127002716 0.1857132613658905 -1.4901161193847656e-08 -0.24868980050086975 -0.9685832262039185 1.1622997522354126 -0.6223054528236389 0.7581822276115417 -0.1946680247783661 0.23360170423984528 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000180.txt b/3DTopia/assets/sample_data/pose/000180.txt new file mode 100644 index 0000000000000000000000000000000000000000..14c85ed729cb52a1e206fb6e5c39309874fe9426 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000180.txt @@ -0,0 +1 @@ +0.6435381174087524 0.7442656755447388 -0.17868220806121826 0.21441881358623505 -5.960463766996327e-08 -0.23344513773918152 -0.972369909286499 1.1668438911437988 -0.7654140591621399 0.625757098197937 -0.15023081004619598 0.18027718365192413 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000181.txt b/3DTopia/assets/sample_data/pose/000181.txt new file mode 100644 index 0000000000000000000000000000000000000000..a668f50004a5a71ff3758a9f2890a8dd86d4458f --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000181.txt @@ -0,0 +1 @@ +0.4786456823348999 0.8568629026412964 -0.1915312558412552 0.2298377901315689 -1.4901159417490817e-08 -0.21814295649528503 -0.9759168028831482 1.1711000204086304 -0.8780080676078796 0.46711835265159607 -0.10441319644451141 0.12529604136943817 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000182.txt b/3DTopia/assets/sample_data/pose/000182.txt new file mode 100644 index 0000000000000000000000000000000000000000..314070f0cd1522f64866ede307832dab2294bdcc --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000182.txt @@ -0,0 +1 @@ +0.2946713864803314 0.9357439875602722 -0.19378307461738586 0.23253990709781647 2.607702853651972e-08 -0.2027871459722519 -0.9792227745056152 1.17506742477417 -0.9555985927581787 0.28854894638061523 -0.05975561589002609 0.07170679420232773 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000183.txt b/3DTopia/assets/sample_data/pose/000183.txt new file mode 100644 index 0000000000000000000000000000000000000000..366bcc215c52775da6e8e4544bc05fedaa1d21be --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000183.txt @@ -0,0 +1 @@ +0.09894955158233643 0.9774670600891113 -0.186459481716156 0.22375409305095673 4.6566128730773926e-08 -0.18737904727458954 -0.9822876453399658 1.1787446737289429 -0.9950924515724182 0.09719691425561905 -0.01854112185537815 0.022249583154916763 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000184.txt b/3DTopia/assets/sample_data/pose/000184.txt new file mode 100644 index 0000000000000000000000000000000000000000..a12478cd73cfad215240a993b54e9e0cd6fe16ff --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000184.txt @@ -0,0 +1 @@ +-0.10071694850921631 0.9801000952720642 -0.17105454206466675 0.2052658498287201 -6.89178563106907e-08 -0.17192883789539337 -0.985109269618988 1.1821311712265015 -0.9949150085449219 -0.0992172583937645 0.01731616072356701 -0.0207794401794672 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000185.txt b/3DTopia/assets/sample_data/pose/000185.txt new file mode 100644 index 0000000000000000000000000000000000000000..6622c985b9bf721d1e955e2e63acf744a0c6770d --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000185.txt @@ -0,0 +1 @@ +-0.29636847972869873 0.9433149099349976 -0.14940685033798218 0.17928773164749146 -1.527368311826649e-07 -0.15643510222434998 -0.9876880645751953 1.1852260828018188 -0.9550734162330627 -0.2927198112010956 0.046362414956092834 -0.05563471466302872 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000186.txt b/3DTopia/assets/sample_data/pose/000186.txt new file mode 100644 index 0000000000000000000000000000000000000000..d8204fa64ab060476fc74f601d18278867097d68 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000186.txt @@ -0,0 +1 @@ +-0.4802047610282898 0.8684056997299194 -0.12359234690666199 0.14831088483333588 -1.527369306586479e-07 -0.14090117812156677 -0.9900237321853638 1.188028335571289 -0.8771565556526184 -0.4754140377044678 0.06766162067651749 -0.0811937153339386 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000187.txt b/3DTopia/assets/sample_data/pose/000187.txt new file mode 100644 index 0000000000000000000000000000000000000000..c5a6e6b7b897650d9f808da3b0f6bc606746ee74 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000187.txt @@ -0,0 +1 @@ +-0.6448968648910522 0.7582430243492126 -0.0957883819937706 0.11494607478380203 -1.4156101713069802e-07 -0.12533338367938995 -0.9921146631240845 1.190537691116333 -0.7642695903778076 -0.6398116946220398 0.08082716166973114 -0.0969923734664917 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000188.txt b/3DTopia/assets/sample_data/pose/000188.txt new file mode 100644 index 0000000000000000000000000000000000000000..da3d3f26f0049ad7d135f80d80876ffc78f99cca --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000188.txt @@ -0,0 +1 @@ +-0.783878743648529 0.6171642541885376 -0.06813523918390274 0.08176270127296448 -3.2782554626464844e-07 -0.10973420739173889 -0.9939610362052917 1.1927533149719238 -0.6209139823913574 -0.7791448831558228 0.08601849526166916 -0.1032220870256424 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000189.txt b/3DTopia/assets/sample_data/pose/000189.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa24f22623a4f6c44d69320e7361fbe5d1dcff29 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000189.txt @@ -0,0 +1 @@ +-0.8916096687316895 0.45079466700553894 -0.04261254519224167 0.05113517865538597 -8.19563368281706e-08 -0.09410841017961502 -0.9955618381500244 1.1946742534637451 -0.45280423760414124 -0.8876528143882751 0.08390773087739944 -0.10068946331739426 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000190.txt b/3DTopia/assets/sample_data/pose/000190.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c7b725f4f8bf00720fef46d91da5ea9c1df6038 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000190.txt @@ -0,0 +1 @@ +-0.9637949466705322 0.2658207416534424 -0.020919324830174446 0.02510467730462551 -1.2833611435780767e-06 -0.07845940440893173 -0.9969170093536377 1.196300745010376 -0.26664260029792786 -0.9608241319656372 0.07561864703893661 -0.09074220806360245 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000191.txt b/3DTopia/assets/sample_data/pose/000191.txt new file mode 100644 index 0000000000000000000000000000000000000000..016bad1f7081ba063ced918647343f73f0ba3bca --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000191.txt @@ -0,0 +1 @@ +-0.9975574016571045 0.06970969587564468 -0.004386159125715494 0.005263194907456636 3.86498697935167e-07 -0.06279079616069794 -0.9980266094207764 1.1976321935653687 -0.06984754651784897 -0.9955891370773315 0.06263715028762817 -0.07516458630561829 -0.0 0.0 0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000192.txt b/3DTopia/assets/sample_data/pose/000192.txt new file mode 100644 index 0000000000000000000000000000000000000000..a1350253ca0a3fab00909d501c9b03c75a81a618 --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000192.txt @@ -0,0 +1 @@ +-0.9915500283241272 -0.1295798271894455 0.0061101061291992664 -0.007333071436733007 7.31088050542894e-07 -0.0471065416932106 -0.9988898634910583 1.1986677646636963 0.12972381711006165 -0.9904493689537048 0.04670844227075577 -0.05605006963014603 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000193.txt b/3DTopia/assets/sample_data/pose/000193.txt new file mode 100644 index 0000000000000000000000000000000000000000..c48c861782fa785615b975d66ac418dbad46d76a --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000193.txt @@ -0,0 +1 @@ +-0.9460127353668213 -0.32396966218948364 0.0101803382858634 -0.012217401526868343 8.381903739973495e-07 -0.03141067922115326 -0.9995065927505493 1.199407935142517 0.324129581451416 -0.9455459117889404 0.02971518225967884 -0.03565797209739685 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/assets/sample_data/pose/000194.txt b/3DTopia/assets/sample_data/pose/000194.txt new file mode 100644 index 0000000000000000000000000000000000000000..addc79d52f19d732eb2b17c1eb4bcca99088534b --- /dev/null +++ b/3DTopia/assets/sample_data/pose/000194.txt @@ -0,0 +1 @@ +-0.8627605438232422 -0.5055502653121948 0.007941310293972492 -0.009530180133879185 6.407498744920304e-07 -0.01570744998753071 -0.9998766183853149 1.1998521089553833 0.5056126117706299 -0.8626541495323181 0.013552011922001839 -0.016261987388134003 -0.0 0.0 -0.0 1.0 diff --git a/3DTopia/configs/default.yaml b/3DTopia/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9fc30020f42cdc132b73509b0c48f8a9f58910e0 --- /dev/null +++ b/3DTopia/configs/default.yaml @@ -0,0 +1,72 @@ +model: + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + shift_scale: 2 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "triplane" + cond_stage_key: "caption" + image_size: 32 + channels: 8 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.5147210212065061 + use_ema: False + learning_rate: 5e-5 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 8 + out_channels: 8 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + context_dim: 768 + transformer_depth: 1 + use_checkpoint: True + legacy: False + + first_stage_config: + target: model.triplane_vae.AutoencoderKLRollOut + params: + embed_dim: 8 + learning_rate: 1e-5 + norm: False + renderer_type: eg3d + ddconfig: + double_z: true + z_channels: 8 + resolution: 256 + in_channels: 32 + out_ch: 32 + ch: 128 + ch_mult: + - 2 + - 4 + - 4 + - 8 + num_res_blocks: 2 + attn_resolutions: [32] + dropout: 0.0 + lossconfig: + kl_weight: 1e-5 + rec_weight: 1 + latent_tv_weight: 2e-3 + renderer_config: + rgbnet_dim: -1 + rgbnet_width: 128 + sigma_dim: 12 + c_dim: 20 + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPTextEmbedder + diff --git a/3DTopia/environment.yml b/3DTopia/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..2415cf057c34303455a4843b7db3b96a71955666 --- /dev/null +++ b/3DTopia/environment.yml @@ -0,0 +1,150 @@ +name: 3dtopia +channels: + - pytorch + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - brotli=1.0.9=h166bdaf_7 + - brotli-bin=1.0.9=h166bdaf_7 + - bzip2=1.0.8=h7f98852_4 + - ca-certificates=2023.5.7=hbcca054_0 + - certifi=2023.5.7=pyhd8ed1ab_0 + - charset-normalizer=3.1.0=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cudatoolkit=11.3.1=h9edb442_10 + - ffmpeg=4.3.2=hca11adc_0 + - freetype=2.10.4=h0708190_1 + - fsspec=2023.5.0=pyh1a96a4e_0 + - gmp=6.2.1=h58526e2_0 + - gnutls=3.6.13=h85f3911_1 + - idna=3.4=pyhd8ed1ab_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jpeg=9e=h166bdaf_1 + - lame=3.100=h7f98852_1001 + - lcms2=2.12=hddcbb42_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libbrotlicommon=1.0.9=h166bdaf_7 + - libbrotlidec=1.0.9=h166bdaf_7 + - libbrotlienc=1.0.9=h166bdaf_7 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libpng=1.6.37=h21135ba_2 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtiff=4.2.0=hecacb30_2 + - libwebp-base=1.2.2=h7f98852_1 + - lightning-utilities=0.8.0=pyhd8ed1ab_0 + - lz4-c=1.9.3=h9c3ff4c_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h95df7f1_0 + - mkl_fft=1.3.1=py38h8666266_1 + - mkl_random=1.2.2=py38h1abd341_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.6=he412f7d_0 + - numpy=1.24.3=py38h14f4228_0 + - numpy-base=1.24.3=py38h31eccc5_0 + - olefile=0.46=pyh9f0ad1d_1 + - openh264=2.1.1=h780b84a_0 + - openjpeg=2.4.0=hb52868f_1 + - openssl=1.1.1u=h7f8727e_0 + - packaging=23.1=pyhd8ed1ab_0 + - pip=23.0.1=py38h06a4308_0 + - pixman-cos6-x86_64=0.32.8=4 + - pysocks=1.7.1=pyha2e5f31_6 + - python=3.8.16=h7a1cb2a_3 + - python_abi=3.8=2_cp38 + - pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0 + - pytorch-lightning=2.0.2=pyhd8ed1ab_0 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0=py38h0a891b7_4 + - readline=8.2=h5eee18b_0 + - requests=2.31.0=pyhd8ed1ab_0 + - setuptools=67.8.0=py38h06a4308_0 + - six=1.16.0=pyh6c4a22f_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.12.0=py38_cu113 + - torchmetrics=0.11.4=pyhd8ed1ab_0 + - torchvision=0.13.0=py38_cu113 + - tqdm=4.65.0=pyhd8ed1ab_1 + - typing_extensions=4.6.3=pyha770c72_0 + - urllib3=2.0.2=pyhd8ed1ab_0 + - wheel=0.38.4=py38h06a4308_0 + - x264=1!161.3030=h7f98852_1 + - xorg-x11-server-common-cos6-x86_64=1.17.4=4 + - xorg-x11-server-xvfb-cos6-x86_64=1.17.4=4 + - xz=5.4.2=h5eee18b_0 + - yaml=0.2.5=h7f98852_2 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.2=ha4553b6_0 + - pip: + - antlr4-python3-runtime==4.9.3 + - appdirs==1.4.4 + - asttokens==2.4.0 + - av==10.0.0 + - backcall==0.2.0 + - click==8.1.3 + - git+https://github.com/openai/CLIP.git + - contourpy==1.1.1 + - cycler==0.12.1 + - decorator==5.1.1 + - docker-pycreds==0.4.0 + - einops==0.6.1 + - executing==1.2.0 + - filelock==3.12.2 + - fonttools==4.43.1 + - ftfy==6.1.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - huggingface-hub==0.16.4 + - imageio==2.31.0 + - imageio-ffmpeg==0.4.8 + - importlib-resources==6.1.0 + - ipdb==0.13.13 + - ipython==8.12.2 + - jedi==0.19.0 + - kiwisolver==1.4.5 + - kornia==0.6.0 + - lpips==0.1.4 + - matplotlib==3.7.3 + - matplotlib-inline==0.1.6 + - omegaconf==2.3.0 + - open-clip-torch==2.20.0 + - opencv-python==4.7.0.72 + - parso==0.8.3 + - pathtools==0.1.2 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pillow==9.5.0 + - prompt-toolkit==3.0.39 + - protobuf==3.20.3 + - psutil==5.9.5 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pygments==2.16.1 + - pymcubes==0.1.4 + - pyparsing==3.1.1 + - pytorch-fid==0.3.0 + - pytorch-msssim==1.0.0 + - regex==2023.6.3 + - safetensors==0.3.3 + - scipy==1.10.1 + - sentencepiece==0.1.99 + - sentry-sdk==1.25.0 + - setproctitle==1.3.2 + - smmap==5.0.0 + - stack-data==0.6.2 + - timm==0.9.7 + - tokenizers==0.12.1 + - tomli==2.0.1 + - traitlets==5.9.0 + - transformers + - trimesh==4.0.2 + - vit-pytorch==1.2.2 + - wandb==0.15.3 + - wcwidth==0.2.6 + - zipp==3.17.0 diff --git a/3DTopia/gradio_demo.py b/3DTopia/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f02e3354ae21a91699428ffa7e7ed454d1c5832b --- /dev/null +++ b/3DTopia/gradio_demo.py @@ -0,0 +1,334 @@ +import os +import cv2 +import time +import json +import torch +import mcubes +import trimesh +import datetime +import argparse +import subprocess +import numpy as np +import gradio as gr +from tqdm import tqdm +import imageio.v2 as imageio +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler + +from utility.initialize import instantiate_from_config, get_obj_from_str +from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes +from utility.triplane_renderer.renderer import get_rays, to8b +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download + +import warnings +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=DeprecationWarning) + +def add_text(rgb, caption): + font = cv2.FONT_HERSHEY_SIMPLEX + # org + gap = 10 + org = (gap, gap) + # fontScale + fontScale = 0.3 + # Blue color in BGR + color = (255, 0, 0) + # Line thickness of 2 px + thickness = 1 + break_caption = [] + for i in range(len(caption) // 30 + 1): + break_caption_i = caption[i*30:(i+1)*30] + break_caption.append(break_caption_i) + for i, bci in enumerate(break_caption): + cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) + return rgb + +config = "configs/default.yaml" +# ckpt = "checkpoints/3dtopia_diffusion_state_dict.ckpt" +ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors") +configs = OmegaConf.load(config) +os.makedirs("tmp", exist_ok=True) + +if ckpt.endswith(".ckpt"): + model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params) +elif ckpt.endswith(".safetensors"): + model = get_obj_from_str(configs.model["target"])(**configs.model.params) + model_ckpt = load_file(ckpt) + model.load_state_dict(model_ckpt) +else: + raise NotImplementedError +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +model = model.to(device) +sampler = DDIMSampler(model) + +img_size = configs.model.params.unet_config.params.image_size +channels = configs.model.params.unet_config.params.in_channels +shape = [channels, img_size, img_size * 3] + +pose_folder = 'assets/sample_data/pose' +poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)]) +batch_rays_list = [] +H = 128 +ratio = 512 // H +for p in poses_fname: + c2w = np.loadtxt(p).reshape(4, 4) + c2w[:3, 3] *= 2.2 + c2w = np.array([ + [1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1] + ]) @ c2w + + k = np.array([ + [560 / ratio, 0, H * 0.5], + [0, 560 / ratio, H * 0.5], + [0, 0, 1] + ]) + + rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4])) + coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1) + coords = torch.reshape(coords, [-1,2]).long() + rays_o = rays_o[coords[:, 0], coords[:, 1]] + rays_d = rays_d[coords[:, 0], coords[:, 1]] + batch_rays = torch.stack([rays_o, rays_d], 0) + batch_rays_list.append(batch_rays) +batch_rays_list = torch.stack(batch_rays_list, 0) + +def marching_cube(b, text, global_info): + # prepare volumn for marching cube + res = 128 + assert 'decode_res' in global_info + decode_res = global_info['decode_res'] + c_list = torch.linspace(-1.2, 1.2, steps=res) + grid_x, grid_y, grid_z = torch.meshgrid( + c_list, c_list, c_list, indexing='ij' + ) + coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) + plane_axes = generate_planes() + feats = sample_from_planes( + plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 + ) + fake_dirs = torch.zeros_like(coords) + fake_dirs[..., 0] = 1 + out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs) + u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() + del out + + # marching cube + vertices, triangles = mcubes.marching_cubes(u, 10) + min_bound = np.array([-1.2, -1.2, -1.2]) + max_bound = np.array([1.2, 1.2, 1.2]) + vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] + pt_vertices = torch.from_numpy(vertices).to(device) + + # extract vertices color + res_triplane = 256 + render_kwargs = { + 'depth_resolution': 128, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 128, + 'clamp_mode': 'softplus', + 'white_back': True, + 'det': True + } + rays_o_list = [ + np.array([0, 0, 2]), + np.array([0, 0, -2]), + np.array([0, 2, 0]), + np.array([0, -2, 0]), + np.array([2, 0, 0]), + np.array([-2, 0, 0]), + ] + rgb_final = None + diff_final = None + for rays_o in tqdm(rays_o_list): + rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) + rays_d = pt_vertices.reshape(-1, 3) - rays_o + rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) + dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) + + render_out = model.first_stage_model.triplane_decoder( + decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane), + rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, + whole_img=False, tvloss=False + ) + rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() + depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() + depth_diff = np.abs(dist - depth) + + if rgb_final is None: + rgb_final = rgb.copy() + diff_final = depth_diff.copy() + + else: + ind = diff_final > depth_diff + rgb_final[ind] = rgb[ind] + diff_final[ind] = depth_diff[ind] + + # bgr to rgb + rgb_final = np.stack([ + rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] + ], -1) + + # export to ply + mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) + path = os.path.join('tmp', f"{text.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.ply") + trimesh.exchange.export.export_mesh(mesh, path, file_type='ply') + + del vertices, triangles, rgb_final + torch.cuda.empty_cache() + + return path + +def infer(prompt, samples, steps, scale, seed, global_info): + prompt = prompt.replace('/', '') + pl.seed_everything(seed) + batch_size = samples + with torch.no_grad(): + noise = None + c = model.get_learned_conditioning([prompt]) + unconditional_c = torch.zeros_like(c) + sample, _ = sampler.sample( + S=steps, + batch_size=batch_size, + shape=shape, + verbose=False, + x_T = noise, + conditioning = c.repeat(batch_size, 1, 1), + unconditional_guidance_scale=scale, + unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1) + ) + decode_res = model.decode_first_stage(sample) + + big_video_list = [] + + global_info['decode_res'] = decode_res + + for b in range(batch_size): + def render_img(v): + rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder( + decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device), + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + rgb_sample = np.stack( + [rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1 + ) + rgb_sample = add_text(rgb_sample, str(b)) + return rgb_sample + + view_num = len(batch_rays_list) + video_list = [] + for v in tqdm(range(view_num//8*3, view_num//8*5, 2)): + rgb_sample = render_img(v) + video_list.append(rgb_sample) + big_video_list.append(video_list) + # if batch_size == 2: + # cat_video_list = [ + # np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \ + # for i in range(len(big_video_list[0])) + # ] + # elif batch_size > 2: + # if batch_size == 3: + # big_video_list.append( + # [np.zeros_like(f) for f in big_video_list[0]] + # ) + # cat_video_list = [ + # np.concatenate([ + # np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1), + # np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1), + # ], 0) \ + # for i in range(len(big_video_list[0])) + # ] + # else: + # cat_video_list = big_video_list[0] + + for _ in range(4 - batch_size): + big_video_list.append( + [np.zeros_like(f) + 255 for f in big_video_list[0]] + ) + cat_video_list = [ + np.concatenate([ + np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1), + np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1), + ], 0) \ + for i in range(len(big_video_list[0])) + ] + + path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4" + imageio.mimwrite(path, np.stack(cat_video_list, 0)) + + return global_info, path + +def infer_stage2(prompt, selection, seed, global_info): + prompt = prompt.replace('/', '') + mesh_path = marching_cube(int(selection), prompt, global_info) + mesh_name = mesh_path.split('/')[-1][:-4] + + if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y" + print(if2_cmd) + # os.system(if2_cmd) + subprocess.Popen(if2_cmd, shell=True).wait() + torch.cuda.empty_cache() + + video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4" + render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256" + print(render_cmd) + # os.system(render_cmd) + subprocess.Popen(render_cmd, shell=True).wait() + torch.cuda.empty_cache() + + return video_path, os.path.join('tmp', mesh_name + '_if2.glb') + +block = gr.Blocks() + +with block: + global_info = gr.State(dict()) + with gr.Row(): + with gr.Column(): + with gr.Row(): + text = gr.Textbox( + label = "Enter your prompt", + max_lines = 1, + placeholder = "Enter your prompt", + container = False, + ) + btn = gr.Button("Generate 3D") + gallery = gr.Video(height=512) + advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") + with gr.Row(elem_id="advanced-options"): + samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1) + steps = gr.Slider(label="Steps", minimum=1, maximum=500, value=50, step=1) + scale = gr.Slider( + label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 + ) + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=2147483647, + step=1, + randomize=True, + ) + gr.on([text.submit, btn.click], infer, inputs=[text, samples, steps, scale, seed, global_info], outputs=[global_info, gallery]) + advanced_button.click( + None, + [], + text, + ) + with gr.Column(): + with gr.Row(): + dropdown = gr.Dropdown( + ['0', '1', '2', '3'], label="Choose a Candidate For Stage2", value='0' + ) + btn_stage2 = gr.Button("Start Refinement") + gallery = gr.Video(height=512) + download = gr.File(label="Download Mesh", file_count="single", height=100) + gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info], outputs=[gallery, download]) + +block.launch(share=True) diff --git a/3DTopia/ldm/data/__init__.py b/3DTopia/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3DTopia/ldm/data/base.py b/3DTopia/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b196c2f7aa583a3e8bc4aad9f943df0c4dae0da7 --- /dev/null +++ b/3DTopia/ldm/data/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass \ No newline at end of file diff --git a/3DTopia/ldm/data/imagenet.py b/3DTopia/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790 --- /dev/null +++ b/3DTopia/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/3DTopia/ldm/data/lsun.py b/3DTopia/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e --- /dev/null +++ b/3DTopia/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/3DTopia/ldm/lr_scheduler.py b/3DTopia/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/3DTopia/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/3DTopia/ldm/models/autoencoder.py b/3DTopia/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9c4f45498561953b8085981609b2a3298a5473 --- /dev/null +++ b/3DTopia/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/3DTopia/ldm/models/diffusion/__init__.py b/3DTopia/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3DTopia/ldm/models/diffusion/classifier.py b/3DTopia/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef --- /dev/null +++ b/3DTopia/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/3DTopia/ldm/models/diffusion/ddim.py b/3DTopia/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..fb31215db5c3f3f703f15987d7eee6a179c9f7ec --- /dev/null +++ b/3DTopia/ldm/models/diffusion/ddim.py @@ -0,0 +1,241 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/3DTopia/ldm/models/diffusion/ddpm.py b/3DTopia/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..79a84e0632a2303d6a863e27a72e55a34fa629bb --- /dev/null +++ b/3DTopia/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1746 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import os +import wandb +import torch +import imageio +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from module.model_2d import DiagonalGaussianDistribution +from ldm.modules.distributions.distributions import normal_kl +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from utility.triplane_renderer.renderer import to8b + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + learning_rate=1e-4, + shift_scale=None, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.beta_schedule = beta_schedule + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, shift_scale=shift_scale) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.learning_rate = learning_rate + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s, shift_scale=shift_scale) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + print("Using timesteps of {}".format(self.num_timesteps)) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # print("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod)) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + elif self.parameterization == "v": + x_recon = self.predict_start_from_z_and_v(x, t, model_out) + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if isinstance(x, list): + return x + if len(x.shape) == 3: + x = x[..., None] + # x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=False, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=False, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=False, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_shift=0.0, + scale_by_std=False, + use_3daware=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + self.scale_shift = scale_shift + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + self.use_3daware = use_3daware + + self.is_test = False + + self.test_mode = None + self.test_tag = "" + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s, shift_scale) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.mode() + # z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * (z + self.scale_shift) + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c).float() + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + # assert not predict_cids + # if predict_cids: + # if z.dim() == 4: + # z = torch.argmax(z.exp(), dim=1).long() + # z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + # z = rearrange(z, 'b h w c -> b c h w').contiguous() + + # import os + # import random + # import string + # z_np = z.detach().cpu().numpy() + # fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' + # with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: + # np.save(f, z_np) + + z = 1. / self.scale_factor * z - self.scale_shift + + # if hasattr(self, "split_input_params"): + # if self.split_input_params["patch_distributed_vq"]: + # ks = self.split_input_params["ks"] # eg. (128, 128) + # stride = self.split_input_params["stride"] # eg. (64, 64) + # uf = self.split_input_params["vqf"] + # bs, nc, h, w = z.shape + # if ks[0] > h or ks[1] > w: + # ks = (min(ks[0], h), min(ks[1], w)) + # print("reducing Kernel") + + # if stride[0] > h or stride[1] > w: + # stride = (min(stride[0], h), min(stride[1], w)) + # print("reducing stride") + + # fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + # z = unfold(z) # (bn, nc * prod(**ks), L) + # # 1. Reshape to img shape + # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # # 2. apply model loop over last dim + # if isinstance(self.first_stage_model, VQModelInterface): + # output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + # force_not_quantize=predict_cids or force_not_quantize) + # for i in range(z.shape[-1])] + # else: + + # output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + # for i in range(z.shape[-1])] + + # o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + # o = o * weighting + # # Reverse 1. reshape to img shape + # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # # stitch crops together + # decoded = fold(o) + # decoded = decoded / normalization # norm is shape (1, 1, h, w) + # return decoded + # else: + # if isinstance(self.first_stage_model, VQModelInterface): + # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + # else: + # return self.first_stage_model.decode(z) + + # else: + # if isinstance(self.first_stage_model, VQModelInterface): + # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + # else: + return self.first_stage_model.decode(z, unrollout=True) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z - self.scale_shift + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + # if hasattr(self, "split_input_params"): + # if self.split_input_params["patch_distributed_vq"]: + # ks = self.split_input_params["ks"] # eg. (128, 128) + # stride = self.split_input_params["stride"] # eg. (64, 64) + # df = self.split_input_params["vqf"] + # self.split_input_params['original_image_size'] = x.shape[-2:] + # bs, nc, h, w = x.shape + # if ks[0] > h or ks[1] > w: + # ks = (min(ks[0], h), min(ks[1], w)) + # print("reducing Kernel") + + # if stride[0] > h or stride[1] > w: + # stride = (min(stride[0], h), min(stride[1], w)) + # print("reducing stride") + + # fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + # z = unfold(x) # (bn, nc * prod(**ks), L) + # # Reshape to img shape + # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + # for i in range(z.shape[-1])] + + # o = torch.stack(output_list, axis=-1) + # o = o * weighting + + # # Reverse reshape to img shape + # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # # stitch crops together + # decoded = fold(o) + # decoded = decoded / normalization + # return decoded + + # else: + # return self.first_stage_model.encode(x) + # else: + return self.first_stage_model.encode(x, rollout=True) + + def get_norm(self, x): + norm = torch.linalg.norm(x, dim=-1, keepdim=True) + norm[norm == 0] = 1 + + assert norm.shape[-1] == 1 + assert norm.shape[0] == x.shape[0] + assert norm.shape[1] == x.shape[1] + assert x.shape[1] == 1 + + return norm + + def random_text_feature_noise(self, c): + noise = torch.randn_like(c) + # alpha = 0.999 + alpha = 1 + nc = alpha * c / self.get_norm(c) + (1 - alpha) * noise / self.get_norm(noise) + nc = nc / self.get_norm(nc) + + import random + if random.randint(0, 10) == 0: + nc[:] = 0 + nc = c + + return nc + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + # print("Random augment text feature...") + c = self.random_text_feature_noise(c) + loss = self(x, c) + return loss + + def forward(self, x, c=None, return_inter=False, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, return_inter=return_inter, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.AvgPool2d((res, 1)) + y_mp = torch.nn.AvgPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + # B, C, H, W = h.shape + # h_xy = th.cat([h[..., 0:(W//3)], h[..., (W//3):(2*W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1) + # h_xz = th.cat([h[..., (W//3):(2*W//3)], h[..., 0:(W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3)], 1) + # h_zy = th.cat([h[..., (2*W//3):W], h[..., 0:(W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1), h[..., (W//3):(2*W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1) + # h = th.cat([h_xy, h_xz, h_zy], -1) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + if self.use_3daware: + x_noisy_3daware = self.to3daware(x_noisy) + x_recon = self.model(x_noisy_3daware, t, **cond) + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None, return_inter=False): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + if return_inter: + return loss, loss_dict, self.predict_start_from_noise(x_noisy, t=t, noise=model_output) + else: + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + elif self.parameterization == "v": + x_recon = self.predict_start_from_z_and_v(x, t, model_out) + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + # if self.is_test and i % 50 == 0: + # decode_res = self.decode_first_stage(img) + # rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + # decode_res, self.batch_rays, self.batch_img, + # ) + # rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_{}.png".format(i)), rgb_sample) + # colorize_res = self.first_stage_model.to_rgb(img) + # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_latent_{}.png".format(i)), colorize_res[0]) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size * 3) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # x, c = self.get_input(batch, self.first_stage_key) + # self.batch_rays = batch['batch_rays'][0][1:2] + # self.batch_img = batch['img'][0][1:2] + # self.is_test = True + # self.test_schedule(x[0:1]) + # exit(0) + + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + # _, loss_dict_ema = self.shared_step(batch) + x, c = self.get_input(batch, self.first_stage_key) + _, loss_dict_ema, inter_res = self(x, c, return_inter=True) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) + + if batch_idx < 2: + if self.num_timesteps < 1000: + x_T = self.q_sample(x_start=x[0:1], t=torch.full((1,), self.num_timesteps-1, device=x.device, dtype=torch.long), noise=torch.randn_like(x[0:1])) + print("Specifying x_T when sampling!") + else: + x_T = None + with self.ema_scope(): + res = self.sample(c, 1, shape=x[0:1].shape, x_T = x_T) + decode_res = self.decode_first_stage(res) + decode_input = self.decode_first_stage(x[:1]) + decode_output = self.decode_first_stage(inter_res[:1]) + + colorize_res = self.first_stage_model.to_rgb(res)[0] + colorize_x = self.first_stage_model.to_rgb(x[:1])[0] + # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0]) + + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res, batch['batch_rays'][0], batch['img'][0], + ) + rgb_input, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_input, batch['batch_rays'][0], batch['img'][0], + ) + rgb_output, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_output, batch['batch_rays'][0], batch['img'][0], + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy()) + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_output = to8b(rgb_output.detach().cpu().numpy()) + + if rgb_sample.shape[0] == 1: + rgb_all = np.concatenate([rgb_sample[0], rgb_input[0], rgb_output[0]], 1) + else: + rgb_all = np.concatenate([rgb_sample[1], rgb_input[1], rgb_output[1]], 1) + + rgb_all = np.stack([rgb_all[..., 2], rgb_all[..., 1], rgb_all[..., 0]], -1) + + if self.model.conditioning_key is not None: + if self.cond_stage_key == 'img_cond': + cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0) + rgb_all = np.concatenate([rgb_all, to8b(cond_img.cpu().numpy())], 1) + else: + import cv2 + font = cv2.FONT_HERSHEY_SIMPLEX + # org + org = (50, 50) + # fontScale + fontScale = 1 + # Blue color in BGR + color = (255, 0, 0) + # Line thickness of 2 px + thickness = 2 + caption = super().get_input(batch, self.cond_stage_key)[0] + break_caption = [] + for i in range(len(caption) // 30 + 1): + break_caption_i = caption[i*30:(i+1)*30] + break_caption.append(break_caption_i) + for i, bci in enumerate(break_caption): + cv2.putText(rgb_all, bci, (50, 50*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) + + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)], + "val/colorize_rse": [wandb.Image(colorize_res)], + "val/colorize_x": [wandb.Image(colorize_x)], + }) + + @torch.no_grad() + def test_schedule(self, x_start, freq=50): + noise = torch.randn_like(x_start) + img_list = [] + latent_list = [] + for t in tqdm(range(self.num_timesteps)): + if t % freq == 0: + t_long = torch.Tensor([t,]).long().to(x_start.device) + x_noisy = self.q_sample(x_start=x_start, t=t_long, noise=noise) + decode_res = self.decode_first_stage(x_noisy) + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res, self.batch_rays, self.batch_img, + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}.png".format(t)), rgb_sample) + colorize_res = self.first_stage_model.to_rgb(x_noisy) + # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}.png".format(t)), colorize_res[0]) + img_list.append(rgb_sample) + latent_list.append(colorize_res[0]) + imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(img_list, 1)) + imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(latent_list, 1)) + + @torch.no_grad() + def test_step(self, batch, batch_idx): + x, c = self.get_input(batch, self.first_stage_key) + if self.test_mode == 'fid': + bs = x.shape[0] + else: + bs = 1 + if self.test_mode == 'noise_schedule': + self.batch_rays = batch['batch_rays'][0][33:34] + self.batch_img = batch['img'][0][33:34] + self.is_test = True + self.test_schedule(x) + exit(0) + with self.ema_scope(): + if c is not None: + res = self.sample(c[:bs], bs, shape=x[0:bs].shape) + else: + res = self.sample(None, bs, shape=x[0:bs].shape) + decode_res = self.decode_first_stage(res) + if self.test_mode == 'fid': + folder = os.path.join(self.logger.log_dir, 'FID_' + self.test_tag) + if not os.path.exists(folder): + os.makedirs(folder, exist_ok=True) + rgb_sample_list = [] + for b in range(bs): + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy()) + rgb_sample_list.append(rgb_sample) + for i in range(len(rgb_sample_list)): + for v in range(rgb_sample_list[i].shape[0]): + imageio.imwrite(os.path.join(folder, "sample_{}_{}_{}.png".format(batch_idx, i, v)), rgb_sample_list[i][v]) + elif self.test_mode == 'sample': + colorize_res = self.first_stage_model.to_rgb(res) + colorize_x = self.first_stage_model.to_rgb(x[:1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0]) + imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0]) + if self.model.conditioning_key is not None: + cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0) + cond_img = to8b(cond_img.cpu().numpy()) + imageio.imwrite(os.path.join(self.logger.log_dir, "cond_{}_{}.png".format(batch_idx, 0)), cond_img) + for b in range(bs): + video = [] + for v in tqdm(range(batch['batch_rays'].shape[1])): + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res[b:b+1], batch['batch_rays'][0][v:v+1], batch['img'][0][v:v+1], + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + video.append(rgb_sample) + imageio.mimwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b)), video, fps=24) + print("Saving to {}".format(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b)))) + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/3DTopia/ldm/models/diffusion/ddpm_preprocess.py b/3DTopia/ldm/models/diffusion/ddpm_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..74926ca5dcfd7e4b1f283e7cd03de42ced5e74ee --- /dev/null +++ b/3DTopia/ldm/models/diffusion/ddpm_preprocess.py @@ -0,0 +1,1716 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import os +import wandb +import torch +import imageio +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from module.model_2d import DiagonalGaussianDistribution +from ldm.modules.distributions.distributions import normal_kl +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from utility.triplane_renderer.renderer import to8b + +import ipdb +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + learning_rate=1e-4, + shift_scale=None, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.beta_schedule = beta_schedule + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, shift_scale=shift_scale) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.learning_rate = learning_rate + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s, shift_scale=shift_scale) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + print("Using timesteps of {}".format(self.num_timesteps)) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # print("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod)) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + elif self.parameterization == "v": + x_recon = self.predict_start_from_z_and_v(x, t, model_out) + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if isinstance(x, list): + return x + # if len(x.shape) == 3: + # x = x[..., None] + # x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=False, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=False, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=False, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_shift=0.0, + scale_by_std=False, + use_3daware=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + self.scale_shift = scale_shift + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + # self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + self.use_3daware = use_3daware + + self.is_test = False + + self.test_mode = None + self.test_tag = "" + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s, shift_scale) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * (z + self.scale_shift) + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c).float() + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + #ipdb.set_trace() + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + #ipdb.set_trace() + z = x.to(self.device) #[:1,:8,:32,:96] + z = self.scale_factor * (z + self.scale_shift) + # encoder_posterior = self.encode_first_stage(x) + # z = self.get_first_stage_encoding(encoder_posterior).detach() + #ipdb.set_trace() + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + #ipdb.set_trace() + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + # assert not predict_cids + # if predict_cids: + # if z.dim() == 4: + # z = torch.argmax(z.exp(), dim=1).long() + # z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + # z = rearrange(z, 'b h w c -> b c h w').contiguous() + + # import os + # import random + # import string + # z_np = z.detach().cpu().numpy() + # fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' + # with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: + # np.save(f, z_np) + + z = 1. / self.scale_factor * z - self.scale_shift + + # if hasattr(self, "split_input_params"): + # if self.split_input_params["patch_distributed_vq"]: + # ks = self.split_input_params["ks"] # eg. (128, 128) + # stride = self.split_input_params["stride"] # eg. (64, 64) + # uf = self.split_input_params["vqf"] + # bs, nc, h, w = z.shape + # if ks[0] > h or ks[1] > w: + # ks = (min(ks[0], h), min(ks[1], w)) + # print("reducing Kernel") + + # if stride[0] > h or stride[1] > w: + # stride = (min(stride[0], h), min(stride[1], w)) + # print("reducing stride") + + # fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + # z = unfold(z) # (bn, nc * prod(**ks), L) + # # 1. Reshape to img shape + # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # # 2. apply model loop over last dim + # if isinstance(self.first_stage_model, VQModelInterface): + # output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + # force_not_quantize=predict_cids or force_not_quantize) + # for i in range(z.shape[-1])] + # else: + + # output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + # for i in range(z.shape[-1])] + + # o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + # o = o * weighting + # # Reverse 1. reshape to img shape + # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # # stitch crops together + # decoded = fold(o) + # decoded = decoded / normalization # norm is shape (1, 1, h, w) + # return decoded + # else: + # if isinstance(self.first_stage_model, VQModelInterface): + # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + # else: + # return self.first_stage_model.decode(z) + + # else: + # if isinstance(self.first_stage_model, VQModelInterface): + # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + # else: + return self.first_stage_model.decode(z, unrollout=True) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z - self.scale_shift + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + # if hasattr(self, "split_input_params"): + # if self.split_input_params["patch_distributed_vq"]: + # ks = self.split_input_params["ks"] # eg. (128, 128) + # stride = self.split_input_params["stride"] # eg. (64, 64) + # df = self.split_input_params["vqf"] + # self.split_input_params['original_image_size'] = x.shape[-2:] + # bs, nc, h, w = x.shape + # if ks[0] > h or ks[1] > w: + # ks = (min(ks[0], h), min(ks[1], w)) + # print("reducing Kernel") + + # if stride[0] > h or stride[1] > w: + # stride = (min(stride[0], h), min(stride[1], w)) + # print("reducing stride") + + # fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + # z = unfold(x) # (bn, nc * prod(**ks), L) + # # Reshape to img shape + # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + # for i in range(z.shape[-1])] + + # o = torch.stack(output_list, axis=-1) + # o = o * weighting + + # # Reverse reshape to img shape + # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # # stitch crops together + # decoded = fold(o) + # decoded = decoded / normalization + # return decoded + + # else: + # return self.first_stage_model.encode(x) + # else: + return self.first_stage_model.encode(x, rollout=True) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, cond=None, return_inter=False, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + #ipdb.set_trace() + if self.model.conditioning_key is not None: + assert cond is not None + if self.cond_stage_trainable: + cond = self.get_learned_conditioning(cond) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond.float())) + #ipdb.set_trace() + return self.p_losses(x, cond, t, return_inter=return_inter, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.AvgPool2d((res, 1)) + y_mp = torch.nn.AvgPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + # B, C, H, W = h.shape + # h_xy = th.cat([h[..., 0:(W//3)], h[..., (W//3):(2*W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1) + # h_xz = th.cat([h[..., (W//3):(2*W//3)], h[..., 0:(W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3)], 1) + # h_zy = th.cat([h[..., (2*W//3):W], h[..., 0:(W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1), h[..., (W//3):(2*W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1) + # h = th.cat([h_xy, h_xz, h_zy], -1) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + #ipdb.set_trace() + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + if self.use_3daware: + x_noisy_3daware = self.to3daware(x_noisy) + x_recon = self.model(x_noisy_3daware, t, **cond) + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None, return_inter=False): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + if return_inter: + return loss, loss_dict, self.predict_start_from_noise(x_noisy, t=t, noise=model_output) + else: + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + elif self.parameterization == "v": + x_recon = self.predict_start_from_z_and_v(x, t, model_out) + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + # if self.is_test and i % 50 == 0: + # decode_res = self.decode_first_stage(img) + # rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + # decode_res, self.batch_rays, self.batch_img, + # ) + # rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_{}.png".format(i)), rgb_sample) + # colorize_res = self.first_stage_model.to_rgb(img) + # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_latent_{}.png".format(i)), colorize_res[0]) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size * 3) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # x, c = self.get_input(batch, self.first_stage_key) + # self.batch_rays = batch['batch_rays'][0][1:2] + # self.batch_img = batch['img'][0][1:2] + # self.is_test = True + # self.test_schedule(x[0:1]) + # exit(0) + + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + # _, loss_dict_ema = self.shared_step(batch) + x, c = self.get_input(batch, self.first_stage_key) + _, loss_dict_ema, inter_res = self(x, c, return_inter=True) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) + + if batch_idx < 2: + if self.num_timesteps < 1000: + x_T = self.q_sample(x_start=x[0:1], t=torch.full((1,), self.num_timesteps-1, device=x.device, dtype=torch.long), noise=torch.randn_like(x[0:1])) + print("Specifying x_T when sampling!") + else: + x_T = None + with self.ema_scope(): + res = self.sample(c, 1, shape=x[0:1].shape, x_T = x_T) + decode_res = self.decode_first_stage(res) + decode_input = self.decode_first_stage(x[:1]) + decode_output = self.decode_first_stage(inter_res[:1]) + + colorize_res = self.first_stage_model.to_rgb(res)[0] + colorize_x = self.first_stage_model.to_rgb(x[:1])[0] + # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0]) + + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res, batch['batch_rays'][0], batch['img'][0], + ) + rgb_input, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_input, batch['batch_rays'][0], batch['img'][0], + ) + rgb_output, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_output, batch['batch_rays'][0], batch['img'][0], + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy()) + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_output = to8b(rgb_output.detach().cpu().numpy()) + + if rgb_sample.shape[0] == 1: + rgb_all = np.concatenate([rgb_sample[0], rgb_input[0], rgb_output[0]], 1) + else: + rgb_all = np.concatenate([rgb_sample[1], rgb_input[1], rgb_output[1]], 1) + + + if self.model.conditioning_key is not None: + if self.cond_stage_key == 'img_cond': + cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0) + rgb_all = np.concatenate([rgb_all, to8b(cond_img.cpu().numpy())], 1) + elif 'caption' in self.cond_stage_key: + import cv2 + font = cv2.FONT_HERSHEY_SIMPLEX + # org + org = (50, 50) + # fontScale + fontScale = 1 + # Blue color in BGR + color = (255, 0, 0) + # Line thickness of 2 px + thickness = 2 + caption = super().get_input(batch, 'caption')[0] + break_caption = [] + for i in range(len(caption) // 30 + 1): + break_caption_i = caption[i*30:(i+1)*30] + break_caption.append(break_caption_i) + for i, bci in enumerate(break_caption): + cv2.putText(rgb_all, bci, (50, 50*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) + + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)], + "val/colorize_rse": [wandb.Image(colorize_res)], + "val/colorize_x": [wandb.Image(colorize_x)], + }) + + @torch.no_grad() + def test_schedule(self, x_start, freq=50): + noise = torch.randn_like(x_start) + img_list = [] + latent_list = [] + for t in tqdm(range(self.num_timesteps)): + if t % freq == 0: + t_long = torch.Tensor([t,]).long().to(x_start.device) + x_noisy = self.q_sample(x_start=x_start, t=t_long, noise=noise) + decode_res = self.decode_first_stage(x_noisy) + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res, self.batch_rays, self.batch_img, + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}.png".format(t)), rgb_sample) + colorize_res = self.first_stage_model.to_rgb(x_noisy) + # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}.png".format(t)), colorize_res[0]) + img_list.append(rgb_sample) + latent_list.append(colorize_res[0]) + imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(img_list, 1)) + imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(latent_list, 1)) + + @torch.no_grad() + def test_step(self, batch, batch_idx): + x, c = self.get_input(batch, self.first_stage_key) + if self.test_mode == 'fid': + bs = x.shape[0] + else: + bs = 1 + if self.test_mode == 'noise_schedule': + self.batch_rays = batch['batch_rays'][0][33:34] + self.batch_img = batch['img'][0][33:34] + self.is_test = True + self.test_schedule(x) + exit(0) + with self.ema_scope(): + if c is not None: + res = self.sample(c[:bs], bs, shape=x[0:bs].shape) + else: + res = self.sample(None, bs, shape=x[0:bs].shape) + decode_res = self.decode_first_stage(res) + if self.test_mode == 'fid': + folder = os.path.join(self.logger.log_dir, 'FID_' + self.test_tag) + if not os.path.exists(folder): + os.makedirs(folder, exist_ok=True) + rgb_sample_list = [] + for b in range(bs): + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy()) + rgb_sample_list.append(rgb_sample) + for i in range(len(rgb_sample_list)): + for v in range(rgb_sample_list[i].shape[0]): + imageio.imwrite(os.path.join(folder, "sample_{}_{}_{}.png".format(batch_idx, i, v)), rgb_sample_list[i][v]) + elif self.test_mode == 'sample': + colorize_res = self.first_stage_model.to_rgb(res) + colorize_x = self.first_stage_model.to_rgb(x[:1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0]) + imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0]) + if self.model.conditioning_key is not None: + cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0) + cond_img = to8b(cond_img.cpu().numpy()) + imageio.imwrite(os.path.join(self.logger.log_dir, "cond_{}_{}.png".format(batch_idx, 0)), cond_img) + for b in range(bs): + video = [] + for v in tqdm(range(batch['batch_rays'].shape[1])): + rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder( + decode_res[b:b+1], batch['batch_rays'][0][v:v+1], batch['img'][0][v:v+1], + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + video.append(rgb_sample) + imageio.mimwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b)), video, fps=24) + print("Saving to {}".format(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b)))) + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/3DTopia/ldm/models/diffusion/dpm_solver/__init__.py b/3DTopia/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08 --- /dev/null +++ b/3DTopia/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/3DTopia/ldm/models/diffusion/dpm_solver/dpm_solver.py b/3DTopia/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb64e0c78cc3520f92d79db3124c85fc3cfb9b4 --- /dev/null +++ b/3DTopia/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1184 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/3DTopia/ldm/models/diffusion/dpm_solver/sampler.py b/3DTopia/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..2c42d6f964d92658e769df95a81dec92250e5a99 --- /dev/null +++ b/3DTopia/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,82 @@ +"""SAMPLING ONLY.""" + +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None diff --git a/3DTopia/ldm/models/diffusion/plms.py b/3DTopia/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..78eeb1003aa45d27bdbfc6b4a1d7ccbff57cd2e3 --- /dev/null +++ b/3DTopia/ldm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/3DTopia/ldm/modules/attention.py b/3DTopia/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f4eff39ccb6d75daa764f6eb70a7cef024fb5a3f --- /dev/null +++ b/3DTopia/ldm/modules/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/3DTopia/ldm/modules/diffusionmodules/__init__.py b/3DTopia/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3DTopia/ldm/modules/diffusionmodules/model.py b/3DTopia/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..533e589a2024f1d7c52093d8c472c3b1b6617e26 --- /dev/null +++ b/3DTopia/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/3DTopia/ldm/modules/diffusionmodules/openaimodel.py b/3DTopia/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..43edcfbcad0cc3e26d9979734a67b1ca0e593392 --- /dev/null +++ b/3DTopia/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,965 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, ch, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, ch, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + # if h.shape[1] == 640: + # # h[:, :320] = h[:, :320] * 1.4 + # # print("here") + # h = h * 1.2 + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/3DTopia/ldm/modules/diffusionmodules/triplane_3daware_unet.py b/3DTopia/ldm/modules/diffusionmodules/triplane_3daware_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..1358d5f9e24d47e8e8142c87292374c755fd7e53 --- /dev/null +++ b/3DTopia/ldm/modules/diffusionmodules/triplane_3daware_unet.py @@ -0,0 +1,991 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + # print("Using 3d aware resblock!") + + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels * 3, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = th.nn.AvgPool2d((res, 1)) + y_mp = th.nn.AvgPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = th.flip(y_mp_rep(plane3), (3,)) + new_plane1 = th.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = th.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = th.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = th.cat([plane3, plane13, plane23], 1) + + new_plane = th.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = self.to3daware(h) + h = out_rest(h) + else: + h = h + emb_out + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + h = self.to3daware(out_norm(h)) + h = out_rest(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/3DTopia/ldm/modules/diffusionmodules/triplane_context_crossattention_unet.py b/3DTopia/ldm/modules/diffusionmodules/triplane_context_crossattention_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..4a846be183fba569d7178e2fd34cc145f9669cec --- /dev/null +++ b/3DTopia/ldm/modules/diffusionmodules/triplane_context_crossattention_unet.py @@ -0,0 +1,1126 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from inspect import isfunction + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, TriplaneAttentionBlock): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + +def exists(val): + return val is not None + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + x = x.permute(0, 2, 1) + context = context.permute(0, 2, 1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -th.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out).permute(0, 2, 1) + +class CrossAttentionContext(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + + # import pdb; pdb.set_trace() + + h = self.heads + + x = x.permute(0, 2, 1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out).permute(0, 2, 1) + + +class TriplaneAttentionBlock(nn.Module): + def __init__( + self, + channels, + context_channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + # self.norm = normalization(channels) + self.norm1 = normalization(channels) + self.norm2 = normalization(channels) + self.norm3 = normalization(channels) + self.norm4 = normalization(channels) + self.norm5 = normalization(channels) + # self.norm6 = normalization(context_channels) + # self.norm1 = nn.LayerNorm(channels) + # self.norm2 = nn.LayerNorm(channels) + # self.norm3 = nn.LayerNorm(channels) + # self.norm4 = nn.LayerNorm(channels) + # self.norm5 = nn.LayerNorm(channels) + # self.norm6 = nn.LayerNorm(context_channels) + + self.plane1_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + self.plane2_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + self.plane3_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + + self.context_ca = CrossAttentionContext(channels, context_channels, self.num_heads, num_head_channels) + + def forward(self, x, context): + return checkpoint(self._forward, (x, context), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x, context): + b, c, *spatial = x.shape + res = x.shape[-2] + plane1 = x[..., :res].reshape(b, c, -1) + plane2 = x[..., res:res*2].reshape(b, c, -1) + plane3 = x[..., 2*res:3*res].reshape(b, c, -1) + x = x.reshape(b, c, -1) + + plane1_output = self.plane1_ca(self.norm1(plane1), self.norm4(x)) + plane2_output = self.plane2_ca(self.norm2(plane2), self.norm4(x)) + plane3_output = self.plane3_ca(self.norm3(plane3), self.norm4(x)) + + h = th.cat([plane1_output, plane2_output, plane3_output], -1) + + h = self.context_ca(self.norm5(h), context=context) + + return (x + h).reshape(b, c, *spatial) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + TriplaneAttentionBlock( + ch, + context_dim, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + TriplaneAttentionBlock( + ch, + context_dim, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + TriplaneAttentionBlock( + ch, + context_dim, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, ch, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, ch, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/3DTopia/ldm/modules/diffusionmodules/triplane_crossattention_unet.py b/3DTopia/ldm/modules/diffusionmodules/triplane_crossattention_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e489ef26704b530453d0492429a260a3f13c4d9a --- /dev/null +++ b/3DTopia/ldm/modules/diffusionmodules/triplane_crossattention_unet.py @@ -0,0 +1,1058 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from inspect import isfunction + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + +def exists(val): + return val is not None + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + x = x.permute(0, 2, 1) + context = context.permute(0, 2, 1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -th.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out).permute(0, 2, 1) + + +class TriplaneAttentionBlock(nn.Module): + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + + self.plane1_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + self.plane2_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + self.plane3_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + res = x.shape[-2] + plane1 = x[..., :res].reshape(b, c, -1) + plane2 = x[..., res:res*2].reshape(b, c, -1) + plane3 = x[..., 2*res:3*res].reshape(b, c, -1) + x = x.reshape(b, c, -1) + + plane1_output = self.plane1_ca(self.norm(plane1), self.norm(x)) + plane2_output = self.plane2_ca(self.norm(plane2), self.norm(x)) + plane3_output = self.plane3_ca(self.norm(plane3), self.norm(x)) + + h = th.cat([plane1_output, plane2_output, plane3_output], -1) + + return (x + h).reshape(b, c, *spatial) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + TriplaneAttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + TriplaneAttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + TriplaneAttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, ch, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, ch, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/3DTopia/ldm/modules/diffusionmodules/util.py b/3DTopia/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..1fe3bb31158f0184505b044ba1cbffe1ede3375e --- /dev/null +++ b/3DTopia/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,305 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + +def force_zero_snr(betas): + alphas = 1 - betas + alphas_bar = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_bar ** (1/2) + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - 1e-6 + alphas_bar_sqrt -= alphas_bar_sqrt_T + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas], 0) + betas = 1 - alphas + return betas + +def shift_schedule(base_betas, shift_scale): + alphas = 1 - base_betas + alphas_bar = torch.cumprod(alphas, dim=0) + snr = alphas_bar / (1 - alphas_bar) # snr(1-ab)=ab; snr-snr*ab=ab; snr=(1+snr)ab; ab=snr/(1+snr) + shifted_snr = snr * ((1 / shift_scale) ** 2) + shifted_alphas_bar = shifted_snr / (1 + shifted_snr) + shifted_alphas = shifted_alphas_bar[1:] / shifted_alphas_bar[:-1] + shifted_alphas = torch.cat([shifted_alphas_bar[0:1], shifted_alphas], 0) + shifted_betas = 1 - shifted_alphas + return shifted_betas + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + elif schedule == 'linear_force_zero_snr': + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + betas = force_zero_snr(betas) + elif schedule == 'linear_100': + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + betas = betas[:100] + else: + raise ValueError(f"schedule '{schedule}' unknown.") + + if shift_scale is not None: + betas = shift_schedule(betas, shift_scale) + + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/3DTopia/ldm/modules/distributions/__init__.py b/3DTopia/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3DTopia/ldm/modules/distributions/distributions.py b/3DTopia/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/3DTopia/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/3DTopia/ldm/modules/ema.py b/3DTopia/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/3DTopia/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/3DTopia/ldm/modules/encoders/__init__.py b/3DTopia/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/3DTopia/ldm/modules/encoders/modules.py b/3DTopia/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..73508813bb3942f914553f4988606d4917f4aaf8 --- /dev/null +++ b/3DTopia/ldm/modules/encoders/modules.py @@ -0,0 +1,386 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List +from pkg_resources import packaging + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer + +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load", "tokenize"] +_tokenizer = _Tokenizer() + +def tokenize_with_truncation(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + + context_length : int + The context length to use; all CLIP models use 77 as the context length + + truncate: bool + Whether to truncate the text in case its encoding is longer than the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder["<|startoftext|>"] + eot_token = _tokenizer.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + if truncate: + tokens = tokens[:context_length] + tokens[-1] = eot_token + else: + raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + # tokens = clip.tokenize(text).to(self.device) + tokens = tokenize_with_truncation(text, truncate=True).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + # self.model, _ = clip.load(name=model, device=device, jit=jit) + self.model, _ = clip.load(name=model, device=device) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + # x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + z = self.model.encode_image(self.preprocess(x)) + if z.ndim==2: + z = z[:, None, :] + return z + + +############### OPENCLIP ################# +import open_clip + +class OpenClipTextEmbedder(nn.Module): + def __init__( + self, + model='ViT-bigG-14', + pretrained='laion2b_s39b_b160k', + device='cuda' if torch.cuda.is_available() else 'cpu', + normalize=True,): + super().__init__() + self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained, device='cpu') + self.tokenizer = open_clip.get_tokenizer(model) + self.normalize = normalize + self.device = device + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tok_text = self.tokenizer(text).to(self.device) + z = self.model.encode_text(tok_text) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=1) + return z + +class OpenClipImageEmbedder(nn.Module): + def __init__( + self, + model='ViT-bigG-14', + pretrained='laion2b_s39b_b160k', + device='cuda' if torch.cuda.is_available() else 'cpu', + ): + super().__init__() + self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained, device=device) + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + z = self.model.encode_image(self.preprocess(x)) + if z.ndim==2: + z = z[:, None, :] + return z + +class DinoV2(nn.Module): + def __init__(self, model='dinov2_vitb14', ckpt='dino_ckpt/dinov2_vitb14_pretrain.pth'): + super().__init__() + # device='cuda' if torch.cuda.is_available() else 'cpu' + # self.model = torch.hub.load('facebookresearch/dinov2', model) + self.model = torch.hub.load('dinov2', model, source='local', pretrained=False) + self.model.load_state_dict(torch.load(ckpt)) + # self.model = self.model.to(device) + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]), persistent=False) + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]), persistent=False) + + def preprocess(self, x): + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=False) + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + return self.model.forward_features(self.preprocess(x))['x_norm_patchtokens'] + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/3DTopia/ldm/modules/image_degradation/__init__.py b/3DTopia/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/3DTopia/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/3DTopia/ldm/modules/image_degradation/bsrgan.py b/3DTopia/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/3DTopia/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/3DTopia/ldm/modules/image_degradation/bsrgan_light.py b/3DTopia/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1f823996bf559e9b015ea9aa2b3cd38dd13af1 --- /dev/null +++ b/3DTopia/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/3DTopia/ldm/modules/image_degradation/utils/test.png b/3DTopia/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/3DTopia/ldm/modules/image_degradation/utils/test.png differ diff --git a/3DTopia/ldm/modules/image_degradation/utils_image.py b/3DTopia/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/3DTopia/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/3DTopia/ldm/modules/losses/__init__.py b/3DTopia/ldm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5 --- /dev/null +++ b/3DTopia/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/3DTopia/ldm/modules/losses/contperceptual.py b/3DTopia/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e --- /dev/null +++ b/3DTopia/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/3DTopia/ldm/modules/losses/vqperceptual.py b/3DTopia/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306 --- /dev/null +++ b/3DTopia/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/3DTopia/ldm/modules/x_transformer.py b/3DTopia/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/3DTopia/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/3DTopia/ldm/util.py b/3DTopia/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba38853e7a07228cc2c187742b5c45d7359b3f9 --- /dev/null +++ b/3DTopia/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/3DTopia/model/auto_regressive.py b/3DTopia/model/auto_regressive.py new file mode 100644 index 0000000000000000000000000000000000000000..e9de51a3596cf5d226c4bf5bf3745d784c977510 --- /dev/null +++ b/3DTopia/model/auto_regressive.py @@ -0,0 +1,412 @@ +import imageio +import os, math +import wandb +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from utility.initialize import instantiate_from_config +from taming.modules.util import SOSProvider +from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr +import numpy as np + +from tqdm import tqdm + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class Net2NetTransformer(pl.LightningModule): + def __init__(self, + transformer_config, + first_stage_config, + cond_stage_config, + permuter_config=None, + ckpt_path=None, + ignore_keys=[], + first_stage_key="triplane", + cond_stage_key="depth", + downsample_cond_size=-1, + pkeep=1.0, + sos_token=0, + unconditional=True, + learning_rate=1e-4, + ): + super().__init__() + self.be_unconditional = unconditional + self.sos_token = sos_token + self.first_stage_key = first_stage_key + # self.cond_stage_key = cond_stage_key + self.init_first_stage_from_ckpt(first_stage_config) + # self.init_cond_stage_from_ckpt(cond_stage_config) + if permuter_config is None: + permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} + self.permuter = instantiate_from_config(config=permuter_config) + self.transformer = instantiate_from_config(config=transformer_config) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.downsample_cond_size = downsample_cond_size + self.pkeep = pkeep + self.learning_rate = learning_rate + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + for k in sd.keys(): + for ik in ignore_keys: + if k.startswith(ik): + self.print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def init_first_stage_from_ckpt(self, config): + model = instantiate_from_config(config) + # model = model.eval() + # model.train = disabled_train + + self.first_stage_model = model + + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + self.first_stage_model.vector_quantizer.training = False + self.first_stage_model.vector_quantizer.embedding.update = False + + def init_cond_stage_from_ckpt(self, config): + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__" or self.be_unconditional: + print(f"Using no cond stage. Assuming the training is intended to be unconditional. " + f"Prepending {self.sos_token} as a sos token.") + self.be_unconditional = True + self.cond_stage_key = self.first_stage_key + self.cond_stage_model = SOSProvider(self.sos_token) + else: + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.cond_stage_model = model + + def forward(self, x, c): + # one step to produce the logits + _, z_indices = self.encode_to_z(x) + # _, c_indices = self.encode_to_c(c) + + if self.training and self.pkeep < 1.0: + mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, + device=z_indices.device)) + mask = mask.round().to(dtype=torch.int64) + r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) + a_indices = mask*z_indices+(1-mask)*r_indices + else: + a_indices = z_indices + + c_indices = torch.zeros_like(z_indices[:, 0:1]) + self.transformer.config.vocab_size - 1 + cz_indices = torch.cat((c_indices, a_indices), dim=1) + + # target includes all sequence elements (no need to handle first one + # differently because we are conditioning) + target = z_indices + # make the prediction + logits, _ = self.transformer(cz_indices[:, :-1]) + # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: + c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) + quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) + if len(indices.shape) > 2: + indices = indices.view(c.shape[0], -1) + return quant_c, indices + + # @torch.no_grad() + # def decode_to_img(self, index, zshape): + # index = self.permuter(index, reverse=True) + # bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) + # quant_z = self.first_stage_model.quantize.get_codebook_entry( + # index.reshape(-1), shape=bhwc) + # x = self.first_stage_model.decode(quant_z) + # return x + + @torch.no_grad() + def decode_to_triplane(self, index, zshape): + quant_z = self.first_stage_model.vector_quantizer.dequantize(index) + quant_z = quant_z.reshape(zshape[0], zshape[2], zshape[3], zshape[1]) + quant_z = quant_z.permute(0, 3, 1, 2) + z = self.first_stage_model.decode(quant_z) + return z + + @torch.no_grad() + def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): + log = dict() + + N = 2 + if lr_interface: + x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) + else: + x, c = self.get_xc(batch, N) + x = x.to(device=self.device) + # c = c.to(device=self.device) + log["inputs"] = self.render_triplane(x, batch) + + quant_z, z_indices = self.encode_to_z(x) + # quant_c, c_indices = self.encode_to_c(c) + c_indices = torch.zeros_like(z_indices[:, 0:1]) + self.transformer.config.vocab_size - 1 + + # create a "half"" sample + z_start_indices = z_indices[:,:z_indices.shape[1]//2] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1]-z_start_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample = self.first_stage_model.unrollout(self.decode_to_triplane(index_sample, quant_z.shape)) + log["samples_half"] = self.render_triplane(x_sample, batch) + + # sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample_nopix = self.first_stage_model.unrollout(self.decode_to_triplane(index_sample, quant_z.shape)) + log["samples_nopix"] = self.render_triplane(x_sample_nopix, batch) + + # # det sample + # z_start_indices = z_indices[:, :0] + # index_sample = self.sample(z_start_indices, c_indices, + # steps=z_indices.shape[1], + # sample=False, + # callback=callback if callback is not None else lambda k: None) + # x_sample_det = self.first_stage_model.unrollout(self.decode_to_triplane(index_sample, quant_z.shape)) + # log["samples_det"] = self.render_triplane(x_sample_det, batch) + + # reconstruction + x_rec = self.first_stage_model.unrollout(self.decode_to_triplane(z_indices, quant_z.shape)) + # x_rec = self.first_stage_model.unrollout(self.first_stage_model(self.first_stage_model.rollout(x))[0]) + log["reconstructions"] = self.render_triplane(x_rec, batch) + + # if self.cond_stage_key in ["objects_bbox", "objects_center_points"]: + # figure_size = (x_rec.shape[2], x_rec.shape[3]) + # dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"] + # label_for_category_no = dataset.get_textual_label_for_category_no + # plotter = dataset.conditional_builders[self.cond_stage_key].plot + # log["conditioning"] = torch.zeros_like(log["reconstructions"]) + # for i in range(quant_c.shape[0]): + # log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size) + # log["conditioning_rec"] = log["conditioning"] + # elif self.cond_stage_key != "image": + # cond_rec = self.cond_stage_model.decode(quant_c) + # if self.cond_stage_key == "segmentation": + # # get image from segmentation mask + # num_classes = cond_rec.shape[1] + + # c = torch.argmax(c, dim=1, keepdim=True) + # c = F.one_hot(c, num_classes=num_classes) + # c = c.squeeze(1).permute(0, 3, 1, 2).float() + # c = self.cond_stage_model.to_rgb(c) + + # cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) + # cond_rec = F.one_hot(cond_rec, num_classes=num_classes) + # cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() + # cond_rec = self.cond_stage_model.to_rgb(cond_rec) + # log["conditioning_rec"] = cond_rec + # log["conditioning"] = c + + return log + + def render_triplane(self, triplane, batch): + batch_size = triplane.shape[0] + rgb_list = [] + for b in range(batch_size): + rgb, cur_psnr_list = self.first_stage_model.render_triplane_eg3d_decoder( + triplane[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb = to8b(rgb.detach().cpu().numpy()) + rgb_list.append(rgb[1]) + + return np.stack(rgb_list, 0) + + def get_input(self, key, batch): + x = batch[key] + # if len(x.shape) == 3: + # x = x[..., None] + # if len(x.shape) == 4: + # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + # if x.dtype == torch.double: + # x = x.float() + return x + + def get_xc(self, batch, N=None): + x = self.get_input(self.first_stage_key, batch) + # c = self.get_input(self.cond_stage_key, batch) + if N is not None: + x = x[:N] + # c = c[:N] + return x, None + + def shared_step(self, batch, batch_idx): + x, c = self.get_xc(batch) + logits, target = self(x, c) + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) + return loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + if batch_idx == 0: + imgs = self.log_images(batch) + for i in range(imgs['inputs'].shape[0]): + self.logger.experiment.log({ + "val/vis/inputs": [wandb.Image(imgs['inputs'][i])], + "val/vis/reconstructions": [wandb.Image(imgs['reconstructions'][i])], + "val/vis/samples_half": [wandb.Image(imgs['samples_half'][i])], + "val/vis/samples_nopix": [wandb.Image(imgs['samples_nopix'][i])], + # "val/vis/samples_det": [wandb.Image(imgs['samples_det'][i])], + }) + return loss + + def test_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + imgs = self.log_images(batch, temperature=1.8) + print("Saved to {}".format(self.logger.log_dir)) + for i in range(imgs['inputs'].shape[0]): + imageio.imwrite(os.path.join(self.logger.log_dir, "inputs_{}_{}.png".format(batch_idx, i)), imgs['inputs'][i]) + imageio.imwrite(os.path.join(self.logger.log_dir, "reconstructions_{}_{}.png".format(batch_idx, i)), imgs['reconstructions'][i]) + imageio.imwrite(os.path.join(self.logger.log_dir, "samples_half_{}_{}.png".format(batch_idx, i)), imgs['samples_half'][i]) + imageio.imwrite(os.path.join(self.logger.log_dir, "samples_nopix_{}_{}.png".format(batch_idx, i)), imgs['samples_nopix'][i]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "samples_det_{}_{}.png".format(batch_idx, i)), imgs['samples_det'][i]) + return loss + + def configure_optimizers(self): + """ + Following minGPT: + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.transformer.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.transformer.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) + return optimizer diff --git a/3DTopia/model/sv_vae_triplane.py b/3DTopia/model/sv_vae_triplane.py new file mode 100644 index 0000000000000000000000000000000000000000..60f9e51edc84df596d3477c25f3aeeb5b8e7150b --- /dev/null +++ b/3DTopia/model/sv_vae_triplane.py @@ -0,0 +1,111 @@ +import os +import imageio +import numpy as np +import torch +import torchvision +import torch.nn as nn +import pytorch_lightning as pl +import wandb + +import lpips +from pytorch_msssim import SSIM + +from utility.initialize import instantiate_from_config + +class VAE(pl.LightningModule): + def __init__(self, vae_configs, renderer_configs, lr=1e-3, weight_decay=1e-2, + kld_weight=1, mse_weight=1, lpips_weight=0.1, ssim_weight=0.1, + log_image_freq=50): + super().__init__() + self.save_hyperparameters() + + self.lr = lr + self.weight_decay = weight_decay + self.kld_weight = kld_weight + self.mse_weight = mse_weight + self.lpips_weight = lpips_weight + self.ssim_weight = ssim_weight + self.log_image_freq = log_image_freq + + self.vae = instantiate_from_config(vae_configs) + self.renderer = instantiate_from_config(renderer_configs) + + self.lpips_fn = lpips.LPIPS(net='alex') + self.ssim_fn = SSIM(data_range=1, size_average=True, channel=3) + + self.triplane_render_kwargs = { + 'depth_resolution': 64, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 64, + 'clamp_mode': 'softplus', + 'white_back': True, + } + + def forward(self, batch, is_train): + encoder_img, input_img, input_ray_o, input_ray_d, \ + target_img, target_ray_o, target_ray_d = batch + grid, mu, logvar = self.vae(encoder_img, is_train) + + cat_ray_o = torch.cat([input_ray_o, target_ray_o], 0) + cat_ray_d = torch.cat([input_ray_d, target_ray_d], 0) + render_out = self.renderer(torch.cat([grid, grid], 0), cat_ray_o, cat_ray_d, self.triplane_render_kwargs) + render_gt = torch.cat([input_img, target_img], 0) + + return render_out['rgb_marched'], render_out['depth_final'], \ + render_out['weights'], mu, logvar, render_gt + + def calc_loss(self, render, mu, logvar, render_gt): + mse = torch.mean((render - render_gt) ** 2) + ssim_loss = 1 - self.ssim_fn(render, render_gt) + lpips_loss = self.lpips_fn((render * 2) - 1, (render_gt * 2) - 1).mean() + kld_loss = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1)) + + loss = self.mse_weight * mse + self.ssim_weight * ssim_loss + \ + self.lpips_weight * lpips_loss + self.kld_weight * kld_loss + + return { + 'loss': loss, + 'mse': mse, + 'ssim': ssim_loss, + 'lpips': lpips_loss, + 'kld': kld_loss, + } + + def log_dict(self, loss_dict, prefix): + for k, v in loss_dict.items(): + self.log(prefix + k, v, on_step=True, logger=True) + + def make_grid(self, render, depth, render_gt): + bs = render.shape[0] // 2 + grid = torchvision.utils.make_grid( + torch.stack([render_gt[0], render_gt[bs], render[0], depth[0], render[bs], depth[bs]], 0)) + grid = (grid.detach().cpu().permute(1, 2, 0) * 255.).numpy().astype(np.uint8) + return grid + + def training_step(self, batch, batch_idx): + render, depth, weights, mu, logvar, render_gt = self.forward(batch, True) + loss_dict = self.calc_loss(render, mu, logvar, render_gt) + self.log_dict(loss_dict, 'train/') + if batch_idx % self.log_image_freq == 0: + self.logger.experiment.log({ + 'train/vis': [wandb.Image(self.make_grid( + render, depth, render_gt + ))] + }) + return loss_dict['loss'] + + def validation_step(self, batch, batch_idx): + render, depth, _, mu, logvar, render_gt = self.forward(batch, False) + loss_dict = self.calc_loss(render, mu, logvar, render_gt) + self.log_dict(loss_dict, 'val/') + if batch_idx % self.log_image_freq == 0: + self.logger.experiment.log({ + 'val/vis': [wandb.Image(self.make_grid( + render, depth, render_gt + ))] + }) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + return optimizer diff --git a/3DTopia/model/triplane_vae.py b/3DTopia/model/triplane_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..43ff1973d3ed49fa3bd07c9527a6650cb3b4b393 --- /dev/null +++ b/3DTopia/model/triplane_vae.py @@ -0,0 +1,2656 @@ +import os +import imageio +import torch +import wandb +import numpy as np +import pytorch_lightning as pl +import torch.nn.functional as F + +from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion +from utility.initialize import instantiate_from_config +from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr +from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + learning_rate=1e-3, + ckpt_path=None, + ignore_keys=[], + colorize_nlabels=None, + monitor=None, + decoder_ckpt=None, + norm=False, + renderer_type='nerf', + renderer_config=dict( + rgbnet_dim=18, + rgbnet_width=128, + viewpe=0, + feape=0 + ), + ): + super().__init__() + self.save_hyperparameters() + self.norm = norm + self.renderer_config = renderer_config + self.learning_rate = learning_rate + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + # self.loss = instantiate_from_config(lossconfig) + self.lossconfig = lossconfig + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.decoder_ckpt = decoder_ckpt + self.renderer_type = renderer_type + # if decoder_ckpt is not None: + assert self.renderer_type in ['nerf', 'eg3d'] + if self.renderer_type == 'nerf': + self.triplane_decoder, self.triplane_render_kwargs = self.create_nerf(decoder_ckpt) + elif self.renderer_type == 'eg3d': + self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt) + else: + raise NotImplementedError + + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + self.latent_list = [] + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x, rollout=False): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def unrollout(self, *args, **kwargs): + pass + + def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): + reconstructions = reconstructions.contiguous() + rec_loss = torch.abs(inputs.contiguous() - reconstructions) + rec_loss = torch.sum(rec_loss) / rec_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss + + ret_dict = { + prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), + prefix+'rec_loss': rec_loss, + prefix+'kl_loss': kl_loss, + prefix+'loss': loss, + prefix+'mean': posteriors.mean.mean(), + prefix+'logvar': posteriors.logvar.mean(), + } + + render_weight = self.lossconfig.get("render_weight", 0) + tv_weight = self.lossconfig.get("tv_weight", 0) + l1_weight = self.lossconfig.get("l1_weight", 0) + latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) + latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) + + triplane_rec = self.unrollout(reconstructions) + if render_weight > 0 and batch is not None: + rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) + render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256 + loss += render_weight * render_loss + ret_dict[prefix + 'render_loss'] = render_loss + if tv_weight > 0: + tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).sum() / triplane_rec.shape[0] + tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).sum() / triplane_rec.shape[0] + tvloss = tvloss_y + tvloss_x + loss += tv_weight * tvloss + ret_dict[prefix + 'tv_loss'] = tvloss + if l1_weight > 0: + l1 = (triplane_rec ** 2).sum() / triplane_rec.shape[0] + loss += l1_weight * l1 + ret_dict[prefix + 'l1_loss'] = l1 + if latent_tv_weight > 0: + latent = posteriors.mean + latent_tv_y = torch.abs(latent[:, :, :-1] - latent[:, :, 1:]).sum() / latent.shape[0] + latent_tv_x = torch.abs(latent[:, :, :, :-1] - latent[:, :, :, 1:]).sum() / latent.shape[0] + latent_tv_loss = latent_tv_y + latent_tv_x + loss += latent_tv_loss * latent_tv_weight + ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss + ret_dict[prefix + 'latent_max'] = latent.max() + ret_dict[prefix + 'latent_min'] = latent.min() + if latent_l1_weight > 0: + latent = posteriors.mean + latent_l1_loss = (latent ** 2).sum() / latent.shape[0] + loss += latent_l1_loss * latent_l1_weight + ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss + + return loss, ret_dict + + def training_step(self, batch, batch_idx): + # inputs = self.get_input(batch, self.image_key) + inputs = batch['triplane'] + reconstructions, posterior = self(inputs) + + # if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + # self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + # if optimizer_idx == 1: + # # train the discriminator + # discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + # last_layer=self.get_last_layer(), split="train") + + # self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + # self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + # return discloss + + def validation_step(self, batch, batch_idx): + # # inputs = self.get_input(batch, self.image_key) + # inputs = batch['triplane'] + # reconstructions, posterior = self(inputs, sample_posterior=False) + # aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + + # # discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + # # last_layer=self.get_last_layer(), split="val") + + # # self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + # self.log_dict(log_dict_ae) + # # self.log_dict(log_dict_disc) + # return self.log_dict + + inputs = batch['triplane'] + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def create_eg3d_decoder(self, decoder_ckpt): + triplane_decoder = Renderer_TriPlane(**self.renderer_config) + if decoder_ckpt is not None: + pretrain_pth = torch.load(decoder_ckpt, map_location='cpu') + pretrain_pth = { + '.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items() + } + triplane_decoder.load_state_dict(pretrain_pth) + render_kwargs = { + 'depth_resolution': 128, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 128, + 'clamp_mode': 'softplus', + 'white_back': True, + 'det': True + } + return triplane_decoder, render_kwargs + + def render_triplane_eg3d_decoder(self, triplane, batch_rays, target): + ray_o = batch_rays[:, 0] + ray_d = batch_rays[:, 1] + psnr_list = [] + rec_img_list = [] + res = triplane.shape[-2] + for i in range(ray_o.shape[0]): + with torch.no_grad(): + render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res), + ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False) + rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1) + psnr = mse2psnr(img2mse(rec_img[0], target[i])) + psnr_list.append(psnr) + rec_img_list.append(rec_img) + return torch.cat(rec_img_list, 0), psnr_list + + def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024): + assert batch_rays.shape[1] == 1 + sel = torch.randint(batch_rays.shape[-2], [sample_num]) + ray_o = batch_rays[:, 0, 0, sel] + ray_d = batch_rays[:, 0, 1, sel] + res = triplane.shape[-2] + render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res), + ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False) + rec_img = render_out['rgb_marched'] + target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :] + return rec_img, target + + def create_nerf(self, decoder_ckpt): + # decoder_ckpt = '/mnt/petrelfs/share_data/caoziang/shapenet_triplane_car/003000.tar' + + multires = 10 + netchunk = 1024*64 + i_embed = 0 + perturb = 0 + raw_noise_std = 0 + + triplanechannel=18 + triplanesize=256 + chunk=4096 + num_instance=1 + batch_size=1 + use_viewdirs = True + white_bkgd = False + lrate_decay = 6 + netdepth=1 + netwidth=64 + N_samples = 512 + N_importance = 0 + N_rand = 8192 + multires_views=10 + precrop_iters = 0 + precrop_frac = 0.5 + i_weights=3000 + + embed_fn, input_ch = get_embedder(multires, i_embed) + embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed) + output_ch = 4 + skips = [4] + model = NeRF(D=netdepth, W=netwidth, + input_ch=triplanechannel, size=triplanesize,output_ch=output_ch, skips=skips, + input_ch_views=input_ch_views, use_viewdirs=use_viewdirs, num_instance=num_instance) + + network_query_fn = lambda inputs, viewdirs, label,network_fn : \ + run_network(inputs, viewdirs, network_fn, + embed_fn=embed_fn, + embeddirs_fn=embeddirs_fn,label=label, + netchunk=netchunk) + + ckpt = torch.load(decoder_ckpt) + model.load_state_dict(ckpt['network_fn_state_dict']) + + render_kwargs_test = { + 'network_query_fn' : network_query_fn, + 'perturb' : perturb, + 'N_samples' : N_samples, + # 'network_fn' : model, + 'use_viewdirs' : use_viewdirs, + 'white_bkgd' : white_bkgd, + 'raw_noise_std' : raw_noise_std, + } + render_kwargs_test['ndc'] = False + render_kwargs_test['lindisp'] = False + render_kwargs_test['perturb'] = False + render_kwargs_test['raw_noise_std'] = 0. + + return model, render_kwargs_test + + def render_triplane(self, triplane, batch_rays, target, near, far, chunk=4096): + self.triplane_decoder.tri_planes.copy_(triplane.detach()) + self.triplane_render_kwargs['network_fn'] = self.triplane_decoder + # print(triplane.device) + # print(batch_rays.device) + # print(target.device) + # print(near.device) + # print(far.device) + with torch.no_grad(): + rgb, _, _, psnr_list = \ + render_path1(batch_rays, chunk, self.triplane_render_kwargs, gt_imgs=target, + near=near, far=far, label=torch.Tensor([0]).long().to(triplane.device)) + return rgb, psnr_list + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + # inputs = batch['triplane'] + # reconstructions, posterior = self(inputs, sample_posterior=False) + # aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + # self.log_dict(log_dict_ae) + + # batch_size = inputs.shape[0] + # psnr_list = [] # between rec and gt + # psnr_input_list = [] # between input and gt + # psnr_rec_list = [] # between input and rec + + # mean = torch.Tensor([ + # 0.2820, 0.4103, -0.2988, 0.1491, 0.4429, -0.3117, 0.2830, 0.4115, + # -0.3032, 0.1530, 0.4466, -0.3165, 0.2617, 0.3837, -0.2692, 0.1098, + # 0.4101, -0.2922 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + # std = torch.Tensor([ + # 1.1696, 1.1287, 1.1733, 1.1583, 1.1238, 1.1675, 1.1978, 1.1585, 1.1949, + # 1.1660, 1.1576, 1.1998, 1.1987, 1.1546, 1.1930, 1.1724, 1.1450, 1.2027 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + + # if self.norm: + # reconstructions_unnormalize = reconstructions * std + mean + # else: + # reconstructions_unnormalize = reconstructions + + # for b in range(batch_size): + # # rgb_input, cur_psnr_list_input = self.render_triplane( + # # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # # ) + # # rgb, cur_psnr_list = self.render_triplane( + # # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # # ) + + # if self.renderer_type == 'nerf': + # rgb_input, cur_psnr_list_input = self.render_triplane( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # rgb, cur_psnr_list = self.render_triplane( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # elif self.renderer_type == 'eg3d': + # rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # else: + # raise NotImplementedError + + # cur_psnr_list_rec = [] + # for i in range(rgb.shape[0]): + # cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + # rgb_input = to8b(rgb_input.detach().cpu().numpy()) + # rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + # rgb = to8b(rgb.detach().cpu().numpy()) + + # if batch_idx < 1: + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + # psnr_list += cur_psnr_list + # psnr_input_list += cur_psnr_list_input + # psnr_rec_list += cur_psnr_list_rec + + # self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + # self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + # self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + inputs = batch['triplane'] + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + # opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + # lr=lr, betas=(0.5, 0.9)) + # return [opt_ae, opt_disc], [] + return opt_ae + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + +class AutoencoderKLRollOut(AutoencoderKL): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.rollout(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + # if batch_idx < 1: + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + # mean = torch.Tensor([ + # 0.2820, 0.4103, -0.2988, 0.1491, 0.4429, -0.3117, 0.2830, 0.4115, + # -0.3032, 0.1530, 0.4466, -0.3165, 0.2617, 0.3837, -0.2692, 0.1098, + # 0.4101, -0.2922 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + # std = torch.Tensor([ + # 1.1696, 1.1287, 1.1733, 1.1583, 1.1238, 1.1675, 1.1978, 1.1585, 1.1949, + # 1.1660, 1.1576, 1.1998, 1.1987, 1.1546, 1.1930, 1.1724, 1.1450, 1.2027 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + + mean = torch.Tensor([ + -1.8449, -1.8242, 0.9667, -1.0187, 1.0647, -0.5422, -1.8632, -1.8435, + 0.9314, -1.0261, 1.0356, -0.5484, -1.8543, -1.8348, 0.9109, -1.0169, + 1.0160, -0.5467 + ]).reshape(1, 18, 1, 1).to(inputs.device) + std = torch.Tensor([ + 1.7593, 1.6127, 2.7132, 1.5500, 2.7893, 0.7707, 2.1114, 1.9198, 2.6586, + 1.8021, 2.5473, 1.0305, 1.7042, 1.7507, 2.4270, 1.4365, 2.2511, 0.8792 + ]).reshape(1, 18, 1, 1).to(inputs.device) + + if self.norm: + reconstructions_unnormalize = reconstructions * std + mean + else: + reconstructions_unnormalize = reconstructions + + # for b in range(batch_size): + # if self.renderer_type == 'nerf': + # rgb_input, cur_psnr_list_input = self.render_triplane( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # rgb, cur_psnr_list = self.render_triplane( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # elif self.renderer_type == 'eg3d': + # rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # else: + # raise NotImplementedError + + # cur_psnr_list_rec = [] + # for i in range(rgb.shape[0]): + # cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + # rgb_input = to8b(rgb_input.detach().cpu().numpy()) + # rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + # rgb = to8b(rgb.detach().cpu().numpy()) + + # # if batch_idx < 1: + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + # psnr_list += cur_psnr_list + # psnr_input_list += cur_psnr_list_input + # psnr_rec_list += cur_psnr_list_rec + + # self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + # self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + # self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLRollOut3DAware(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + ddconfig['z_channels'] *= 3 + del self.decoder + del self.post_quant_conv + self.decoder = Decoder(**ddconfig) + self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.to3daware(self.rollout(x)) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs)) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + res = inputs.shape[1] + colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLRollOut3DAwareOnlyInput(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + # ddconfig = kwargs['ddconfig'] + # ddconfig['z_channels'] *= 3 + # self.decoder = Decoder(**ddconfig) + # self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.to3daware(self.rollout(x)) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + # z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs)) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLRollOut3DAwareMeanPool(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + ddconfig['z_channels'] *= 3 + self.decoder = Decoder(**ddconfig) + self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.AvgPool2d((res, 1)) + y_mp = torch.nn.AvgPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.to3daware(self.rollout(x)) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs)) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + res = inputs.shape[1] + colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLGroupConv(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.latent_list = [] + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + # ddconfig['z_channels'] *= 3 + del self.decoder + del self.encoder + self.encoder = Encoder_GroupConv(**ddconfig) + self.decoder = Decoder_GroupConv(**ddconfig) + + if "mean" in ddconfig: + print("Using mean std!!") + self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + else: + self.triplane_mean = None + self.triplane_std = None + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + # x = self.to3daware(self.rollout(x)) + x = self.rollout(x) + if self.triplane_mean is not None: + x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + # z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if self.triplane_mean is not None: + dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + rgb_input = np.stack([rgb_input[..., 2], rgb_input[..., 1], rgb_input[..., 0]], -1) + rgb = np.stack([rgb[..., 2], rgb[..., 1], rgb[..., 0]], -1) + + if b % 2 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)], + "val/latent_vis": [wandb.Image(colorize_z)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + + import os + import random + import string + # z_np = z.detach().cpu().numpy() + z_np = inputs.detach().cpu().numpy() + fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' + with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: + np.save(f, z_np) + + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 0: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + if True: + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): + reconstructions = reconstructions.contiguous() + # rec_loss = torch.abs(inputs.contiguous() - reconstructions) + # rec_loss = torch.sum(rec_loss) / rec_loss.shape[0] + rec_loss = F.mse_loss(inputs.contiguous(), reconstructions) + kl_loss = posteriors.kl() + # kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + kl_loss = kl_loss.mean() + loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss + + ret_dict = { + prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), + prefix+'rec_loss': rec_loss, + prefix+'kl_loss': kl_loss, + prefix+'loss': loss, + prefix+'mean': posteriors.mean.mean(), + prefix+'logvar': posteriors.logvar.mean(), + } + + + latent = posteriors.mean + ret_dict[prefix + 'latent_max'] = latent.max() + ret_dict[prefix + 'latent_min'] = latent.min() + + render_weight = self.lossconfig.get("render_weight", 0) + tv_weight = self.lossconfig.get("tv_weight", 0) + l1_weight = self.lossconfig.get("l1_weight", 0) + latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) + latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) + + triplane_rec = self.unrollout(reconstructions) + if render_weight > 0 and batch is not None: + rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) + # render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256 + render_loss = F.mse_loss(rgb_rendered, target) + loss += render_weight * render_loss + ret_dict[prefix + 'render_loss'] = render_loss + if tv_weight > 0: + tvloss_y = F.mse_loss(triplane_rec[:, :, :-1], triplane_rec[:, :, 1:]) + tvloss_x = F.mse_loss(triplane_rec[:, :, :, :-1], triplane_rec[:, :, :, 1:]) + tvloss = tvloss_y + tvloss_x + loss += tv_weight * tvloss + ret_dict[prefix + 'tv_loss'] = tvloss + if l1_weight > 0: + l1 = (triplane_rec ** 2).mean() + loss += l1_weight * l1 + ret_dict[prefix + 'l1_loss'] = l1 + if latent_tv_weight > 0: + latent = posteriors.mean + latent_tv_y = F.mse_loss(latent[:, :, :-1], latent[:, :, 1:]) + latent_tv_x = F.mse_loss(latent[:, :, :, :-1], latent[:, :, :, 1:]) + latent_tv_loss = latent_tv_y + latent_tv_x + loss += latent_tv_loss * latent_tv_weight + ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss + if latent_l1_weight > 0: + latent = posteriors.mean + latent_l1_loss = (latent ** 2).mean() + loss += latent_l1_loss * latent_l1_weight + ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss + + return loss, ret_dict + + +class AutoencoderKLGroupConvLateFusion(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.latent_list = [] + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + del self.decoder + del self.encoder + self.encoder = Encoder_GroupConv_LateFusion(**ddconfig) + self.decoder = Decoder_GroupConv_LateFusion(**ddconfig) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.rollout(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + +from module.model_2d import ViTEncoder, ViTDecoder + +class AutoencoderVIT(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.latent_list = [] + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + # ddconfig['z_channels'] *= 3 + del self.decoder + del self.encoder + del self.quant_conv + del self.post_quant_conv + + assert ddconfig["z_channels"] == 256 + self.encoder = ViTEncoder( + image_size=(256, 256*3), + patch_size=(256//32, 256//32), + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + channels=8) + self.decoder = ViTDecoder( + image_size=(256, 256*3), + patch_size=(256//32, 256//32), + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + channels=8) + + self.quant_conv = torch.nn.Conv2d(768, 2*self.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, 768, 1) + + if "mean" in ddconfig: + print("Using mean std!!") + self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + else: + self.triplane_mean = None + self.triplane_std = None + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + # x = self.to3daware(self.rollout(x)) + x = self.rollout(x) + if self.triplane_mean is not None: + x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + # z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if self.triplane_mean is not None: + dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + + import os + import random + import string + # z_np = z.detach().cpu().numpy() + z_np = inputs.detach().cpu().numpy() + fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' + with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: + np.save(f, z_np) + + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + # if batch_idx < 10: + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + if True: + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + # if batch_idx < 10: + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) diff --git a/3DTopia/model/triplane_vqvae.py b/3DTopia/model/triplane_vqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6be8fae4ad14394eb1f98477f5483e0ae04115 --- /dev/null +++ b/3DTopia/model/triplane_vqvae.py @@ -0,0 +1,418 @@ +import os +import imageio +import torch +import wandb +import numpy as np +import pytorch_lightning as pl +import torch.nn.functional as F + +from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion +from utility.initialize import instantiate_from_config +from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr +from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane +from module.quantise import VectorQuantiser +from module.quantize_taming import EMAVectorQuantizer, VectorQuantizer2, QuantizeEMAReset + +class CVQVAE(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + learning_rate=1e-3, + ckpt_path=None, + ignore_keys=[], + colorize_nlabels=None, + monitor=None, + decoder_ckpt=None, + norm=True, + renderer_type='nerf', + is_cvqvae=False, + renderer_config=dict( + rgbnet_dim=18, + rgbnet_width=128, + viewpe=0, + feape=0 + ), + vector_quantizer_config=dict( + num_embed=1024, + beta=0.25, + distance='cos', + anchor='closest', + first_batch=False, + contras_loss=True, + ) + ): + super().__init__() + self.save_hyperparameters() + self.norm = norm + self.renderer_config = renderer_config + self.learning_rate = learning_rate + + ddconfig['double_z'] = False + self.encoder = Encoder_GroupConv(**ddconfig) + self.decoder = Decoder_GroupConv(**ddconfig) + + self.lossconfig = lossconfig + + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.decoder_ckpt = decoder_ckpt + self.renderer_type = renderer_type + if decoder_ckpt is not None: + self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt) + + vector_quantizer_config['embed_dim'] = embed_dim + + if is_cvqvae: + self.vector_quantizer = VectorQuantiser( + **vector_quantizer_config + ) + else: + self.vector_quantizer = EMAVectorQuantizer( + n_embed=vector_quantizer_config['num_embed'], + codebook_dim = embed_dim, + beta=vector_quantizer_config['beta'] + ) + # self.vector_quantizer = VectorQuantizer2( + # n_e = vector_quantizer_config['num_embed'], + # e_dim = embed_dim, + # beta = vector_quantizer_config['beta'] + # ) + # self.vector_quantizer = QuantizeEMAReset( + # nb_code = vector_quantizer_config['num_embed'], + # code_dim = embed_dim, + # mu = vector_quantizer_config['beta'], + # ) + + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + self.latent_list = [] + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=True) + print(f"Restored from {path}") + + def encode(self, x, rollout=False): + if rollout: + x = self.rollout(x) + h = self.encoder(x) + moments = self.quant_conv(h) + z_q, loss, (perplexity, min_encodings, encoding_indices) = self.vector_quantizer(moments) + return z_q, loss, perplexity, encoding_indices + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input): + z_q, loss, perplexity, encoding_indices = self.encode(input) + dec = self.decode(z_q) + return dec, loss, perplexity, encoding_indices + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, vq_loss, perplexity, encoding_indices = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, vq_loss, prefix='train/', batch=batch) + log_dict_ae['train/perplexity'] = perplexity + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, vq_loss, perplexity, encoding_indices = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, vq_loss, prefix='val/', batch=None) + log_dict_ae['val/perplexity'] = perplexity + self.log_dict(log_dict_ae) + + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)], + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, vq_loss, perplexity, encoding_indices = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, vq_loss, prefix='test/', batch=None) + log_dict_ae['test/perplexity'] = perplexity + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + + reconstructions = self.unrollout(reconstructions) + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + if True: + for b in range(batch_size): + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + def loss(self, inputs, reconstructions, vq_loss, prefix, batch=None): + reconstructions = reconstructions.contiguous() + rec_loss = F.mse_loss(inputs.contiguous(), reconstructions) + loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.vq_weight * vq_loss + + ret_dict = { + prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), + prefix+'rec_loss': rec_loss, + prefix+'vq_loss': vq_loss, + prefix+'loss': loss, + } + + render_weight = self.lossconfig.get("render_weight", 0) + tv_weight = self.lossconfig.get("tv_weight", 0) + l1_weight = self.lossconfig.get("l1_weight", 0) + latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) + latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) + + triplane_rec = self.unrollout(reconstructions) + if render_weight > 0 and batch is not None: + rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) + render_loss = F.mse(rgb_rendered, target) + loss += render_weight * render_loss + ret_dict[prefix + 'render_loss'] = render_loss + if tv_weight > 0: + tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).mean() + tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).mean() + tvloss = tvloss_y + tvloss_x + loss += tv_weight * tvloss + ret_dict[prefix + 'tv_loss'] = tvloss + if l1_weight > 0: + l1 = (triplane_rec ** 2).mean() + loss += l1_weight * l1 + ret_dict[prefix + 'l1_loss'] = l1 + + ret_dict[prefix+'loss'] = loss + + return loss, ret_dict + + def create_eg3d_decoder(self, decoder_ckpt): + triplane_decoder = Renderer_TriPlane(**self.renderer_config) + pretrain_pth = torch.load(decoder_ckpt, map_location='cpu') + pretrain_pth = { + '.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items() + } + # import pdb; pdb.set_trace() + triplane_decoder.load_state_dict(pretrain_pth) + render_kwargs = { + 'depth_resolution': 128, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 128, + 'clamp_mode': 'softplus', + 'white_back': True, + 'det': True + } + return triplane_decoder, render_kwargs + + def render_triplane_eg3d_decoder(self, triplane, batch_rays, target): + ray_o = batch_rays[:, 0] + ray_d = batch_rays[:, 1] + psnr_list = [] + rec_img_list = [] + res = triplane.shape[-2] + for i in range(ray_o.shape[0]): + with torch.no_grad(): + render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res), + ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False) + rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1) + psnr = mse2psnr(img2mse(rec_img[0], target[i])) + psnr_list.append(psnr) + rec_img_list.append(rec_img) + return torch.cat(rec_img_list, 0), psnr_list + + def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024): + assert batch_rays.shape[1] == 1 + sel = torch.randint(batch_rays.shape[-2], [sample_num]) + ray_o = batch_rays[:, 0, 0, sel] + ray_d = batch_rays[:, 0, 1, sel] + res = triplane.shape[-2] + render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res), + ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False) + rec_img = render_out['rgb_marched'] + target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :] + return rec_img, target + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters())+ + list(self.vector_quantizer.parameters()), + lr=lr) + return opt_ae diff --git a/3DTopia/module/model_2d.py b/3DTopia/module/model_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..e501c20953c37314b00c4689b37c2c969a60c47d --- /dev/null +++ b/3DTopia/module/model_2d.py @@ -0,0 +1,2206 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from utility.initialize import instantiate_from_config +from .nn_2d import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none", "vanilla_groupconv", "crossattention"], f'attn_type {attn_type} unknown' + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == 'vanilla_groupconv': + return AttnBlock_GroupConv(in_channels) + elif attn_type == 'crossattention': + num_heads = 8 + return TriplaneAttentionBlock(in_channels, num_heads, in_channels // num_heads, True) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +class ResnetBlock_GroupConv(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels * 3, 32 * 3) + self.conv1 = torch.nn.Conv2d(in_channels * 3, + out_channels * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels * 3, 32 * 3) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels * 3, + out_channels * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels * 3, + out_channels * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels * 3, + out_channels * 3, + kernel_size=1, + stride=1, + padding=0, + groups=3) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + assert temb is None + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +def rollout(triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + +def unrollout(triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + +class Upsample_GroupConv(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels * 3, + in_channels * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample_GroupConv(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels * 3, + in_channels * 3, + kernel_size=3, + stride=2, + padding=0, + groups=3) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class AttnBlock_GroupConv(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x, temp=None): + x = rollout(x) + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return unrollout(x+h_) + + +from torch import nn, einsum +from inspect import isfunction +from einops import rearrange, repeat + +def exists(val): + return val is not None + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + x = x.permute(0, 2, 1) + context = context.permute(0, 2, 1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out).permute(0, 2, 1) + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +class TriplaneAttentionBlock(nn.Module): + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + + self.plane1_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + self.plane2_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + self.plane3_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels) + + def forward(self, x, temp=None): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + x = rollout(x) + + b, c, *spatial = x.shape + res = x.shape[-2] + plane1 = x[..., :res].reshape(b, c, -1) + plane2 = x[..., res:res*2].reshape(b, c, -1) + plane3 = x[..., 2*res:3*res].reshape(b, c, -1) + x = x.reshape(b, c, -1) + + plane1_output = self.plane1_ca(self.norm(plane1), self.norm(x)) + plane2_output = self.plane2_ca(self.norm(plane2), self.norm(x)) + plane3_output = self.plane3_ca(self.norm(plane3), self.norm(x)) + + h = torch.cat([plane1_output, plane2_output, plane3_output], -1) + + x = (x + h).reshape(b, c, *spatial) + + return unrollout(x) + + +class Encoder_GroupConv(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, + attn_type="vanilla_groupconv", mid_layers=1, + **ignore_kwargs): + super().__init__() + assert not use_linear_attn + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + # self.conv_in = torch.nn.Conv2d(in_channels, + # self.ch, + # kernel_size=3, + # stride=1, + # padding=1) + self.conv_in = torch.nn.Conv2d(in_channels * 3, + self.ch * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample_GroupConv(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.attn_type = attn_type + self.mid = nn.Module() + if attn_type == 'crossattention': + self.mid.block_1 = nn.ModuleList() + for _ in range(mid_layers): + self.mid.block_1.append( + ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + ) + self.mid.block_1.append( + make_attn(block_in, attn_type=attn_type) + ) + else: + self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in * 3, 32 * 3) + self.conv_out = torch.nn.Conv2d(block_in * 3, + 2*z_channels * 3 if double_z else z_channels * 3, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + x = unrollout(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + if self.attn_type == 'crossattention': + for m in self.mid.block_1: + h = m(h, temb) + else: + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + h = rollout(h) + + return h + +class Decoder_GroupConv(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla_groupconv", mid_layers=1, **ignorekwargs): + super().__init__() + assert not use_linear_attn + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels * 3, + block_in * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + + # middle + self.mid = nn.Module() + self.attn_type = attn_type + if attn_type == 'crossattention': + self.mid.block_1 = nn.ModuleList() + for _ in range(mid_layers): + self.mid.block_1.append( + ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + ) + self.mid.block_1.append( + make_attn(block_in, attn_type=attn_type) + ) + else: + self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample_GroupConv(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in * 3, 32 * 3) + self.conv_out = torch.nn.Conv2d(block_in * 3, + out_ch * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + z = unrollout(z) + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + if self.attn_type == 'crossattention': + for m in self.mid.block_1: + h = m(h, temb) + else: + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + + h = rollout(h) + + return h + + + +# not success attempts +class CrossAttnFuseBlock_GroupConv(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.q1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.q2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + self.fuse_out = torch.nn.Conv2d(in_channels * 3, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + x = rollout(x) + + b, c, *spatial = x.shape + res = x.shape[-2] + plane1 = x[..., :res].reshape(b, c, res, res) + plane2 = x[..., res:res*2].reshape(b, c, res, res) + plane3 = x[..., 2*res:3*res].reshape(b, c, res, res) + + # h_ = x + # h_ = self.norm(h_) + # q = self.q(h_) + # k = self.k(h_) + # v = self.v(h_) + + q0 = self.q0(self.norm(plane2)) + k0 = self.k0(self.norm(plane2)) + v0 = self.v0(self.norm(plane2)) + + q1 = self.q1(self.norm(plane2)) + k1 = self.k1(self.norm(plane1)) + v1 = self.v1(self.norm(plane1)) + + q2 = self.q2(self.norm(plane2)) + k2 = self.k2(self.norm(plane3)) + v2 = self.v2(self.norm(plane3)) + + def compute_attention(q, k, v): + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + return h_ + + h0 = compute_attention(q0, k0, v0) + h0 = self.proj_out0(h0) + + h1 = compute_attention(q1, k1, v1) + h1 = self.proj_out1(h1) + + h2 = compute_attention(q2, k2, v2) + h2 = self.proj_out2(h2) + + fuse_out = self.fuse_out( + torch.cat([h0, h1, h2], 1) + ) + + return fuse_out + +class CrossAttnDecodeBlock_GroupConv(nn.Module): + def __init__(self, in_channels, h, w): + super().__init__() + self.in_channels = in_channels + self.h = h + self.w = w + + self.norm = Normalize(in_channels) + self.q0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + self.q1 = torch.nn.Parameter(torch.randn(1, self.in_channels, h, w)) + self.q1.requires_grad = True + + self.k1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + self.q2 = torch.nn.Parameter(torch.randn(1, self.in_channels, h, w)) + self.q2.requires_grad = True + + self.k2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out0 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out1 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out2 = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + self.fuse_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + # x = rollout(x) + + b, c, *spatial = x.shape + res = x.shape[-2] + # plane1 = x[..., :res].reshape(b, c, res, res) + # plane2 = x[..., res:res*2].reshape(b, c, res, res) + # plane3 = x[..., 2*res:3*res].reshape(b, c, res, res) + + # h_ = x + # h_ = self.norm(h_) + # q = self.q(h_) + # k = self.k(h_) + # v = self.v(h_) + + q0 = self.q0(self.norm(x)) + k0 = self.k0(self.norm(x)) + v0 = self.v0(self.norm(x)) + + q1 = self.q1.repeat(b, 1, 1, 1) + k1 = self.k1(self.norm(x)) + v1 = self.v1(self.norm(x)) + + q2 = self.q2.repeat(b, 1, 1, 1) + k2 = self.k2(self.norm(x)) + v2 = self.v2(self.norm(x)) + + def compute_attention(q, k, v): + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + return h_ + + h0 = compute_attention(q0, k0, v0) + h0 = self.proj_out0(h0) + + h1 = compute_attention(q1, k1, v1) + h1 = self.proj_out1(h1) + + h2 = compute_attention(q2, k2, v2) + h2 = self.proj_out2(h2) + + fuse_out = self.fuse_out( + torch.cat([h1, h0, h2], -1) + ) + + fuse_out = unrollout(fuse_out) + + return fuse_out + +class Encoder_GroupConv_LateFusion(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla_groupconv", + **ignore_kwargs): + super().__init__() + assert not use_linear_attn + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels * 3, + self.ch * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample_GroupConv(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # fuse to one plane + self.fuse = CrossAttnFuseBlock_GroupConv(block_in) + + # end + self.norm_out = Normalize(block_in, 32) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + x = unrollout(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + h = self.fuse(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + # h = rollout(h) + + return h + +class Decoder_GroupConv_LateFusion(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla_groupconv", **ignorekwargs): + super().__init__() + assert not use_linear_attn + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # triplane decoder + self.triplane_decoder = CrossAttnDecodeBlock_GroupConv(block_in, curr_res, curr_res) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock_GroupConv(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample_GroupConv(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in * 3, 32 * 3) + self.conv_out = torch.nn.Conv2d(block_in * 3, + out_ch * 3, + kernel_size=3, + stride=1, + padding=1, + groups=3) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + h = self.triplane_decoder(h) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + + h = rollout(h) + + return h + + +# VIT Encoder and Decoder from https://github.com/thuanz123/enhancing-transformers/blob/main/enhancing/modules/stage1/layers.py +# ------------------------------------------------------------------------------------ +# Enhancing Transformers +# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved. +# Licensed under the MIT License [see LICENSE for details] +# ------------------------------------------------------------------------------------ +# Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch) +# Copyright (c) 2020 Phil Wang. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import math +import numpy as np +from typing import Union, Tuple, List +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +def get_2d_sincos_pos_embed(embed_dim, grid_size): + """ + grid_size: int or (int, int) of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def init_weights(m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + w = m.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + +class PreNorm(nn.Module): + def __init__(self, dim: int, fn: nn.Module) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.Tanh(), + nn.Linear(hidden_dim, dim) + ) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None: + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity() + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) + + attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(attn) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None: + super().__init__() + self.layers = nn.ModuleList([]) + for idx in range(depth): + layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)), + PreNorm(dim, FeedForward(dim, mlp_dim))]) + self.layers.append(layer) + self.norm = nn.LayerNorm(dim) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +class ViTEncoder(nn.Module): + def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], + dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None: + super().__init__() + image_height, image_width = image_size if isinstance(image_size, tuple) \ + else (image_size, image_size) + patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ + else (patch_size, patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) + + self.num_patches = (image_height // patch_height) * (image_width // patch_width) + self.patch_dim = channels * patch_height * patch_width + + self.to_patch_embedding = nn.Sequential( + nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size), + Rearrange('b c h w -> b (h w) c'), + ) + + self.patch_height = patch_height + self.patch_width = patch_width + self.image_height = image_height + self.image_width = image_width + self.dim = dim + + self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False) + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + + self.apply(init_weights) + + def forward(self, img: torch.FloatTensor) -> torch.FloatTensor: + x = self.to_patch_embedding(img) + x = x + self.en_pos_embedding + x = self.transformer(x) + + x = Rearrange('b h w c -> b c h w')(x.reshape(-1, self.image_height // self.patch_height, self.image_width // self.patch_width, self.dim)) + + return x + + +class ViTDecoder(nn.Module): + def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int], + dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None: + super().__init__() + image_height, image_width = image_size if isinstance(image_size, tuple) \ + else (image_size, image_size) + patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \ + else (patch_size, patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width)) + + self.num_patches = (image_height // patch_height) * (image_width // patch_width) + self.patch_dim = channels * patch_height * patch_width + + self.patch_height = patch_height + self.patch_width = patch_width + self.image_height = image_height + self.image_width = image_width + self.dim = dim + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) + self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False) + self.to_pixel = nn.Sequential( + Rearrange('b (h w) c -> b c h w', h=image_height // patch_height), + nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size) + ) + + self.apply(init_weights) + + def forward(self, token: torch.FloatTensor) -> torch.FloatTensor: + token = Rearrange('b c h w -> b (h w) c')(token) + + x = token + self.de_pos_embedding + x = self.transformer(x) + x = self.to_pixel(x) + + return x + + def get_last_layer(self) -> nn.Parameter: + return self.to_pixel[-1].weight diff --git a/3DTopia/module/nn_2d.py b/3DTopia/module/nn_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d44cc7f0b0aa6106053343d509c73bf7b466aba0 --- /dev/null +++ b/3DTopia/module/nn_2d.py @@ -0,0 +1,546 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + +# zero123/zero123/ldm/modules/diffusionmodules/util.py +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +# zero123/zero123/ldm/modules/attention.py +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in + + +def exists(x): + return x is not None diff --git a/3DTopia/module/quantise.py b/3DTopia/module/quantise.py new file mode 100644 index 0000000000000000000000000000000000000000..07540d6d99f649a6cbd9cb3c8b6a608a77f4e4fa --- /dev/null +++ b/3DTopia/module/quantise.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantiser(nn.Module): + """ + Improved version over vector quantiser, with the dynamic initialisation + for these unoptimised "dead" points. + num_embed: number of codebook entry + embed_dim: dimensionality of codebook entry + beta: weight for the commitment loss + distance: distance for looking up the closest code + anchor: anchor sampled methods + first_batch: if true, the offline version of our model + contras_loss: if true, use the contras_loss to further improve the performance + """ + def __init__(self, num_embed, embed_dim, beta, distance='cos', + anchor='probrandom', first_batch=False, contras_loss=False): + super().__init__() + + self.num_embed = num_embed + self.embed_dim = embed_dim + self.beta = beta + self.distance = distance + self.anchor = anchor + self.first_batch = first_batch + self.contras_loss = contras_loss + self.decay = 0.99 + self.init = False + + self.pool = FeaturePool(self.num_embed, self.embed_dim) + self.embedding = nn.Embedding(self.num_embed, self.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed) + self.register_buffer("embed_prob", torch.zeros(self.num_embed)) + + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.embed_dim) + + # clculate the distance + if self.distance == 'l2': + # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \ + torch.sum(self.embedding.weight ** 2, dim=1) + \ + 2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n')) + elif self.distance == 'cos': + # cosine distances from z to embeddings e_j + normed_z_flattened = F.normalize(z_flattened, dim=1).detach() + normed_codebook = F.normalize(self.embedding.weight, dim=1) + d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n')) + + # encoding + sort_distance, indices = d.sort(dim=1) + # look up the closest point for the indices + encoding_indices = indices[:,-1] + encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device) + encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) + + # quantise and unflatten + z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + # count + # import pdb + # pdb.set_trace() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + min_encodings = encodings + + # online clustered reinitialisation for unoptimized points + if self.training: + # calculate the average usage of code entries + self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay) + # running average updates + if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init): + # closest sampling + if self.anchor == 'closest': + sort_distance, indices = d.sort(dim=0) + random_feat = z_flattened.detach()[indices[-1,:]] + # feature pool based random sampling + elif self.anchor == 'random': + random_feat = self.pool.query(z_flattened.detach()) + # probabilitical based random sampling + elif self.anchor == 'probrandom': + norm_distance = F.softmax(d.t(), dim=1) + prob = torch.multinomial(norm_distance, num_samples=1).view(-1) + random_feat = z_flattened.detach()[prob] + # decay parameter based on the average usage + decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim) + self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay + if self.first_batch: + self.init = True + # contrastive loss + if self.contras_loss: + sort_distance, indices = d.sort(dim=0) + dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True) + dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:] + dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07 + contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device)) + loss += contra_loss + + return z_q, loss, (perplexity, min_encodings, encoding_indices) + +class FeaturePool(): + """ + This class implements a feature buffer that stores previously encoded features + + This buffer enables us to initialize the codebook using a history of generated features + rather than the ones produced by the latest encoders + """ + def __init__(self, pool_size, dim=64): + """ + Initialize the FeaturePool class + + Parameters: + pool_size(int) -- the size of featue buffer + """ + self.pool_size = pool_size + if self.pool_size > 0: + self.nums_features = 0 + self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size + + def query(self, features): + """ + return features from the pool + """ + self.features = self.features.to(features.device) + if self.nums_features < self.pool_size: + if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook + random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),)) + self.features = features[random_feat_id] + self.nums_features = self.pool_size + else: + # if the mini-batch is not large nuough, just store it for the next update + num = self.nums_features + features.size(0) + self.features[self.nums_features:num] = features + self.nums_features = num + else: + if features.size(0) > int(self.pool_size): + random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),)) + self.features = features[random_feat_id] + else: + random_id = torch.randperm(self.pool_size) + self.features[random_id[:features.size(0)]] = features + + return self.features diff --git a/3DTopia/module/quantize_taming.py b/3DTopia/module/quantize_taming.py new file mode 100644 index 0000000000000000000000000000000000000000..e018d4c07ed5ddd513ca46ea2f063000bdd6ddc1 --- /dev/null +++ b/3DTopia/module/quantize_taming.py @@ -0,0 +1,564 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = 0 + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, codebook_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = n_embed + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def dequantize(self, ids): + return self.embedding(ids) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) + + +class QuantizeEMAReset(nn.Module): + def __init__(self, nb_code, code_dim, mu): + super().__init__() + self.nb_code = nb_code + self.code_dim = code_dim + self.mu = mu + self.reset_codebook() + + def reset_codebook(self): + self.init = False + self.code_sum = None + self.code_count = None + device = "cuda" if torch.cuda.is_available() else "cpu" + self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device)) + + def _tile(self, x): + nb_code_x, code_dim = x.shape + if nb_code_x < self.nb_code: + n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x + std = 0.01 / np.sqrt(code_dim) + out = x.repeat(n_repeats, 1) + out = out + torch.randn_like(out) * std + else : + out = x + return out + + def init_codebook(self, x): + out = self._tile(x) + self.codebook = out[:self.nb_code] + self.code_sum = self.codebook.clone() + self.code_count = torch.ones(self.nb_code, device=self.codebook.device) + self.init = True + + @torch.no_grad() + def compute_perplexity(self, code_idx) : + # Calculate new centres + code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1) + + code_count = code_onehot.sum(dim=-1) # nb_code + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + return perplexity + + @torch.no_grad() + def update_codebook(self, x, code_idx): + + code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L + code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1) + + code_sum = torch.matmul(code_onehot, x) # nb_code, w + code_count = code_onehot.sum(dim=-1) # nb_code + + out = self._tile(x) + code_rand = out[:self.nb_code] + + # Update centres + self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code + self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code + + usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float() + code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1) + + self.codebook = usage * code_update + (1 - usage) * code_rand + prob = code_count / torch.sum(code_count) + perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7))) + + + return perplexity + + def quantize(self, x): + # Calculate latent code x_l + k_w = self.codebook.t() + distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0, + keepdim=True) # (N * L, b) + _, code_idx = torch.min(distance, dim=-1) + return code_idx + + def dequantize(self, code_idx): + x = F.embedding(code_idx, self.codebook) + return x + + def forward(self, x): + N, C, H, W = x.shape + + # Preprocess + # x = self.preprocess(x) + x = rearrange(x, 'b c h w -> b h w c') + x = x.reshape(-1, self.code_dim) + + # Init codebook if not inited + if self.training and not self.init: + self.init_codebook(x) + + # quantize and dequantize through bottleneck + code_idx = self.quantize(x) + x_d = self.dequantize(code_idx) + + # Update embeddings + if self.training: + perplexity = self.update_codebook(x, code_idx) + else : + perplexity = self.compute_perplexity(code_idx) + + # Loss + commit_loss = F.mse_loss(x, x_d.detach()) + + # Passthrough + x_d = x + (x_d - x).detach() + + # Postprocess + x_d = x_d.view(N, H, W, C).permute(0, 3, 1, 2).contiguous() + + return x_d, commit_loss, (perplexity, code_idx, code_idx) \ No newline at end of file diff --git a/3DTopia/module/renderer.py b/3DTopia/module/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..bff0e703093f86bbe1693960db5574815aca6a56 --- /dev/null +++ b/3DTopia/module/renderer.py @@ -0,0 +1,463 @@ +import os +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# TriPlane Utils +class MipRayMarcher2(nn.Module): + def __init__(self): + super().__init__() + + def run_forward(self, colors, densities, depths, rendering_options): + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + + if rendering_options['clamp_mode'] == 'softplus': + densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better + else: + assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = torch.sum(weights * depths_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + # composite_depth = torch.nan_to_num(composite_depth, 0.) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights + + def forward(self, colors, densities, depths, rendering_options): + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + + return composite_rgb, composite_depth, weights + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + + # # ORIGINAL + # N, M, C = coordinates.shape + # xy_coords = coordinates[..., [0, 1]] + # xz_coords = coordinates[..., [0, 2]] + # zx_coords = coordinates[..., [2, 0]] + # return torch.stack([xy_coords, xz_coords, zx_coords], dim=1).reshape(N*3, M, 2) + + # FIXED + N, M, _ = coordinates.shape + xy_coords = coordinates[..., [0, 1]] + yz_coords = coordinates[..., [1, 2]] + zx_coords = coordinates[..., [2, 0]] + return torch.stack([xy_coords, yz_coords, zx_coords], dim=1).reshape(N*3, M, 2) + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + + coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + + output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', padding_mode='zeros', align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class FullyConnectedLayer(nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + # self.weight = torch.nn.Parameter(torch.full([out_features, in_features], np.float32(0))) + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + +class TriPlane_Decoder(nn.Module): + def __init__(self, dim=12, width=128): + super().__init__() + self.net = torch.nn.Sequential( + FullyConnectedLayer(dim, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, 1 + 3) + ) + + def forward(self, sampled_features): + sampled_features = sampled_features.mean(1) + x = sampled_features + + N, M, C = x.shape + x = x.view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + return {'rgb': rgb, 'sigma': sigma} + +class Renderer_TriPlane(nn.Module): + def __init__(self, rgbnet_dim=18, rgbnet_width=128): + super(Renderer_TriPlane, self).__init__() + self.decoder = TriPlane_Decoder(dim=rgbnet_dim//3, width=rgbnet_width) + self.ray_marcher = MipRayMarcher2() + self.plane_axes = generate_planes() + + def forward(self, planes, ray_origins, ray_directions, rendering_options, whole_img=False): + self.plane_axes = self.plane_axes.to(ray_origins.device) + + ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape + + # Coarse Pass + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + + + out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options) + colors_coarse = out['rgb'] + densities_coarse = out['sigma'] + colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) + densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options) + colors_fine = out['rgb'] + densities_fine = out['sigma'] + colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) + densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + + # Aggregate + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + else: + rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + # return rgb_final, depth_final, weights.sum(2) + if whole_img: + H = W = int(ray_origins.shape[1] ** 0.5) + rgb_final = rgb_final.permute(0, 2, 1).reshape(-1, 3, H, W).contiguous() + depth_final = depth_final.permute(0, 2, 1).reshape(-1, 1, H, W).contiguous() + depth_final = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min()) + depth_final = depth_final.repeat(1, 3, 1, 1) + # rgb_final = torch.clip(rgb_final, min=0, max=1) + rgb_final = (rgb_final + 1) / 2. + weights = weights.sum(2).reshape(rgb_final.shape[0], rgb_final.shape[2], rgb_final.shape[3]) + return { + 'rgb_marched': rgb_final, + 'depth_final': depth_final, + 'weights': weights, + } + else: + rgb_final = (rgb_final + 1) / 2. + return { + 'rgb_marched': rgb_final, + 'depth_final': depth_final, + } + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denomnchpwq', x) + x = x.reshape(b, 3, self.out_channel//3, self.out_reso, self.out_reso).contiguous() + return x + + +class SingleImageToTriplaneVAE(nn.Module): + def __init__(self, backbone='dino_vits8', input_reso=256, out_reso=128, out_channel=18, z_dim=32, + decoder_depth=16, decoder_heads=16, decoder_mlp_dim=1024, decoder_dim_head=64, dropout=0): + super().__init__() + self.backbone = backbone + + self.input_image_size = input_reso + self.out_reso = out_reso + self.out_channel = out_channel + self.z_dim = z_dim + + self.decoder_depth = decoder_depth + self.decoder_heads = decoder_heads + self.decoder_mlp_dim = decoder_mlp_dim + self.decoder_dim_head = decoder_dim_head + + self.dropout = dropout + self.patch_size = 8 if '8' in backbone else 16 + + if 'dino' in backbone: + self.vit = torch.hub.load('facebookresearch/dino:main', backbone) + self.embed_dim = self.vit.embed_dim + self.preprocess = None + else: + raise NotImplementedError + + self.fc_mu = nn.Linear(self.embed_dim, self.z_dim) + self.fc_var = nn.Linear(self.embed_dim, self.z_dim) + + self.vit_decoder = TriplaneDecoder((self.input_image_size // self.patch_size) ** 2, self.z_dim, + depth=self.decoder_depth, heads=self.decoder_heads, mlp_dim=self.decoder_mlp_dim, + out_channel=self.out_channel, out_reso=self.out_reso, dim_head = self.decoder_dim_head, dropout=0) + + def forward(self, x, is_train): + assert x.shape[-1] == self.input_image_size + bs = x.shape[0] + if 'dino' in self.backbone: + z = self.vit.get_intermediate_layers(x, n=1)[0][:, 1:] # [bs, 1024, self.embed_dim] + else: + raise NotImplementedError + + z = z.reshape(-1, z.shape[-1]) + mu = self.fc_mu(z) + logvar = self.fc_var(z) + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + if is_train: + rep_z = eps * std + mu + else: + rep_z = eps + rep_z = rep_z.reshape(bs, -1, self.z_dim) + out = self.vit_decoder(rep_z) + + return out, mu, logvar diff --git a/3DTopia/sample_stage1.py b/3DTopia/sample_stage1.py new file mode 100644 index 0000000000000000000000000000000000000000..60a1192df3f429604c28721b94fe9b9eb28e07c5 --- /dev/null +++ b/3DTopia/sample_stage1.py @@ -0,0 +1,299 @@ +import os +import cv2 +import json +import torch +import mcubes +import trimesh +import argparse +import numpy as np +from tqdm import tqdm +import imageio.v2 as imageio +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.dpm_solver import DPMSolverSampler + +from utility.initialize import instantiate_from_config, get_obj_from_str +from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes +from utility.triplane_renderer.renderer import get_rays, to8b +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download + +import warnings +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=DeprecationWarning) + +def add_text(rgb, caption): + font = cv2.FONT_HERSHEY_SIMPLEX + # org + gap = 30 + org = (gap, gap) + # fontScale + fontScale = 0.6 + # Blue color in BGR + color = (255, 0, 0) + # Line thickness of 2 px + thickness = 1 + break_caption = [] + for i in range(len(caption) // 30 + 1): + break_caption_i = caption[i*30:(i+1)*30] + break_caption.append(break_caption_i) + for i, bci in enumerate(break_caption): + cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) + return rgb + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default='configs/default.yaml') + parser.add_argument("--ckpt", type=str, default=None) + parser.add_argument("--test_folder", type=str, default="stage1") + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--sampler", type=str, default="ddpm") + parser.add_argument("--samples", type=int, default=1) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--steps", type=int, default=1000) + parser.add_argument("--text", nargs='+', default='a robot') + parser.add_argument("--text_file", type=str, default=None) + parser.add_argument("--no_video", action='store_true', default=False) + parser.add_argument("--render_res", type=int, default=128) + parser.add_argument("--no_mcubes", action='store_true', default=False) + parser.add_argument("--mcubes_res", type=int, default=128) + parser.add_argument("--cfg_scale", type=float, default=1) + args = parser.parse_args() + + if args.text is not None: + text = [' '.join(args.text),] + elif args.text_file is not None: + if args.text_file.endswith('.json'): + with open(args.text_file, 'r') as f: + json_file = json.load(f) + text = json_file + text = [l.strip('.') for l in text] + else: + with open(args.text_file, 'r') as f: + text = f.readlines() + text = [l.strip() for l in text] + else: + raise NotImplementedError + + print(text) + + configs = OmegaConf.load(args.config) + if args.seed is not None: + pl.seed_everything(args.seed) + + log_dir = os.path.join('results', args.config.split('/')[-1].split('.')[0], args.test_folder) + os.makedirs(log_dir, exist_ok=True) + + if args.ckpt == None: + ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors") + else: + ckpt = args.ckpt + + if ckpt.endswith(".ckpt"): + model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params) + elif ckpt.endswith(".safetensors"): + model = get_obj_from_str(configs.model["target"])(**configs.model.params) + model_ckpt = load_file(ckpt) + model.load_state_dict(model_ckpt) + else: + raise NotImplementedError + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + class DummySampler: + def __init__(self, model): + self.model = model + + def sample(self, S, batch_size, shape, verbose, conditioning=None, *args, **kwargs): + return self.model.sample( + conditioning, batch_size, shape=[batch_size, ] + shape, *args, **kwargs + ), None + + if args.sampler == 'dpm': + raise NotImplementedError + # sampler = DPMSolverSampler(model) + elif args.sampler == 'plms': + raise NotImplementedError + # sampler = PLMSSampler(model) + elif args.sampler == 'ddim': + sampler = DDIMSampler(model) + elif args.sampler == 'ddpm': + sampler = DummySampler(model) + else: + raise NotImplementedError + + img_size = configs.model.params.unet_config.params.image_size + channels = configs.model.params.unet_config.params.in_channels + shape = [channels, img_size, img_size * 3] + plane_axes = generate_planes() + + pose_folder = 'assets/sample_data/pose' + poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)]) + batch_rays_list = [] + H = args.render_res + ratio = 512 // H + for p in poses_fname: + c2w = np.loadtxt(p).reshape(4, 4) + c2w[:3, 3] *= 2.2 + c2w = np.array([ + [1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1] + ]) @ c2w + + k = np.array([ + [560 / ratio, 0, H * 0.5], + [0, 560 / ratio, H * 0.5], + [0, 0, 1] + ]) + + rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4])) + coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1) + coords = torch.reshape(coords, [-1,2]).long() + rays_o = rays_o[coords[:, 0], coords[:, 1]] + rays_d = rays_d[coords[:, 0], coords[:, 1]] + batch_rays = torch.stack([rays_o, rays_d], 0) + batch_rays_list.append(batch_rays) + batch_rays_list = torch.stack(batch_rays_list, 0) + + for text_idx, text_i in enumerate(text): + text_connect = '_'.join(text_i.split(' ')) + for s in range(args.samples): + batch_size = args.batch_size + with torch.no_grad(): + # with model.ema_scope(): + noise = None + c = model.get_learned_conditioning([text_i]) + unconditional_c = torch.zeros_like(c) + if args.cfg_scale != 1: + assert args.sampler == 'ddim' + sample, _ = sampler.sample( + S=args.steps, + batch_size=batch_size, + shape=shape, + verbose=False, + x_T = noise, + conditioning = c.repeat(batch_size, 1, 1), + unconditional_guidance_scale=args.cfg_scale, + unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1) + ) + else: + sample, _ = sampler.sample( + S=args.steps, + batch_size=batch_size, + shape=shape, + verbose=False, + x_T = noise, + conditioning = c.repeat(batch_size, 1, 1), + ) + decode_res = model.decode_first_stage(sample) + + for b in range(batch_size): + def render_img(v): + rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder( + decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device), + ) + rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] + rgb_sample = np.stack( + [rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1 + ) + # rgb_sample = add_text(rgb_sample, text_i) + return rgb_sample + + if not args.no_mcubes: + # prepare volumn for marching cube + res = args.mcubes_res + c_list = torch.linspace(-1.2, 1.2, steps=res) + grid_x, grid_y, grid_z = torch.meshgrid( + c_list, c_list, c_list, indexing='ij' + ) + coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) + plane_axes = generate_planes() + feats = sample_from_planes( + plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 + ) + fake_dirs = torch.zeros_like(coords) + fake_dirs[..., 0] = 1 + out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs) + u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() + del out + + # marching cube + vertices, triangles = mcubes.marching_cubes(u, 10) + min_bound = np.array([-1.2, -1.2, -1.2]) + max_bound = np.array([1.2, 1.2, 1.2]) + vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] + pt_vertices = torch.from_numpy(vertices).to(device) + + # extract vertices color + res_triplane = 256 + render_kwargs = { + 'depth_resolution': 128, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 128, + 'clamp_mode': 'softplus', + 'white_back': True, + 'det': True + } + rays_o_list = [ + np.array([0, 0, 2]), + np.array([0, 0, -2]), + np.array([0, 2, 0]), + np.array([0, -2, 0]), + np.array([2, 0, 0]), + np.array([-2, 0, 0]), + ] + rgb_final = None + diff_final = None + for rays_o in tqdm(rays_o_list): + rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) + rays_d = pt_vertices.reshape(-1, 3) - rays_o + rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) + dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) + + render_out = model.first_stage_model.triplane_decoder( + decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane), + rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, + whole_img=False, tvloss=False + ) + rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() + depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() + depth_diff = np.abs(dist - depth) + + if rgb_final is None: + rgb_final = rgb.copy() + diff_final = depth_diff.copy() + + else: + ind = diff_final > depth_diff + rgb_final[ind] = rgb[ind] + diff_final[ind] = depth_diff[ind] + + + # bgr to rgb + rgb_final = np.stack([ + rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] + ], -1) + + # export to ply + mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) + trimesh.exchange.export.export_mesh(mesh, os.path.join(log_dir, f"{text_connect}_{s}_{b}.ply"), file_type='ply') + + if not args.no_video: + view_num = len(batch_rays_list) + video_list = [] + for v in tqdm(range(view_num//4, view_num//4 * 3, 2)): + rgb_sample = render_img(v) + video_list.append(rgb_sample) + imageio.mimwrite(os.path.join(log_dir, "{}_{}_{}.mp4".format(text_connect, s, b)), np.stack(video_list, 0)) + else: + rgb_sample = render_img(104) + imageio.imwrite(os.path.join(log_dir, "{}_{}_{}.jpg".format(text_connect, s, b)), rgb_sample) + +if __name__ == '__main__': + main() diff --git a/3DTopia/taming/data/ade20k.py b/3DTopia/taming/data/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..366dae97207dbb8356598d636e14ad084d45bc76 --- /dev/null +++ b/3DTopia/taming/data/ade20k.py @@ -0,0 +1,124 @@ +import os +import numpy as np +import cv2 +import albumentations +from PIL import Image +from torch.utils.data import Dataset + +from taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/ade20k_examples.txt", + data_root="data/ade20k_images", + segmentation_root="data/ade20k_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=151, shift_segmentation=False) + + +# With semantic map and scene label +class ADE20kBase(Dataset): + def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): + self.split = self.get_split() + self.n_labels = 151 # unknown + 150 + self.data_csv = {"train": "data/ade20k_train.txt", + "validation": "data/ade20k_test.txt"}[self.split] + self.data_root = "data/ade20k_root" + with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: + self.scene_categories = f.read().splitlines() + self.scene_categories = dict(line.split() for line in self.scene_categories) + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, "images", l) + for l in self.image_paths], + "relative_segmentation_path_": [l.replace(".jpg", ".png") + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.data_root, "annotations", + l.replace(".jpg", ".png")) + for l in self.image_paths], + "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] + for l in self.image_paths], + } + + size = None if size is not None and size<=0 else size + self.size = size + if crop_size is None: + self.crop_size = size if size is not None else None + else: + self.crop_size = crop_size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + + if crop_size is not None: + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + segmentation = np.array(segmentation).astype(np.uint8) + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, mask=segmentation) + else: + processed = {"image": image, "mask": segmentation} + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class ADE20kTrain(ADE20kBase): + # default to random_crop=True + def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): + super().__init__(config=config, size=size, random_crop=random_crop, + interpolation=interpolation, crop_size=crop_size) + + def get_split(self): + return "train" + + +class ADE20kValidation(ADE20kBase): + def get_split(self): + return "validation" + + +if __name__ == "__main__": + dset = ADE20kValidation() + ex = dset[0] + for k in ["image", "scene_category", "segmentation"]: + print(type(ex[k])) + try: + print(ex[k].shape) + except: + print(ex[k]) diff --git a/3DTopia/taming/data/annotated_objects_coco.py b/3DTopia/taming/data/annotated_objects_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..af000ecd943d7b8a85d7eb70195c9ecd10ab5edc --- /dev/null +++ b/3DTopia/taming/data/annotated_objects_coco.py @@ -0,0 +1,139 @@ +import json +from itertools import chain +from pathlib import Path +from typing import Iterable, Dict, List, Callable, Any +from collections import defaultdict + +from tqdm import tqdm + +from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from taming.data.helper_types import Annotation, ImageDescription, Category + +COCO_PATH_STRUCTURE = { + 'train': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_train2017.json', + 'stuff_annotations': 'annotations/stuff_train2017.json', + 'files': 'train2017' + }, + 'validation': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_val2017.json', + 'stuff_annotations': 'annotations/stuff_val2017.json', + 'files': 'val2017' + } +} + + +def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: + return { + str(img['id']): ImageDescription( + id=img['id'], + license=img.get('license'), + file_name=img['file_name'], + coco_url=img['coco_url'], + original_size=(img['width'], img['height']), + date_captured=img.get('date_captured'), + flickr_url=img.get('flickr_url') + ) + for img in description_json + } + + +def load_categories(category_json: Iterable) -> Dict[str, Category]: + return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) + for cat in category_json if cat['name'] != 'other'} + + +def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], + category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: + annotations = defaultdict(list) + total = sum(len(a) for a in annotations_json) + for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): + image_id = str(ann['image_id']) + if image_id not in image_descriptions: + raise ValueError(f'image_id [{image_id}] has no image description.') + category_id = ann['category_id'] + try: + category_no = category_no_for_id(str(category_id)) + except KeyError: + continue + + width, height = image_descriptions[image_id].original_size + bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) + + annotations[image_id].append( + Annotation( + id=ann['id'], + area=bbox[2]*bbox[3], # use bbox area + is_group_of=ann['iscrowd'], + image_id=ann['image_id'], + bbox=bbox, + category_id=str(category_id), + category_no=category_no + ) + ) + return dict(annotations) + + +class AnnotatedObjectsCoco(AnnotatedObjectsDataset): + def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): + """ + @param data_path: is the path to the following folder structure: + coco/ + ├── annotations + │ ├── instances_train2017.json + │ ├── instances_val2017.json + │ ├── stuff_train2017.json + │ └── stuff_val2017.json + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000000025.jpg + │ └── ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000000285.jpg + │ └── ... + @param: split: one of 'train' or 'validation' + @param: desired image size (give square images) + """ + super().__init__(**kwargs) + self.use_things = use_things + self.use_stuff = use_stuff + + with open(self.paths['instances_annotations']) as f: + inst_data_json = json.load(f) + with open(self.paths['stuff_annotations']) as f: + stuff_data_json = json.load(f) + + category_jsons = [] + annotation_jsons = [] + if self.use_things: + category_jsons.append(inst_data_json['categories']) + annotation_jsons.append(inst_data_json['annotations']) + if self.use_stuff: + category_jsons.append(stuff_data_json['categories']) + annotation_jsons.append(stuff_data_json['annotations']) + + self.categories = load_categories(chain(*category_jsons)) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = load_image_descriptions(inst_data_json['images']) + annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) + self.annotations = self.filter_object_number(annotations, self.min_object_area, + self.min_objects_per_image, self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in COCO_PATH_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for COCO data.]') + return COCO_PATH_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + # noinspection PyProtectedMember + return self.image_descriptions[image_id]._asdict() diff --git a/3DTopia/taming/data/annotated_objects_dataset.py b/3DTopia/taming/data/annotated_objects_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..53cc346a1c76289a4964d7dc8a29582172f33dc0 --- /dev/null +++ b/3DTopia/taming/data/annotated_objects_dataset.py @@ -0,0 +1,218 @@ +from pathlib import Path +from typing import Optional, List, Callable, Dict, Any, Union +import warnings + +import PIL.Image as pil_image +from torch import Tensor +from torch.utils.data import Dataset +from torchvision import transforms + +from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder +from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from taming.data.conditional_builder.utils import load_object_from_string +from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType +from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ + Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor + + +class AnnotatedObjectsDataset(Dataset): + def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int, + min_object_area: float, min_objects_per_image: int, max_objects_per_image: int, + crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool, + encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "", + no_object_classes: Optional[int] = None): + self.data_path = data_path + self.split = split + self.keys = keys + self.target_image_size = target_image_size + self.min_object_area = min_object_area + self.min_objects_per_image = min_objects_per_image + self.max_objects_per_image = max_objects_per_image + self.crop_method = crop_method + self.random_flip = random_flip + self.no_tokens = no_tokens + self.use_group_parameter = use_group_parameter + self.encode_crop = encode_crop + + self.annotations = None + self.image_descriptions = None + self.categories = None + self.category_ids = None + self.category_number = None + self.image_ids = None + self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip) + self.paths = self.build_paths(self.data_path) + self._conditional_builders = None + self.category_allow_list = None + if category_allow_list_target: + allow_list = load_object_from_string(category_allow_list_target) + self.category_allow_list = {name for name, _ in allow_list} + self.category_mapping = {} + if category_mapping_target: + self.category_mapping = load_object_from_string(category_mapping_target) + self.no_object_classes = no_object_classes + + def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]: + top_level = Path(top_level) + sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()} + for path in sub_paths.values(): + if not path.exists(): + raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.') + return sub_paths + + @staticmethod + def load_image_from_disk(path: Path) -> Image: + return pil_image.open(path).convert('RGB') + + @staticmethod + def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool): + transform_functions = [] + if crop_method == 'none': + transform_functions.append(transforms.Resize((target_image_size, target_image_size))) + elif crop_method == 'center': + transform_functions.extend([ + transforms.Resize(target_image_size), + CenterCropReturnCoordinates(target_image_size) + ]) + elif crop_method == 'random-1d': + transform_functions.extend([ + transforms.Resize(target_image_size), + RandomCrop1dReturnCoordinates(target_image_size) + ]) + elif crop_method == 'random-2d': + transform_functions.extend([ + Random2dCropReturnCoordinates(target_image_size), + transforms.Resize(target_image_size) + ]) + elif crop_method is None: + return None + else: + raise ValueError(f'Received invalid crop method [{crop_method}].') + if random_flip: + transform_functions.append(RandomHorizontalFlipReturn()) + transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.)) + return transform_functions + + def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor): + crop_bbox = None + flipped = None + for t in self.transform_functions: + if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)): + crop_bbox, x = t(x) + elif isinstance(t, RandomHorizontalFlipReturn): + flipped, x = t(x) + else: + x = t(x) + return crop_bbox, flipped, x + + @property + def no_classes(self) -> int: + return self.no_object_classes if self.no_object_classes else len(self.categories) + + @property + def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: + # cannot set this up in init because no_classes is only known after loading data in init of superclass + if self._conditional_builders is None: + self._conditional_builders = { + 'objects_center_points': ObjectsCenterPointsConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.encode_crop, + self.use_group_parameter, + getattr(self, 'use_additional_parameters', False) + ), + 'objects_bbox': ObjectsBoundingBoxConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.encode_crop, + self.use_group_parameter, + getattr(self, 'use_additional_parameters', False) + ) + } + return self._conditional_builders + + def filter_categories(self) -> None: + if self.category_allow_list: + self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list} + if self.category_mapping: + self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping} + + def setup_category_id_and_number(self) -> None: + self.category_ids = list(self.categories.keys()) + self.category_ids.sort() + if '/m/01s55n' in self.category_ids: + self.category_ids.remove('/m/01s55n') + self.category_ids.append('/m/01s55n') + self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)} + if self.category_allow_list is not None and self.category_mapping is None \ + and len(self.category_ids) != len(self.category_allow_list): + warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. ' + 'Make sure all names in category_allow_list exist.') + + def clean_up_annotations_and_image_descriptions(self) -> None: + image_id_set = set(self.image_ids) + self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set} + self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set} + + @staticmethod + def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float, + min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]: + filtered = {} + for image_id, annotations in all_annotations.items(): + annotations_with_min_area = [a for a in annotations if a.area > min_object_area] + if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image: + filtered[image_id] = annotations_with_min_area + return filtered + + def __len__(self): + return len(self.image_ids) + + def __getitem__(self, n: int) -> Dict[str, Any]: + image_id = self.get_image_id(n) + sample = self.get_image_description(image_id) + sample['annotations'] = self.get_annotation(image_id) + + if 'image' in self.keys: + sample['image_path'] = str(self.get_image_path(image_id)) + sample['image'] = self.load_image_from_disk(sample['image_path']) + sample['image'] = convert_pil_to_tensor(sample['image']) + sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image']) + sample['image'] = sample['image'].permute(1, 2, 0) + + for conditional, builder in self.conditional_builders.items(): + if conditional in self.keys: + sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped']) + + if self.keys: + # only return specified keys + sample = {key: sample[key] for key in self.keys} + return sample + + def get_image_id(self, no: int) -> str: + return self.image_ids[no] + + def get_annotation(self, image_id: str) -> str: + return self.annotations[image_id] + + def get_textual_label_for_category_id(self, category_id: str) -> str: + return self.categories[category_id].name + + def get_textual_label_for_category_no(self, category_no: int) -> str: + return self.categories[self.get_category_id(category_no)].name + + def get_category_number(self, category_id: str) -> int: + return self.category_number[category_id] + + def get_category_id(self, category_no: int) -> str: + return self.category_ids[category_no] + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + raise NotImplementedError() + + def get_path_structure(self): + raise NotImplementedError + + def get_image_path(self, image_id: str) -> Path: + raise NotImplementedError diff --git a/3DTopia/taming/data/annotated_objects_open_images.py b/3DTopia/taming/data/annotated_objects_open_images.py new file mode 100644 index 0000000000000000000000000000000000000000..aede6803d2cef7a74ca784e7907d35fba6c71239 --- /dev/null +++ b/3DTopia/taming/data/annotated_objects_open_images.py @@ -0,0 +1,137 @@ +from collections import defaultdict +from csv import DictReader, reader as TupleReader +from pathlib import Path +from typing import Dict, List, Any +import warnings + +from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from taming.data.helper_types import Annotation, Category +from tqdm import tqdm + +OPEN_IMAGES_STRUCTURE = { + 'train': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'oidv6-train-annotations-bbox.csv', + 'file_list': 'train-images-boxable.csv', + 'files': 'train' + }, + 'validation': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'validation-annotations-bbox.csv', + 'file_list': 'validation-images.csv', + 'files': 'validation' + }, + 'test': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'test-annotations-bbox.csv', + 'file_list': 'test-images.csv', + 'files': 'test' + } +} + + +def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], + category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: + annotations: Dict[str, List[Annotation]] = defaultdict(list) + with open(descriptor_path) as file: + reader = DictReader(file) + for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): + width = float(row['XMax']) - float(row['XMin']) + height = float(row['YMax']) - float(row['YMin']) + area = width * height + category_id = row['LabelName'] + if category_id in category_mapping: + category_id = category_mapping[category_id] + if area >= min_object_area and category_id in category_no_for_id: + annotations[row['ImageID']].append( + Annotation( + id=i, + image_id=row['ImageID'], + source=row['Source'], + category_id=category_id, + category_no=category_no_for_id[category_id], + confidence=float(row['Confidence']), + bbox=(float(row['XMin']), float(row['YMin']), width, height), + area=area, + is_occluded=bool(int(row['IsOccluded'])), + is_truncated=bool(int(row['IsTruncated'])), + is_group_of=bool(int(row['IsGroupOf'])), + is_depiction=bool(int(row['IsDepiction'])), + is_inside=bool(int(row['IsInside'])) + ) + ) + if 'train' in str(descriptor_path) and i < 14000000: + warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') + return dict(annotations) + + +def load_image_ids(csv_path: Path) -> List[str]: + with open(csv_path) as file: + reader = DictReader(file) + return [row['image_name'] for row in reader] + + +def load_categories(csv_path: Path) -> Dict[str, Category]: + with open(csv_path) as file: + reader = TupleReader(file) + return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} + + +class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): + def __init__(self, use_additional_parameters: bool, **kwargs): + """ + @param data_path: is the path to the following folder structure: + open_images/ + │ oidv6-train-annotations-bbox.csv + ├── class-descriptions-boxable.csv + ├── oidv6-train-annotations-bbox.csv + ├── test + │ ├── 000026e7ee790996.jpg + │ ├── 000062a39995e348.jpg + │ └── ... + ├── test-annotations-bbox.csv + ├── test-images.csv + ├── train + │ ├── 000002b66c9c498e.jpg + │ ├── 000002b97e5471a0.jpg + │ └── ... + ├── train-images-boxable.csv + ├── validation + │ ├── 0001eeaf4aed83f9.jpg + │ ├── 0004886b7d043cfd.jpg + │ └── ... + ├── validation-annotations-bbox.csv + └── validation-images.csv + @param: split: one of 'train', 'validation' or 'test' + @param: desired image size (returns square images) + """ + + super().__init__(**kwargs) + self.use_additional_parameters = use_additional_parameters + + self.categories = load_categories(self.paths['class_descriptions']) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = {} + annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, + self.category_number) + self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, + self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in OPEN_IMAGES_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') + return OPEN_IMAGES_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + image_path = self.get_image_path(image_id) + return {'file_path': str(image_path), 'file_name': image_path.name} diff --git a/3DTopia/taming/data/base.py b/3DTopia/taming/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7 --- /dev/null +++ b/3DTopia/taming/data/base.py @@ -0,0 +1,70 @@ +import bisect +import numpy as np +import albumentations +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset + + +class ConcatDatasetWithIndex(ConcatDataset): + """Modified from original pytorch code to return dataset idx""" + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx], dataset_idx + + +class ImagePaths(Dataset): + def __init__(self, paths, size=None, random_crop=False, labels=None): + self.size = size + self.random_crop = random_crop + + self.labels = dict() if labels is None else labels + self.labels["file_path_"] = paths + self._length = len(paths) + + if self.size is not None and self.size > 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) + if not self.random_crop: + self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) + self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return self._length + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def __getitem__(self, i): + example = dict() + example["image"] = self.preprocess_image(self.labels["file_path_"][i]) + for k in self.labels: + example[k] = self.labels[k][i] + return example + + +class NumpyPaths(ImagePaths): + def preprocess_image(self, image_path): + image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 + image = np.transpose(image, (1,2,0)) + image = Image.fromarray(image, mode="RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image diff --git a/3DTopia/taming/data/coco.py b/3DTopia/taming/data/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2f7838448cb63dcf96daffe9470d58566d975a --- /dev/null +++ b/3DTopia/taming/data/coco.py @@ -0,0 +1,176 @@ +import os +import json +import albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset + +from taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/coco_examples.txt", + data_root="data/coco_images", + segmentation_root="data/coco_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=183, shift_segmentation=True) + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in ["captions_train2017.json", + "captions_val2017.json"] + if self.stuffthing: + self.segmentation_prefix = ( + "data/cocostuffthings/val2017" if + data_json.endswith("captions_val2017.json") else + "data/cocostuffthings/train2017") + else: + self.segmentation_prefix = ( + "data/coco/annotations/stuff_val2017_pixelmaps" if + data_json.endswith("captions_val2017.json") else + "data/coco/annotations/stuff_train2017_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + image, segmentation = processed["image"], processed["segmentation"] + image = (image / 127.5 - 1.0).astype(np.float32) + + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + "caption": [str(caption[0])], + "segmentation": segmentation, + "img_path": img_path, + "seg_path": seg_path, + "filename_": img_path.split(os.sep)[-1] + } + return example + + +class CocoImagesAndCaptionsTrain(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + +class CocoImagesAndCaptionsValidation(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" diff --git a/3DTopia/taming/data/conditional_builder/objects_bbox.py b/3DTopia/taming/data/conditional_builder/objects_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..15881e76b7ab2a914df8f2dfe08ae4f0c6c511b5 --- /dev/null +++ b/3DTopia/taming/data/conditional_builder/objects_bbox.py @@ -0,0 +1,60 @@ +from itertools import cycle +from typing import List, Tuple, Callable, Optional + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + +from taming.data.helper_types import BoundingBox, Annotation +from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ + pad_list, get_plot_font_size, absolute_bbox + + +class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): + @property + def object_descriptor_length(self) -> int: + return 3 + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_triples = [ + (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) + for ann in annotations + ] + empty_triple = (self.none, self.none, self.none) + object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) + return object_triples + + def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + object_triples = grouper(conditional_list, 3) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) + for object_triple in object_triples if object_triple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + font = ImageFont.truetype( + "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", + size=get_plot_font_size(font_size, figure_size) + ) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): + annotation = self.representation_to_annotation(representation) + class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) + bbox = absolute_bbox(bbox, width, height) + draw.rectangle(bbox, outline=color, width=line_width) + draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. diff --git a/3DTopia/taming/data/conditional_builder/objects_center_points.py b/3DTopia/taming/data/conditional_builder/objects_center_points.py new file mode 100644 index 0000000000000000000000000000000000000000..9a480329cc47fb38a7b8729d424e092b77d40749 --- /dev/null +++ b/3DTopia/taming/data/conditional_builder/objects_center_points.py @@ -0,0 +1,168 @@ +import math +import random +import warnings +from itertools import cycle +from typing import List, Optional, Tuple, Callable + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ + additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ + absolute_bbox, rescale_annotations +from taming.data.helper_types import BoundingBox, Annotation +from taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + + +class ObjectsCenterPointsConditionalBuilder: + def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, + use_group_parameter: bool, use_additional_parameters: bool): + self.no_object_classes = no_object_classes + self.no_max_objects = no_max_objects + self.no_tokens = no_tokens + self.encode_crop = encode_crop + self.no_sections = int(math.sqrt(self.no_tokens)) + self.use_group_parameter = use_group_parameter + self.use_additional_parameters = use_additional_parameters + + @property + def none(self) -> int: + return self.no_tokens - 1 + + @property + def object_descriptor_length(self) -> int: + return 2 + + @property + def embedding_dim(self) -> int: + extra_length = 2 if self.encode_crop else 0 + return self.no_max_objects * self.object_descriptor_length + extra_length + + def tokenize_coordinates(self, x: float, y: float) -> int: + """ + Express 2d coordinates with one number. + Example: assume self.no_tokens = 16, then no_sections = 4: + 0 0 0 0 + 0 0 # 0 + 0 0 0 0 + 0 0 0 x + Then the # position corresponds to token 6, the x position to token 15. + @param x: float in [0, 1] + @param y: float in [0, 1] + @return: discrete tokenized coordinate + """ + x_discrete = int(round(x * (self.no_sections - 1))) + y_discrete = int(round(y * (self.no_sections - 1))) + return y_discrete * self.no_sections + x_discrete + + def coordinates_from_token(self, token: int) -> (float, float): + x = token % self.no_sections + y = token // self.no_sections + return x / (self.no_sections - 1), y / (self.no_sections - 1) + + def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: + x0, y0 = self.coordinates_from_token(token1) + x1, y1 = self.coordinates_from_token(token2) + return x0, y0, x1 - x0, y1 - y0 + + def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: + return self.tokenize_coordinates(bbox[0], bbox[1]), \ + self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) + + def inverse_build(self, conditional: LongTensor) \ + -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + table_of_content = grouper(conditional_list, self.object_descriptor_length) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_tuple[0], self.coordinates_from_token(object_tuple[1])) + for object_tuple in table_of_content if object_tuple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + circle_size = get_circle_size(figure_size) + font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', + size=get_plot_font_size(font_size, figure_size)) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): + x_abs, y_abs = x * width, y * height + ann = self.representation_to_annotation(representation) + label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) + ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] + draw.ellipse(ellipse_bbox, fill=color, width=0) + draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. + + def object_representation(self, annotation: Annotation) -> int: + modifier = 0 + if self.use_group_parameter: + modifier |= 1 * (annotation.is_group_of is True) + if self.use_additional_parameters: + modifier |= 2 * (annotation.is_occluded is True) + modifier |= 4 * (annotation.is_depiction is True) + modifier |= 8 * (annotation.is_inside is True) + return annotation.category_no + self.no_object_classes * modifier + + def representation_to_annotation(self, representation: int) -> Annotation: + category_no = representation % self.no_object_classes + modifier = representation // self.no_object_classes + # noinspection PyTypeChecker + return Annotation( + area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, + category_no=category_no, + is_group_of=bool((modifier & 1) * self.use_group_parameter), + is_occluded=bool((modifier & 2) * self.use_additional_parameters), + is_depiction=bool((modifier & 4) * self.use_additional_parameters), + is_inside=bool((modifier & 8) * self.use_additional_parameters) + ) + + def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: + return list(self.token_pair_from_bbox(crop_coordinates)) + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_tuples = [ + (self.object_representation(a), + self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) + for a in annotations + ] + empty_tuple = (self.none, self.none) + object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) + return object_tuples + + def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ + -> LongTensor: + if len(annotations) == 0: + warnings.warn('Did not receive any annotations.') + if len(annotations) > self.no_max_objects: + warnings.warn('Received more annotations than allowed.') + annotations = annotations[:self.no_max_objects] + + if not crop_coordinates: + crop_coordinates = FULL_CROP + + random.shuffle(annotations) + annotations = filter_annotations(annotations, crop_coordinates) + if self.encode_crop: + annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) + if horizontal_flip: + crop_coordinates = horizontally_flip_bbox(crop_coordinates) + extra = self._crop_encoder(crop_coordinates) + else: + annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) + extra = [] + + object_tuples = self._make_object_descriptors(annotations) + flattened = [token for tuple_ in object_tuples for token in tuple_] + extra + assert len(flattened) == self.embedding_dim + assert all(0 <= value < self.no_tokens for value in flattened) + return LongTensor(flattened) diff --git a/3DTopia/taming/data/conditional_builder/utils.py b/3DTopia/taming/data/conditional_builder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ee175f2e05a80dbc71c22acbecb22dddadbb42 --- /dev/null +++ b/3DTopia/taming/data/conditional_builder/utils.py @@ -0,0 +1,105 @@ +import importlib +from typing import List, Any, Tuple, Optional + +from taming.data.helper_types import BoundingBox, Annotation + +# source: seaborn, color palette tab10 +COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), + (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] +BLACK = (0, 0, 0) +GRAY_75 = (63, 63, 63) +GRAY_50 = (127, 127, 127) +GRAY_25 = (191, 191, 191) +WHITE = (255, 255, 255) +FULL_CROP = (0., 0., 1., 1.) + + +def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: + """ + Give intersection area of two rectangles. + @param rectangle1: (x0, y0, w, h) of first rectangle + @param rectangle2: (x0, y0, w, h) of second rectangle + """ + rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] + rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] + x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) + y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) + return x_overlap * y_overlap + + +def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: + return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] + + +def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: + bbox = relative_bbox + bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height + return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + +def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: + return list_ + [pad_element for _ in range(pad_to_length - len(list_))] + + +def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ + List[Annotation]: + def clamp(x: float): + return max(min(x, 1.), 0.) + + def rescale_bbox(bbox: BoundingBox) -> BoundingBox: + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + if flip: + x0 = 1 - (x0 + w) + return x0, y0, w, h + + return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] + + +def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: + return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] + + +def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: + sl = slice(1) if short else slice(None) + string = '' + if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): + return string + if annotation.is_group_of: + string += 'group'[sl] + ',' + if annotation.is_occluded: + string += 'occluded'[sl] + ',' + if annotation.is_depiction: + string += 'depiction'[sl] + ',' + if annotation.is_inside: + string += 'inside'[sl] + return '(' + string.strip(",") + ')' + + +def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: + if font_size is None: + font_size = 10 + if max(figure_size) >= 256: + font_size = 12 + if max(figure_size) >= 512: + font_size = 15 + return font_size + + +def get_circle_size(figure_size: Tuple[int, int]) -> int: + circle_size = 2 + if max(figure_size) >= 256: + circle_size = 3 + if max(figure_size) >= 512: + circle_size = 4 + return circle_size + + +def load_object_from_string(object_string: str) -> Any: + """ + Source: https://stackoverflow.com/a/10773699 + """ + module_name, class_name = object_string.rsplit(".", 1) + return getattr(importlib.import_module(module_name), class_name) diff --git a/3DTopia/taming/data/custom.py b/3DTopia/taming/data/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..33f302a4b55ba1e8ec282ec3292b6263c06dfb91 --- /dev/null +++ b/3DTopia/taming/data/custom.py @@ -0,0 +1,38 @@ +import os +import numpy as np +import albumentations +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex + + +class CustomBase(Dataset): + def __init__(self, *args, **kwargs): + super().__init__() + self.data = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + example = self.data[i] + return example + + + +class CustomTrain(CustomBase): + def __init__(self, size, training_images_list_file): + super().__init__() + with open(training_images_list_file, "r") as f: + paths = f.read().splitlines() + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + + +class CustomTest(CustomBase): + def __init__(self, size, test_images_list_file): + super().__init__() + with open(test_images_list_file, "r") as f: + paths = f.read().splitlines() + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + + diff --git a/3DTopia/taming/data/faceshq.py b/3DTopia/taming/data/faceshq.py new file mode 100644 index 0000000000000000000000000000000000000000..6912d04b66a6d464c1078e4b51d5da290f5e767e --- /dev/null +++ b/3DTopia/taming/data/faceshq.py @@ -0,0 +1,134 @@ +import os +import numpy as np +import albumentations +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex + + +class FacesBase(Dataset): + def __init__(self, *args, **kwargs): + super().__init__() + self.data = None + self.keys = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + example = self.data[i] + ex = {} + if self.keys is not None: + for k in self.keys: + ex[k] = example[k] + else: + ex = example + return ex + + +class CelebAHQTrain(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/celebahq" + with open("data/celebahqtrain.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = NumpyPaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class CelebAHQValidation(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/celebahq" + with open("data/celebahqvalidation.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = NumpyPaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FFHQTrain(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/ffhq" + with open("data/ffhqtrain.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FFHQValidation(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/ffhq" + with open("data/ffhqvalidation.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FacesHQTrain(Dataset): + # CelebAHQ [0] + FFHQ [1] + def __init__(self, size, keys=None, crop_size=None, coord=False): + d1 = CelebAHQTrain(size=size, keys=keys) + d2 = FFHQTrain(size=size, keys=keys) + self.data = ConcatDatasetWithIndex([d1, d2]) + self.coord = coord + if crop_size is not None: + self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + if self.coord: + self.cropper = albumentations.Compose([self.cropper], + additional_targets={"coord": "image"}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + ex, y = self.data[i] + if hasattr(self, "cropper"): + if not self.coord: + out = self.cropper(image=ex["image"]) + ex["image"] = out["image"] + else: + h,w,_ = ex["image"].shape + coord = np.arange(h*w).reshape(h,w,1)/(h*w) + out = self.cropper(image=ex["image"], coord=coord) + ex["image"] = out["image"] + ex["coord"] = out["coord"] + ex["class"] = y + return ex + + +class FacesHQValidation(Dataset): + # CelebAHQ [0] + FFHQ [1] + def __init__(self, size, keys=None, crop_size=None, coord=False): + d1 = CelebAHQValidation(size=size, keys=keys) + d2 = FFHQValidation(size=size, keys=keys) + self.data = ConcatDatasetWithIndex([d1, d2]) + self.coord = coord + if crop_size is not None: + self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + if self.coord: + self.cropper = albumentations.Compose([self.cropper], + additional_targets={"coord": "image"}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + ex, y = self.data[i] + if hasattr(self, "cropper"): + if not self.coord: + out = self.cropper(image=ex["image"]) + ex["image"] = out["image"] + else: + h,w,_ = ex["image"].shape + coord = np.arange(h*w).reshape(h,w,1)/(h*w) + out = self.cropper(image=ex["image"], coord=coord) + ex["image"] = out["image"] + ex["coord"] = out["coord"] + ex["class"] = y + return ex diff --git a/3DTopia/taming/data/helper_types.py b/3DTopia/taming/data/helper_types.py new file mode 100644 index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79 --- /dev/null +++ b/3DTopia/taming/data/helper_types.py @@ -0,0 +1,49 @@ +from typing import Dict, Tuple, Optional, NamedTuple, Union +from PIL.Image import Image as pil_image +from torch import Tensor + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +Image = Union[Tensor, pil_image] +BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h +CropMethodType = Literal['none', 'random', 'center', 'random-2d'] +SplitType = Literal['train', 'validation', 'test'] + + +class ImageDescription(NamedTuple): + id: int + file_name: str + original_size: Tuple[int, int] # w, h + url: Optional[str] = None + license: Optional[int] = None + coco_url: Optional[str] = None + date_captured: Optional[str] = None + flickr_url: Optional[str] = None + flickr_id: Optional[str] = None + coco_id: Optional[str] = None + + +class Category(NamedTuple): + id: str + super_category: Optional[str] + name: str + + +class Annotation(NamedTuple): + area: float + image_id: str + bbox: BoundingBox + category_no: int + category_id: str + id: Optional[int] = None + source: Optional[str] = None + confidence: Optional[float] = None + is_group_of: Optional[bool] = None + is_truncated: Optional[bool] = None + is_occluded: Optional[bool] = None + is_depiction: Optional[bool] = None + is_inside: Optional[bool] = None + segmentation: Optional[Dict] = None diff --git a/3DTopia/taming/data/image_transforms.py b/3DTopia/taming/data/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..657ac332174e0ac72f68315271ffbd757b771a0f --- /dev/null +++ b/3DTopia/taming/data/image_transforms.py @@ -0,0 +1,132 @@ +import random +import warnings +from typing import Union + +import torch +from torch import Tensor +from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor +from torchvision.transforms.functional import _get_image_size as get_image_size + +from taming.data.helper_types import BoundingBox, Image + +pil_to_tensor = PILToTensor() + + +def convert_pil_to_tensor(image: Image) -> Tensor: + with warnings.catch_warnings(): + # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 + warnings.simplefilter("ignore") + return pil_to_tensor(image) + + +class RandomCrop1dReturnCoordinates(RandomCrop): + def forward(self, img: Image) -> (BoundingBox, Image): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h + return bbox, F.crop(img, i, j, h, w) + + +class Random2dCropReturnCoordinates(torch.nn.Module): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + + def __init__(self, min_size: int): + super().__init__() + self.min_size = min_size + + def forward(self, img: Image) -> (BoundingBox, Image): + width, height = get_image_size(img) + max_size = min(width, height) + if max_size <= self.min_size: + size = max_size + else: + size = random.randint(self.min_size, max_size) + top = random.randint(0, height - size) + left = random.randint(0, width - size) + bbox = left / width, top / height, size / width, size / height + return bbox, F.crop(img, top, left, size, size) + + +class CenterCropReturnCoordinates(CenterCrop): + @staticmethod + def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: + if width > height: + w = height / width + h = 1.0 + x0 = 0.5 - w / 2 + y0 = 0. + else: + w = 1.0 + h = width / height + x0 = 0. + y0 = 0.5 - h / 2 + return x0, y0, w, h + + def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + width, height = get_image_size(img) + return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) + + +class RandomHorizontalFlipReturn(RandomHorizontalFlip): + def forward(self, img: Image) -> (bool, Image): + """ + Additionally to flipping, returns a boolean whether it was flipped or not. + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + flipped: whether the image was flipped or not + PIL Image or Tensor: Randomly flipped image. + + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + if torch.rand(1) < self.p: + return True, F.hflip(img) + return False, img diff --git a/3DTopia/taming/data/imagenet.py b/3DTopia/taming/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..9a02ec44ba4af9e993f58c91fa43482a4ecbe54c --- /dev/null +++ b/3DTopia/taming/data/imagenet.py @@ -0,0 +1,558 @@ +import os, tarfile, glob, shutil +import yaml +import numpy as np +from tqdm import tqdm +from PIL import Image +import albumentations +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths +from taming.util import download, retrieve +import taming.data.utils as bdu + + +def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): + synsets = [] + with open(path_to_yaml) as f: + di2s = yaml.load(f) + for idx in indices: + synsets.append(str(di2s[idx])) + print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) + return synsets + + +def str_to_indices(string): + """Expects a string in the format '32-123, 256, 280-321'""" + assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) + subs = string.split(",") + indices = [] + for sub in subs: + subsubs = sub.split("-") + assert len(subsubs) > 0 + if len(subsubs) == 1: + indices.append(int(subsubs[0])) + else: + rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] + indices.extend(rang) + return sorted(indices) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + self.class_labels = [class_dict[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + self.data = ImagePaths(self.abspaths, + labels=labels, + size=retrieve(self.config, "size", default=0), + random_crop=self.random_crop) + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +def get_preprocessor(size=None, random_crop=False, additional_targets=None, + crop_size=None): + if size is not None and size > 0: + transforms = list() + rescaler = albumentations.SmallestMaxSize(max_size = size) + transforms.append(rescaler) + if not random_crop: + cropper = albumentations.CenterCrop(height=size,width=size) + transforms.append(cropper) + else: + cropper = albumentations.RandomCrop(height=size,width=size) + transforms.append(cropper) + flipper = albumentations.HorizontalFlip() + transforms.append(flipper) + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + elif crop_size is not None and crop_size > 0: + if not random_crop: + cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + else: + cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + transforms = [cropper] + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + preprocessor = lambda **kwargs: kwargs + return preprocessor + + +def rgba_to_depth(x): + assert x.dtype == np.uint8 + assert len(x.shape) == 3 and x.shape[2] == 4 + y = x.copy() + y.dtype = np.float32 + y = y.reshape(x.shape[:2]) + return np.ascontiguousarray(y) + + +class BaseWithDepth(Dataset): + DEFAULT_DEPTH_ROOT="data/imagenet_depth" + + def __init__(self, config=None, size=None, random_crop=False, + crop_size=None, root=None): + self.config = config + self.base_dset = self.get_base_dset() + self.preprocessor = get_preprocessor( + size=size, + crop_size=crop_size, + random_crop=random_crop, + additional_targets={"depth": "image"}) + self.crop_size = crop_size + if self.crop_size is not None: + self.rescaler = albumentations.Compose( + [albumentations.SmallestMaxSize(max_size = self.crop_size)], + additional_targets={"depth": "image"}) + if root is not None: + self.DEFAULT_DEPTH_ROOT = root + + def __len__(self): + return len(self.base_dset) + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = self.base_dset[i] + e["depth"] = self.preprocess_depth(self.get_depth_path(e)) + # up if necessary + h,w,c = e["image"].shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + out = self.rescaler(image=e["image"], depth=e["depth"]) + e["image"] = out["image"] + e["depth"] = out["depth"] + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +class ImageNetTrainWithDepth(BaseWithDepth): + # default to random_crop=True + def __init__(self, random_crop=True, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(random_crop=random_crop, **kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetTrain() + else: + return ImageNetTrain({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) + return fid + + +class ImageNetValidationWithDepth(BaseWithDepth): + def __init__(self, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(**kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetValidation() + else: + return ImageNetValidation({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) + return fid + + +class RINTrainWithDepth(ImageNetTrainWithDepth): + def __init__(self, config=None, size=None, random_crop=True, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class RINValidationWithDepth(ImageNetValidationWithDepth): + def __init__(self, config=None, size=None, random_crop=False, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class DRINExamples(Dataset): + def __init__(self): + self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) + with open("data/drin_examples.txt", "r") as f: + relpaths = f.read().splitlines() + self.image_paths = [os.path.join("data/drin_images", + relpath) for relpath in relpaths] + self.depth_paths = [os.path.join("data/drin_depth", + relpath.replace(".JPEG", ".png")) for relpath in relpaths] + + def __len__(self): + return len(self.image_paths) + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = dict() + e["image"] = self.preprocess_image(self.image_paths[i]) + e["depth"] = self.preprocess_depth(self.depth_paths[i]) + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +def imscale(x, factor, keepshapes=False, keepmode="bicubic"): + if factor is None or factor==1: + return x + + dtype = x.dtype + assert dtype in [np.float32, np.float64] + assert x.min() >= -1 + assert x.max() <= 1 + + keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC}[keepmode] + + lr = (x+1.0)*127.5 + lr = lr.clip(0,255).astype(np.uint8) + lr = Image.fromarray(lr) + + h, w, _ = x.shape + nh = h//factor + nw = w//factor + assert nh > 0 and nw > 0, (nh, nw) + + lr = lr.resize((nw,nh), Image.BICUBIC) + if keepshapes: + lr = lr.resize((w,h), keepmode) + lr = np.array(lr)/127.5-1.0 + lr = lr.astype(dtype) + + return lr + + +class ImageNetScale(Dataset): + def __init__(self, size=None, crop_size=None, random_crop=False, + up_factor=None, hr_factor=None, keep_mode="bicubic"): + self.base = self.get_base() + + self.size = size + self.crop_size = crop_size if crop_size is not None else self.size + self.random_crop = random_crop + self.up_factor = up_factor + self.hr_factor = hr_factor + self.keep_mode = keep_mode + + transforms = list() + + if self.size is not None and self.size > 0: + rescaler = albumentations.SmallestMaxSize(max_size = self.size) + self.rescaler = rescaler + transforms.append(rescaler) + + if self.crop_size is not None and self.crop_size > 0: + if len(transforms) == 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) + + if not self.random_crop: + cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) + else: + cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) + transforms.append(cropper) + + if len(transforms) > 0: + if self.up_factor is not None: + additional_targets = {"lr": "image"} + else: + additional_targets = None + self.preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + # adjust resolution + image = imscale(image, self.hr_factor, keepshapes=False) + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + if self.up_factor is None: + image = self.preprocessor(image=image)["image"] + example["image"] = image + else: + lr = imscale(image, self.up_factor, keepshapes=True, + keepmode=self.keep_mode) + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + +class ImageNetScaleTrain(ImageNetScale): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetScaleValidation(ImageNetScale): + def get_base(self): + return ImageNetValidation() + + +from skimage.feature import canny +from skimage.color import rgb2gray + + +class ImageNetEdges(ImageNetScale): + def __init__(self, up_factor=1, **kwargs): + super().__init__(up_factor=1, **kwargs) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + + lr = canny(rgb2gray(image), sigma=2) + lr = lr.astype(np.float32) + lr = lr[:,:,None][:,:,[0,0,0]] + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + + +class ImageNetEdgesTrain(ImageNetEdges): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetEdgesValidation(ImageNetEdges): + def get_base(self): + return ImageNetValidation() diff --git a/3DTopia/taming/data/open_images_helper.py b/3DTopia/taming/data/open_images_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243 --- /dev/null +++ b/3DTopia/taming/data/open_images_helper.py @@ -0,0 +1,379 @@ +open_images_unify_categories_for_coco = { + '/m/03bt1vf': '/m/01g317', + '/m/04yx4': '/m/01g317', + '/m/05r655': '/m/01g317', + '/m/01bl7v': '/m/01g317', + '/m/0cnyhnx': '/m/01xq0k1', + '/m/01226z': '/m/018xm', + '/m/05ctyq': '/m/018xm', + '/m/058qzx': '/m/04ctx', + '/m/06pcq': '/m/0l515', + '/m/03m3pdh': '/m/02crq1', + '/m/046dlr': '/m/01x3z', + '/m/0h8mzrc': '/m/01x3z', +} + + +top_300_classes_plus_coco_compatibility = [ + ('Man', 1060962), + ('Clothing', 986610), + ('Tree', 748162), + ('Woman', 611896), + ('Person', 610294), + ('Human face', 442948), + ('Girl', 175399), + ('Building', 162147), + ('Car', 159135), + ('Plant', 155704), + ('Human body', 137073), + ('Flower', 133128), + ('Window', 127485), + ('Human arm', 118380), + ('House', 114365), + ('Wheel', 111684), + ('Suit', 99054), + ('Human hair', 98089), + ('Human head', 92763), + ('Chair', 88624), + ('Boy', 79849), + ('Table', 73699), + ('Jeans', 57200), + ('Tire', 55725), + ('Skyscraper', 53321), + ('Food', 52400), + ('Footwear', 50335), + ('Dress', 50236), + ('Human leg', 47124), + ('Toy', 46636), + ('Tower', 45605), + ('Boat', 43486), + ('Land vehicle', 40541), + ('Bicycle wheel', 34646), + ('Palm tree', 33729), + ('Fashion accessory', 32914), + ('Glasses', 31940), + ('Bicycle', 31409), + ('Furniture', 30656), + ('Sculpture', 29643), + ('Bottle', 27558), + ('Dog', 26980), + ('Snack', 26796), + ('Human hand', 26664), + ('Bird', 25791), + ('Book', 25415), + ('Guitar', 24386), + ('Jacket', 23998), + ('Poster', 22192), + ('Dessert', 21284), + ('Baked goods', 20657), + ('Drink', 19754), + ('Flag', 18588), + ('Houseplant', 18205), + ('Tableware', 17613), + ('Airplane', 17218), + ('Door', 17195), + ('Sports uniform', 17068), + ('Shelf', 16865), + ('Drum', 16612), + ('Vehicle', 16542), + ('Microphone', 15269), + ('Street light', 14957), + ('Cat', 14879), + ('Fruit', 13684), + ('Fast food', 13536), + ('Animal', 12932), + ('Vegetable', 12534), + ('Train', 12358), + ('Horse', 11948), + ('Flowerpot', 11728), + ('Motorcycle', 11621), + ('Fish', 11517), + ('Desk', 11405), + ('Helmet', 10996), + ('Truck', 10915), + ('Bus', 10695), + ('Hat', 10532), + ('Auto part', 10488), + ('Musical instrument', 10303), + ('Sunglasses', 10207), + ('Picture frame', 10096), + ('Sports equipment', 10015), + ('Shorts', 9999), + ('Wine glass', 9632), + ('Duck', 9242), + ('Wine', 9032), + ('Rose', 8781), + ('Tie', 8693), + ('Butterfly', 8436), + ('Beer', 7978), + ('Cabinetry', 7956), + ('Laptop', 7907), + ('Insect', 7497), + ('Goggles', 7363), + ('Shirt', 7098), + ('Dairy Product', 7021), + ('Marine invertebrates', 7014), + ('Cattle', 7006), + ('Trousers', 6903), + ('Van', 6843), + ('Billboard', 6777), + ('Balloon', 6367), + ('Human nose', 6103), + ('Tent', 6073), + ('Camera', 6014), + ('Doll', 6002), + ('Coat', 5951), + ('Mobile phone', 5758), + ('Swimwear', 5729), + ('Strawberry', 5691), + ('Stairs', 5643), + ('Goose', 5599), + ('Umbrella', 5536), + ('Cake', 5508), + ('Sun hat', 5475), + ('Bench', 5310), + ('Bookcase', 5163), + ('Bee', 5140), + ('Computer monitor', 5078), + ('Hiking equipment', 4983), + ('Office building', 4981), + ('Coffee cup', 4748), + ('Curtain', 4685), + ('Plate', 4651), + ('Box', 4621), + ('Tomato', 4595), + ('Coffee table', 4529), + ('Office supplies', 4473), + ('Maple', 4416), + ('Muffin', 4365), + ('Cocktail', 4234), + ('Castle', 4197), + ('Couch', 4134), + ('Pumpkin', 3983), + ('Computer keyboard', 3960), + ('Human mouth', 3926), + ('Christmas tree', 3893), + ('Mushroom', 3883), + ('Swimming pool', 3809), + ('Pastry', 3799), + ('Lavender (Plant)', 3769), + ('Football helmet', 3732), + ('Bread', 3648), + ('Traffic sign', 3628), + ('Common sunflower', 3597), + ('Television', 3550), + ('Bed', 3525), + ('Cookie', 3485), + ('Fountain', 3484), + ('Paddle', 3447), + ('Bicycle helmet', 3429), + ('Porch', 3420), + ('Deer', 3387), + ('Fedora', 3339), + ('Canoe', 3338), + ('Carnivore', 3266), + ('Bowl', 3202), + ('Human eye', 3166), + ('Ball', 3118), + ('Pillow', 3077), + ('Salad', 3061), + ('Beetle', 3060), + ('Orange', 3050), + ('Drawer', 2958), + ('Platter', 2937), + ('Elephant', 2921), + ('Seafood', 2921), + ('Monkey', 2915), + ('Countertop', 2879), + ('Watercraft', 2831), + ('Helicopter', 2805), + ('Kitchen appliance', 2797), + ('Personal flotation device', 2781), + ('Swan', 2739), + ('Lamp', 2711), + ('Boot', 2695), + ('Bronze sculpture', 2693), + ('Chicken', 2677), + ('Taxi', 2643), + ('Juice', 2615), + ('Cowboy hat', 2604), + ('Apple', 2600), + ('Tin can', 2590), + ('Necklace', 2564), + ('Ice cream', 2560), + ('Human beard', 2539), + ('Coin', 2536), + ('Candle', 2515), + ('Cart', 2512), + ('High heels', 2441), + ('Weapon', 2433), + ('Handbag', 2406), + ('Penguin', 2396), + ('Rifle', 2352), + ('Violin', 2336), + ('Skull', 2304), + ('Lantern', 2285), + ('Scarf', 2269), + ('Saucer', 2225), + ('Sheep', 2215), + ('Vase', 2189), + ('Lily', 2180), + ('Mug', 2154), + ('Parrot', 2140), + ('Human ear', 2137), + ('Sandal', 2115), + ('Lizard', 2100), + ('Kitchen & dining room table', 2063), + ('Spider', 1977), + ('Coffee', 1974), + ('Goat', 1926), + ('Squirrel', 1922), + ('Cello', 1913), + ('Sushi', 1881), + ('Tortoise', 1876), + ('Pizza', 1870), + ('Studio couch', 1864), + ('Barrel', 1862), + ('Cosmetics', 1841), + ('Moths and butterflies', 1841), + ('Convenience store', 1817), + ('Watch', 1792), + ('Home appliance', 1786), + ('Harbor seal', 1780), + ('Luggage and bags', 1756), + ('Vehicle registration plate', 1754), + ('Shrimp', 1751), + ('Jellyfish', 1730), + ('French fries', 1723), + ('Egg (Food)', 1698), + ('Football', 1697), + ('Musical keyboard', 1683), + ('Falcon', 1674), + ('Candy', 1660), + ('Medical equipment', 1654), + ('Eagle', 1651), + ('Dinosaur', 1634), + ('Surfboard', 1630), + ('Tank', 1628), + ('Grape', 1624), + ('Lion', 1624), + ('Owl', 1622), + ('Ski', 1613), + ('Waste container', 1606), + ('Frog', 1591), + ('Sparrow', 1585), + ('Rabbit', 1581), + ('Pen', 1546), + ('Sea lion', 1537), + ('Spoon', 1521), + ('Sink', 1512), + ('Teddy bear', 1507), + ('Bull', 1495), + ('Sofa bed', 1490), + ('Dragonfly', 1479), + ('Brassiere', 1478), + ('Chest of drawers', 1472), + ('Aircraft', 1466), + ('Human foot', 1463), + ('Pig', 1455), + ('Fork', 1454), + ('Antelope', 1438), + ('Tripod', 1427), + ('Tool', 1424), + ('Cheese', 1422), + ('Lemon', 1397), + ('Hamburger', 1393), + ('Dolphin', 1390), + ('Mirror', 1390), + ('Marine mammal', 1387), + ('Giraffe', 1385), + ('Snake', 1368), + ('Gondola', 1364), + ('Wheelchair', 1360), + ('Piano', 1358), + ('Cupboard', 1348), + ('Banana', 1345), + ('Trumpet', 1335), + ('Lighthouse', 1333), + ('Invertebrate', 1317), + ('Carrot', 1268), + ('Sock', 1260), + ('Tiger', 1241), + ('Camel', 1224), + ('Parachute', 1224), + ('Bathroom accessory', 1223), + ('Earrings', 1221), + ('Headphones', 1218), + ('Skirt', 1198), + ('Skateboard', 1190), + ('Sandwich', 1148), + ('Saxophone', 1141), + ('Goldfish', 1136), + ('Stool', 1104), + ('Traffic light', 1097), + ('Shellfish', 1081), + ('Backpack', 1079), + ('Sea turtle', 1078), + ('Cucumber', 1075), + ('Tea', 1051), + ('Toilet', 1047), + ('Roller skates', 1040), + ('Mule', 1039), + ('Bust', 1031), + ('Broccoli', 1030), + ('Crab', 1020), + ('Oyster', 1019), + ('Cannon', 1012), + ('Zebra', 1012), + ('French horn', 1008), + ('Grapefruit', 998), + ('Whiteboard', 997), + ('Zucchini', 997), + ('Crocodile', 992), + + ('Clock', 960), + ('Wall clock', 958), + + ('Doughnut', 869), + ('Snail', 868), + + ('Baseball glove', 859), + + ('Panda', 830), + ('Tennis racket', 830), + + ('Pear', 652), + + ('Bagel', 617), + ('Oven', 616), + ('Ladybug', 615), + ('Shark', 615), + ('Polar bear', 614), + ('Ostrich', 609), + + ('Hot dog', 473), + ('Microwave oven', 467), + ('Fire hydrant', 20), + ('Stop sign', 20), + ('Parking meter', 20), + ('Bear', 20), + ('Flying disc', 20), + ('Snowboard', 20), + ('Tennis ball', 20), + ('Kite', 20), + ('Baseball bat', 20), + ('Kitchen knife', 20), + ('Knife', 20), + ('Submarine sandwich', 20), + ('Computer mouse', 20), + ('Remote control', 20), + ('Toaster', 20), + ('Sink', 20), + ('Refrigerator', 20), + ('Alarm clock', 20), + ('Wall clock', 20), + ('Scissors', 20), + ('Hair dryer', 20), + ('Toothbrush', 20), + ('Suitcase', 20) +] diff --git a/3DTopia/taming/data/sflckr.py b/3DTopia/taming/data/sflckr.py new file mode 100644 index 0000000000000000000000000000000000000000..91101be5953b113f1e58376af637e43f366b3dee --- /dev/null +++ b/3DTopia/taming/data/sflckr.py @@ -0,0 +1,91 @@ +import os +import numpy as np +import cv2 +import albumentations +from PIL import Image +from torch.utils.data import Dataset + + +class SegmentationBase(Dataset): + def __init__(self, + data_csv, data_root, segmentation_root, + size=None, random_crop=False, interpolation="bicubic", + n_labels=182, shift_segmentation=False, + ): + self.n_labels = n_labels + self.shift_segmentation = shift_segmentation + self.data_csv = data_csv + self.data_root = data_root + self.segmentation_root = segmentation_root + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) + for l in self.image_paths] + } + + size = None if size is not None and size<=0 else size + self.size = size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + assert segmentation.mode == "L", segmentation.mode + segmentation = np.array(segmentation).astype(np.uint8) + if self.shift_segmentation: + # used to support segmentations containing unlabeled==255 label + segmentation = segmentation+1 + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, + mask=segmentation + ) + else: + processed = {"image": image, + "mask": segmentation + } + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class Examples(SegmentationBase): + def __init__(self, size=None, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/sflckr_examples.txt", + data_root="data/sflckr_images", + segmentation_root="data/sflckr_segmentations", + size=size, random_crop=random_crop, interpolation=interpolation) diff --git a/3DTopia/taming/data/utils.py b/3DTopia/taming/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3c3d53cd2b6c72b481b59834cf809d3735b394 --- /dev/null +++ b/3DTopia/taming/data/utils.py @@ -0,0 +1,169 @@ +import collections +import os +import tarfile +import urllib +import zipfile +from pathlib import Path + +import numpy as np +import torch +from taming.data.helper_types import Annotation +from torch._six import string_classes +from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format +from tqdm import tqdm + + +def unpack(path): + if path.endswith("tar.gz"): + with tarfile.open(path, "r:gz") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("tar"): + with tarfile.open(path, "r:") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("zip"): + with zipfile.ZipFile(path, "r") as f: + f.extractall(path=os.path.split(path)[0]) + else: + raise NotImplementedError( + "Unknown file extension: {}".format(os.path.splitext(path)[1]) + ) + + +def reporthook(bar): + """tqdm progress bar for downloads.""" + + def hook(b=1, bsize=1, tsize=None): + if tsize is not None: + bar.total = tsize + bar.update(b * bsize - bar.n) + + return hook + + +def get_root(name): + base = "data/" + root = os.path.join(base, name) + os.makedirs(root, exist_ok=True) + return root + + +def is_prepared(root): + return Path(root).joinpath(".ready").exists() + + +def mark_prepared(root): + Path(root).joinpath(".ready").touch() + + +def prompt_download(file_, source, target_dir, content_dir=None): + targetpath = os.path.join(target_dir, file_) + while not os.path.exists(targetpath): + if content_dir is not None and os.path.exists( + os.path.join(target_dir, content_dir) + ): + break + print( + "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) + ) + if content_dir is not None: + print( + "Or place its content into '{}'.".format( + os.path.join(target_dir, content_dir) + ) + ) + input("Press Enter when done...") + return targetpath + + +def download_url(file_, url, target_dir): + targetpath = os.path.join(target_dir, file_) + os.makedirs(target_dir, exist_ok=True) + with tqdm( + unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ + ) as bar: + urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) + return targetpath + + +def download_urls(urls, target_dir): + paths = dict() + for fname, url in urls.items(): + outpath = download_url(fname, url, target_dir) + paths[fname] = outpath + return paths + + +def quadratic_crop(x, bbox, alpha=1.0): + """bbox is xmin, ymin, xmax, ymax""" + im_h, im_w = x.shape[:2] + bbox = np.array(bbox, dtype=np.float32) + bbox = np.clip(bbox, 0, max(im_h, im_w)) + center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + l = int(alpha * max(w, h)) + l = max(l, 2) + + required_padding = -1 * min( + center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) + ) + required_padding = int(np.ceil(required_padding)) + if required_padding > 0: + padding = [ + [required_padding, required_padding], + [required_padding, required_padding], + ] + padding += [[0, 0]] * (len(x.shape) - 2) + x = np.pad(x, padding, "reflect") + center = center[0] + required_padding, center[1] + required_padding + xmin = int(center[0] - l / 2) + ymin = int(center[1] - l / 2) + return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) + + +def custom_collate(batch): + r"""source: pytorch 1.9.0, only one modification to original code """ + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return custom_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: custom_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(custom_collate(samples) for samples in zip(*batch))) + if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added + return batch # added + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [custom_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/3DTopia/taming/lr_scheduler.py b/3DTopia/taming/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..e598ed120159c53da6820a55ad86b89f5c70c82d --- /dev/null +++ b/3DTopia/taming/lr_scheduler.py @@ -0,0 +1,34 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n): + return self.schedule(n) + diff --git a/3DTopia/taming/models/cond_transformer.py b/3DTopia/taming/models/cond_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c63730fa86ac1b92b37af14c14fb696595b1ab --- /dev/null +++ b/3DTopia/taming/models/cond_transformer.py @@ -0,0 +1,352 @@ +import os, math +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from main import instantiate_from_config +from taming.modules.util import SOSProvider + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class Net2NetTransformer(pl.LightningModule): + def __init__(self, + transformer_config, + first_stage_config, + cond_stage_config, + permuter_config=None, + ckpt_path=None, + ignore_keys=[], + first_stage_key="image", + cond_stage_key="depth", + downsample_cond_size=-1, + pkeep=1.0, + sos_token=0, + unconditional=False, + ): + super().__init__() + self.be_unconditional = unconditional + self.sos_token = sos_token + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + self.init_first_stage_from_ckpt(first_stage_config) + self.init_cond_stage_from_ckpt(cond_stage_config) + if permuter_config is None: + permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} + self.permuter = instantiate_from_config(config=permuter_config) + self.transformer = instantiate_from_config(config=transformer_config) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.downsample_cond_size = downsample_cond_size + self.pkeep = pkeep + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + for k in sd.keys(): + for ik in ignore_keys: + if k.startswith(ik): + self.print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def init_first_stage_from_ckpt(self, config): + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.first_stage_model = model + + def init_cond_stage_from_ckpt(self, config): + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__" or self.be_unconditional: + print(f"Using no cond stage. Assuming the training is intended to be unconditional. " + f"Prepending {self.sos_token} as a sos token.") + self.be_unconditional = True + self.cond_stage_key = self.first_stage_key + self.cond_stage_model = SOSProvider(self.sos_token) + else: + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.cond_stage_model = model + + def forward(self, x, c): + # one step to produce the logits + _, z_indices = self.encode_to_z(x) + _, c_indices = self.encode_to_c(c) + + if self.training and self.pkeep < 1.0: + mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, + device=z_indices.device)) + mask = mask.round().to(dtype=torch.int64) + r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) + a_indices = mask*z_indices+(1-mask)*r_indices + else: + a_indices = z_indices + + cz_indices = torch.cat((c_indices, a_indices), dim=1) + + # target includes all sequence elements (no need to handle first one + # differently because we are conditioning) + target = z_indices + # make the prediction + logits, _ = self.transformer(cz_indices[:, :-1]) + # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: + c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) + quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) + if len(indices.shape) > 2: + indices = indices.view(c.shape[0], -1) + return quant_c, indices + + @torch.no_grad() + def decode_to_img(self, index, zshape): + index = self.permuter(index, reverse=True) + bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) + quant_z = self.first_stage_model.quantize.get_codebook_entry( + index.reshape(-1), shape=bhwc) + x = self.first_stage_model.decode(quant_z) + return x + + @torch.no_grad() + def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): + log = dict() + + N = 4 + if lr_interface: + x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) + else: + x, c = self.get_xc(batch, N) + x = x.to(device=self.device) + c = c.to(device=self.device) + + quant_z, z_indices = self.encode_to_z(x) + quant_c, c_indices = self.encode_to_c(c) + + # create a "half"" sample + z_start_indices = z_indices[:,:z_indices.shape[1]//2] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1]-z_start_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample = self.decode_to_img(index_sample, quant_z.shape) + + # sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) + + # det sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + sample=False, + callback=callback if callback is not None else lambda k: None) + x_sample_det = self.decode_to_img(index_sample, quant_z.shape) + + # reconstruction + x_rec = self.decode_to_img(z_indices, quant_z.shape) + + log["inputs"] = x + log["reconstructions"] = x_rec + + if self.cond_stage_key in ["objects_bbox", "objects_center_points"]: + figure_size = (x_rec.shape[2], x_rec.shape[3]) + dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"] + label_for_category_no = dataset.get_textual_label_for_category_no + plotter = dataset.conditional_builders[self.cond_stage_key].plot + log["conditioning"] = torch.zeros_like(log["reconstructions"]) + for i in range(quant_c.shape[0]): + log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size) + log["conditioning_rec"] = log["conditioning"] + elif self.cond_stage_key != "image": + cond_rec = self.cond_stage_model.decode(quant_c) + if self.cond_stage_key == "segmentation": + # get image from segmentation mask + num_classes = cond_rec.shape[1] + + c = torch.argmax(c, dim=1, keepdim=True) + c = F.one_hot(c, num_classes=num_classes) + c = c.squeeze(1).permute(0, 3, 1, 2).float() + c = self.cond_stage_model.to_rgb(c) + + cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) + cond_rec = F.one_hot(cond_rec, num_classes=num_classes) + cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() + cond_rec = self.cond_stage_model.to_rgb(cond_rec) + log["conditioning_rec"] = cond_rec + log["conditioning"] = c + + log["samples_half"] = x_sample + log["samples_nopix"] = x_sample_nopix + log["samples_det"] = x_sample_det + return log + + def get_input(self, key, batch): + x = batch[key] + if len(x.shape) == 3: + x = x[..., None] + if len(x.shape) == 4: + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + if x.dtype == torch.double: + x = x.float() + return x + + def get_xc(self, batch, N=None): + x = self.get_input(self.first_stage_key, batch) + c = self.get_input(self.cond_stage_key, batch) + if N is not None: + x = x[:N] + c = c[:N] + return x, c + + def shared_step(self, batch, batch_idx): + x, c = self.get_xc(batch) + logits, target = self(x, c) + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) + return loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def configure_optimizers(self): + """ + Following minGPT: + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.transformer.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.transformer.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) + return optimizer diff --git a/3DTopia/taming/models/dummy_cond_stage.py b/3DTopia/taming/models/dummy_cond_stage.py new file mode 100644 index 0000000000000000000000000000000000000000..6e19938078752e09b926a3e749907ee99a258ca0 --- /dev/null +++ b/3DTopia/taming/models/dummy_cond_stage.py @@ -0,0 +1,22 @@ +from torch import Tensor + + +class DummyCondStage: + def __init__(self, conditional_key): + self.conditional_key = conditional_key + self.train = None + + def eval(self): + return self + + @staticmethod + def encode(c: Tensor): + return c, None, (None, None, c) + + @staticmethod + def decode(c: Tensor): + return c + + @staticmethod + def to_rgb(c: Tensor): + return c diff --git a/3DTopia/taming/models/vqgan.py b/3DTopia/taming/models/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a6950baa5f739111cd64c17235dca8be3a5f8037 --- /dev/null +++ b/3DTopia/taming/models/vqgan.py @@ -0,0 +1,404 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from main import instantiate_from_config + +from taming.modules.diffusionmodules.model import Encoder, Decoder +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from taming.modules.vqvae.quantize import GumbelQuantize +from taming.modules.vqvae.quantize import EMAVectorQuantizer + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def training_step(self, batch, batch_idx, optimizer_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQSegmentationModel(VQModel): + def __init__(self, n_labels, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + return opt_ae + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + total_loss = log_dict_ae["val/total_loss"] + self.log("val/total_loss", total_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + return aeloss + + @torch.no_grad() + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + # convert logits to indices + xrec = torch.argmax(xrec, dim=1, keepdim=True) + xrec = F.one_hot(xrec, num_classes=x.shape[1]) + xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + +class VQNoDiscModel(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None + ): + super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, + ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, + colorize_nlabels=colorize_nlabels) + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") + output = pl.TrainResult(minimize=aeloss) + output.log("train/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return output + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") + rec_loss = log_dict_ae["val/rec_loss"] + output = pl.EvalResult(checkpoint_on=rec_loss) + output.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae) + + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=self.learning_rate, betas=(0.5, 0.9)) + return optimizer + + +class GumbelVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + temperature_scheduler_config, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + kl_weight=1e-8, + remap=None, + ): + + z_channels = ddconfig["z_channels"] + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + + self.loss.n_classes = n_embed + self.vocab_size = n_embed + + self.quantize = GumbelQuantize(z_channels, embed_dim, + n_embed=n_embed, + kl_weight=kl_weight, temp_init=1.0, + remap=remap) + + self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def temperature_scheduling(self): + self.quantize.temperature = self.temperature_scheduler(self.global_step) + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode_code(self, code_b): + raise NotImplementedError + + def training_step(self, batch, batch_idx, optimizer_idx): + self.temperature_scheduling() + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + # encode + h = self.encoder(x) + h = self.quant_conv(h) + quant, _, _ = self.quantize(h) + # decode + x_rec = self.decode(quant) + log["inputs"] = x + log["reconstructions"] = x_rec + return log + + +class EMAVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + self.quantize = EMAVectorQuantizer(n_embed=n_embed, + embedding_dim=embed_dim, + beta=0.25, + remap=remap) + def configure_optimizers(self): + lr = self.learning_rate + #Remove self.quantize from parameter list since it is updated via EMA + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] \ No newline at end of file diff --git a/3DTopia/taming/modules/diffusionmodules/model.py b/3DTopia/taming/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c --- /dev/null +++ b/3DTopia/taming/modules/diffusionmodules/model.py @@ -0,0 +1,776 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + diff --git a/3DTopia/taming/modules/discriminator/model.py b/3DTopia/taming/modules/discriminator/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2aaa3110d0a7bcd05de7eca1e45101589ca5af05 --- /dev/null +++ b/3DTopia/taming/modules/discriminator/model.py @@ -0,0 +1,67 @@ +import functools +import torch.nn as nn + + +from taming.modules.util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/3DTopia/taming/modules/losses/__init__.py b/3DTopia/taming/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d09caf9eb805f849a517f1b23503e1a4d6ea1ec5 --- /dev/null +++ b/3DTopia/taming/modules/losses/__init__.py @@ -0,0 +1,2 @@ +from taming.modules.losses.vqperceptual import DummyLoss + diff --git a/3DTopia/taming/modules/losses/lpips.py b/3DTopia/taming/modules/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..a7280447694ffc302a7636e7e4d6183408e0aa95 --- /dev/null +++ b/3DTopia/taming/modules/losses/lpips.py @@ -0,0 +1,123 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +from taming.util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) + diff --git a/3DTopia/taming/modules/losses/segmentation.py b/3DTopia/taming/modules/losses/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5 --- /dev/null +++ b/3DTopia/taming/modules/losses/segmentation.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class BCELoss(nn.Module): + def forward(self, prediction, target): + loss = F.binary_cross_entropy_with_logits(prediction,target) + return loss, {} + + +class BCELossWithQuant(nn.Module): + def __init__(self, codebook_weight=1.): + super().__init__() + self.codebook_weight = codebook_weight + + def forward(self, qloss, target, prediction, split): + bce_loss = F.binary_cross_entropy_with_logits(prediction,target) + loss = bce_loss + self.codebook_weight*qloss + return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/bce_loss".format(split): bce_loss.detach().mean(), + "{}/quant_loss".format(split): qloss.detach().mean() + } diff --git a/3DTopia/taming/modules/losses/vqperceptual.py b/3DTopia/taming/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..c2febd445728479d4cd9aacdb2572cb1f1af04db --- /dev/null +++ b/3DTopia/taming/modules/losses/vqperceptual.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from taming.modules.losses.lpips import LPIPS +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init + + +class DummyLoss(nn.Module): + def __init__(self): + super().__init__() + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train"): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/3DTopia/taming/modules/misc/coord.py b/3DTopia/taming/modules/misc/coord.py new file mode 100644 index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09 --- /dev/null +++ b/3DTopia/taming/modules/misc/coord.py @@ -0,0 +1,31 @@ +import torch + +class CoordStage(object): + def __init__(self, n_embed, down_factor): + self.n_embed = n_embed + self.down_factor = down_factor + + def eval(self): + return self + + def encode(self, c): + """fake vqmodel interface""" + assert 0.0 <= c.min() and c.max() <= 1.0 + b,ch,h,w = c.shape + assert ch == 1 + + c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, + mode="area") + c = c.clamp(0.0, 1.0) + c = self.n_embed*c + c_quant = c.round() + c_ind = c_quant.to(dtype=torch.long) + + info = None, None, c_ind + return c_quant, None, info + + def decode(self, c): + c = c/self.n_embed + c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, + mode="nearest") + return c diff --git a/3DTopia/taming/modules/transformer/mingpt.py b/3DTopia/taming/modules/transformer/mingpt.py new file mode 100644 index 0000000000000000000000000000000000000000..d14b7b68117f4b9f297b2929397cd4f55089334c --- /dev/null +++ b/3DTopia/taming/modules/transformer/mingpt.py @@ -0,0 +1,415 @@ +""" +taken from: https://github.com/karpathy/minGPT/ +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier +""" + +import math +import logging + +import torch +import torch.nn as nn +from torch.nn import functional as F +from transformers import top_k_top_p_filtering + +logger = logging.getLogger(__name__) + + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, vocab_size, block_size, **kwargs): + self.vocab_size = vocab_size + self.block_size = block_size + for k,v in kwargs.items(): + setattr(self, k, v) + + +class GPT1Config(GPTConfig): + """ GPT-1 like network roughly 125M params """ + n_layer = 12 + n_head = 12 + n_embd = 768 + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + mask = torch.tril(torch.ones(config.block_size, + config.block_size)) + if hasattr(config, "n_unmasked"): + mask[:config.n_unmasked, :config.n_unmasked] = 1 + self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + present = torch.stack((k, v)) + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if layer_past is None: + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y, present # TODO: check that this does not break anything + + +class Block(nn.Module): + """ an unassuming Transformer block """ + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), # nice + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x, layer_past=None, return_present=False): + # TODO: check that training still works + if return_present: assert not self.training + # layer past: tuple of length two with B, nh, T, hs + attn, present = self.attn(self.ln1(x), layer_past=layer_past) + + x = x + attn + x = x + self.mlp(self.ln2(x)) + if layer_past is not None or return_present: + return x, present + return x + + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None, targets=None): + # forward the GPT model + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): + # inference only + assert not self.training + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + if past is not None: + assert past_length is not None + past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head + past_shape = list(past.shape) + expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] + assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" + position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector + else: + position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] + + x = self.drop(token_embeddings + position_embeddings) + presents = [] # accumulate over layers + for i, block in enumerate(self.blocks): + x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True) + presents.append(present) + + x = self.ln_f(x) + logits = self.head(x) + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head + + +class DummyGPT(nn.Module): + # for debugging + def __init__(self, add_value=1): + super().__init__() + self.add_value = add_value + + def forward(self, idx): + return idx + self.add_value, None + + +class CodeGPT(nn.Module): + """Takes in semi-embeddings""" + def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Linear(in_channels, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None, targets=None): + # forward the GPT model + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.taming_cinln_f(x) + logits = self.head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + + +#### sampling utils + +def top_k_logits(logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[:, [-1]]] = -float('Inf') + return out + +@torch.no_grad() +def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): + """ + take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in + the sequence, feeding the predictions back into the model each time. Clearly the sampling + has quadratic complexity unlike an RNN that is only linear, and has a finite context window + of block_size, unlike an RNN that has an infinite context window. + """ + block_size = model.get_block_size() + model.eval() + for k in range(steps): + x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed + logits, _ = model(x_cond) + # pluck the logits at the final step and scale by temperature + logits = logits[:, -1, :] / temperature + # optionally crop probabilities to only the top k options + if top_k is not None: + logits = top_k_logits(logits, top_k) + # apply softmax to convert to probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution or take the most likely + if sample: + ix = torch.multinomial(probs, num_samples=1) + else: + _, ix = torch.topk(probs, k=1, dim=-1) + # append to the sequence and continue + x = torch.cat((x, ix), dim=1) + + return x + + +@torch.no_grad() +def sample_with_past(x, model, steps, temperature=1., sample_logits=True, + top_k=None, top_p=None, callback=None): + # x is conditioning + sample = x + cond_len = x.shape[1] + past = None + for n in range(steps): + if callback is not None: + callback(n) + logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) + if past is None: + past = [present] + else: + past.append(present) + logits = logits[:, -1, :] / temperature + if top_k is not None: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + + probs = F.softmax(logits, dim=-1) + if not sample_logits: + _, x = torch.topk(probs, k=1, dim=-1) + else: + x = torch.multinomial(probs, num_samples=1) + # append to the sequence and continue + sample = torch.cat((sample, x), dim=1) + del past + sample = sample[:, cond_len:] # cut conditioning off + return sample + + +#### clustering utils + +class KMeans(nn.Module): + def __init__(self, ncluster=512, nc=3, niter=10): + super().__init__() + self.ncluster = ncluster + self.nc = nc + self.niter = niter + self.shape = (3,32,32) + self.register_buffer("C", torch.zeros(self.ncluster,nc)) + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def is_initialized(self): + return self.initialized.item() == 1 + + @torch.no_grad() + def initialize(self, x): + N, D = x.shape + assert D == self.nc, D + c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random + for i in range(self.niter): + # assign all pixels to the closest codebook element + a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) + # move each codebook element to be the mean of the pixels that assigned to it + c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) + # re-assign any poorly positioned codebook elements + nanix = torch.any(torch.isnan(c), dim=1) + ndead = nanix.sum().item() + print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) + c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters + + self.C.copy_(c) + self.initialized.fill_(1) + + + def forward(self, x, reverse=False, shape=None): + if not reverse: + # flatten + bs,c,h,w = x.shape + assert c == self.nc + x = x.reshape(bs,c,h*w,1) + C = self.C.permute(1,0) + C = C.reshape(1,c,1,self.ncluster) + a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices + return a + else: + # flatten + bs, HW = x.shape + """ + c = self.C.reshape( 1, self.nc, 1, self.ncluster) + c = c[bs*[0],:,:,:] + c = c[:,:,HW*[0],:] + x = x.reshape(bs, 1, HW, 1) + x = x[:,3*[0],:,:] + x = torch.gather(c, dim=3, index=x) + """ + x = self.C[x] + x = x.permute(0,2,1) + shape = shape if shape is not None else self.shape + x = x.reshape(bs, *shape) + + return x diff --git a/3DTopia/taming/modules/transformer/permuter.py b/3DTopia/taming/modules/transformer/permuter.py new file mode 100644 index 0000000000000000000000000000000000000000..0d43bb135adde38d94bf18a7e5edaa4523cd95cf --- /dev/null +++ b/3DTopia/taming/modules/transformer/permuter.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import numpy as np + + +class AbstractPermuter(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + def forward(self, x, reverse=False): + raise NotImplementedError + + +class Identity(AbstractPermuter): + def __init__(self): + super().__init__() + + def forward(self, x, reverse=False): + return x + + +class Subsample(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + C = 1 + indices = np.arange(H*W).reshape(C,H,W) + while min(H, W) > 1: + indices = indices.reshape(C,H//2,2,W//2,2) + indices = indices.transpose(0,2,4,1,3) + indices = indices.reshape(C*4,H//2, W//2) + H = H//2 + W = W//2 + C = C*4 + assert H == W == 1 + idx = torch.tensor(indices.ravel()) + self.register_buffer('forward_shuffle_idx', + nn.Parameter(idx, requires_grad=False)) + self.register_buffer('backward_shuffle_idx', + nn.Parameter(torch.argsort(idx), requires_grad=False)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +def mortonify(i, j): + """(i,j) index to linear morton code""" + i = np.uint64(i) + j = np.uint64(j) + + z = np.uint(0) + + for pos in range(32): + z = (z | + ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | + ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) + ) + return z + + +class ZCurve(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] + idx = np.argsort(reverseidx) + idx = torch.tensor(idx) + reverseidx = torch.tensor(reverseidx) + self.register_buffer('forward_shuffle_idx', + idx) + self.register_buffer('backward_shuffle_idx', + reverseidx) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class SpiralOut(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + assert H == W + size = W + indices = np.arange(size*size).reshape(size,size) + + i0 = size//2 + j0 = size//2-1 + + i = i0 + j = j0 + + idx = [indices[i0, j0]] + step_mult = 0 + for c in range(1, size//2+1): + step_mult += 1 + # steps left + for k in range(step_mult): + i = i - 1 + j = j + idx.append(indices[i, j]) + + # step down + for k in range(step_mult): + i = i + j = j + 1 + idx.append(indices[i, j]) + + step_mult += 1 + if c < size//2: + # step right + for k in range(step_mult): + i = i + 1 + j = j + idx.append(indices[i, j]) + + # step up + for k in range(step_mult): + i = i + j = j - 1 + idx.append(indices[i, j]) + else: + # end reached + for k in range(step_mult-1): + i = i + 1 + idx.append(indices[i, j]) + + assert len(idx) == size*size + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class SpiralIn(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + assert H == W + size = W + indices = np.arange(size*size).reshape(size,size) + + i0 = size//2 + j0 = size//2-1 + + i = i0 + j = j0 + + idx = [indices[i0, j0]] + step_mult = 0 + for c in range(1, size//2+1): + step_mult += 1 + # steps left + for k in range(step_mult): + i = i - 1 + j = j + idx.append(indices[i, j]) + + # step down + for k in range(step_mult): + i = i + j = j + 1 + idx.append(indices[i, j]) + + step_mult += 1 + if c < size//2: + # step right + for k in range(step_mult): + i = i + 1 + j = j + idx.append(indices[i, j]) + + # step up + for k in range(step_mult): + i = i + j = j - 1 + idx.append(indices[i, j]) + else: + # end reached + for k in range(step_mult-1): + i = i + 1 + idx.append(indices[i, j]) + + assert len(idx) == size*size + idx = idx[::-1] + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class Random(nn.Module): + def __init__(self, H, W): + super().__init__() + indices = np.random.RandomState(1).permutation(H*W) + idx = torch.tensor(indices.ravel()) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class AlternateParsing(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + indices = np.arange(W*H).reshape(H,W) + for i in range(1, H, 2): + indices[i, :] = indices[i, ::-1] + idx = indices.flatten() + assert len(idx) == H*W + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +if __name__ == "__main__": + p0 = AlternateParsing(16, 16) + print(p0.forward_shuffle_idx) + print(p0.backward_shuffle_idx) + + x = torch.randint(0, 768, size=(11, 256)) + y = p0(x) + xre = p0(y, reverse=True) + assert torch.equal(x, xre) + + p1 = SpiralOut(2, 2) + print(p1.forward_shuffle_idx) + print(p1.backward_shuffle_idx) diff --git a/3DTopia/taming/modules/util.py b/3DTopia/taming/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8 --- /dev/null +++ b/3DTopia/taming/modules/util.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn + + +def count_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class Labelator(AbstractEncoder): + """Net2Net Interface for Class-Conditional Model""" + def __init__(self, n_classes, quantize_interface=True): + super().__init__() + self.n_classes = n_classes + self.quantize_interface = quantize_interface + + def encode(self, c): + c = c[:,None] + if self.quantize_interface: + return c, None, [None, None, c.long()] + return c + + +class SOSProvider(AbstractEncoder): + # for unconditional training + def __init__(self, sos_token, quantize_interface=True): + super().__init__() + self.sos_token = sos_token + self.quantize_interface = quantize_interface + + def encode(self, x): + # get batch size from data and replicate sos_token + c = torch.ones(x.shape[0], 1)*self.sos_token + c = c.long().to(x.device) + if self.quantize_interface: + return c, None, [None, None, c] + return c diff --git a/3DTopia/taming/modules/vqvae/quantize.py b/3DTopia/taming/modules/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51 --- /dev/null +++ b/3DTopia/taming/modules/vqvae/quantize.py @@ -0,0 +1,445 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/3DTopia/taming/util.py b/3DTopia/taming/util.py new file mode 100644 index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7 --- /dev/null +++ b/3DTopia/taming/util.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") + diff --git a/3DTopia/utility/initialize.py b/3DTopia/utility/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..ac572368414198aef0efd6ab3b1d43f5c22b0331 --- /dev/null +++ b/3DTopia/utility/initialize.py @@ -0,0 +1,13 @@ +import importlib + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) diff --git a/3DTopia/utility/mcubes_from_latent.py b/3DTopia/utility/mcubes_from_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..bcbad2a16ae1e3e3473f6f2341f2058699540bcb --- /dev/null +++ b/3DTopia/utility/mcubes_from_latent.py @@ -0,0 +1,160 @@ +import os +import torch +import argparse +import mcubes +import trimesh +import numpy as np +from tqdm import tqdm +from omegaconf import OmegaConf +from utility.initialize import instantiate_from_config, get_obj_from_str +from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes + +# load model +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str, default=None, required=True) +parser.add_argument("--ckpt", type=str, default=None, required=True) +args = parser.parse_args() +configs = OmegaConf.load(args.config) +device = 'cuda' +vae = get_obj_from_str(configs.model.params.first_stage_config['target'])(**configs.model.params.first_stage_config['params']) +vae = vae.to(device) +vae.eval() + +model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(args.ckpt, map_location='cpu', strict=False, **configs.model.params) +model = model.to(device) + +def extract_mesh(triplane_fname, save_name=None): + latent = torch.from_numpy(np.load(triplane_fname)).to(device) + with torch.no_grad(): + with model.ema_scope(): + triplane = model.decode_first_stage(latent) + + # prepare volumn for marching cube + res = 128 + c_list = torch.linspace(-1.2, 1.2, steps=res) + grid_x, grid_y, grid_z = torch.meshgrid( + c_list, c_list, c_list, indexing='ij' + ) + coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) # 256x256x256x3 + plane_axes = generate_planes() + feats = sample_from_planes( + plane_axes, triplane.reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 + ) + fake_dirs = torch.zeros_like(coords) + fake_dirs[..., 0] = 1 + with torch.no_grad(): + out = vae.triplane_decoder.decoder(feats, fake_dirs) + u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() + del out + + # marching cube + vertices, triangles = mcubes.marching_cubes(u, 8) + min_bound = np.array([-1.2, -1.2, -1.2]) + max_bound = np.array([1.2, 1.2, 1.2]) + vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] + pt_vertices = torch.from_numpy(vertices).to(device) + + # extract vertices color + res_triplane = 256 + # rays_d = torch.from_numpy(-vertices / np.sqrt((vertices ** 2).sum(-1)).reshape(-1, 1)).to(device).unsqueeze(0) + # rays_o = -rays_d * 2.0 + render_kwargs = { + 'depth_resolution': 128, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 128, + 'clamp_mode': 'softplus', + 'white_back': True, + 'det': True + } + # render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane), rays_o, rays_d, render_kwargs, whole_img=False, tvloss=False) + # rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() + # rgb = (rgb * 255).astype(np.uint8) + rays_o_list = [ + np.array([0, 0, 2]), + np.array([0, 0, -2]), + np.array([0, 2, 0]), + np.array([0, -2, 0]), + np.array([2, 0, 0]), + np.array([-2, 0, 0]), + ] + rgb_final = None + diff_final = None + for rays_o in tqdm(rays_o_list): + rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) + rays_d = pt_vertices.reshape(-1, 3) - rays_o + rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) + dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) + + # batch_size = 2**14 + # batch_num = (rays_o.shape[0] // batch_size) + 1 + # rgb_list = [] + # depth_diff_list = [] + # for b in range(batch_num): + # cur_rays_o = rays_o[b * batch_size: (b + 1) * batch_size] + # cur_rays_d = rays_d[b * batch_size: (b + 1) * batch_size] + with torch.no_grad(): + render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane), + rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, + whole_img=False, tvloss=False) + rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() + depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() + depth_diff = np.abs(dist - depth) + + # rgb_list.append(rgb) + # depth_diff_list.append(depth_diff) + + # del render_out + # torch.cuda.empty_cache() + + # rgb = np.concatenate(rgb_list, 0) + # depth_diff = np.concatenate(depth_diff_list, 0) + + if rgb_final is None: + rgb_final = rgb.copy() + diff_final = depth_diff.copy() + + else: + ind = diff_final > depth_diff + rgb_final[ind] = rgb[ind] + diff_final[ind] = depth_diff[ind] + + + # bgr to rgb + rgb_final = np.stack([ + rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] + ], -1) + + # export to ply + mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) + if save_name: + trimesh.exchange.export.export_mesh(mesh, save_name, file_type='ply') + else: + trimesh.exchange.export.export_mesh(mesh, triplane_fname[:-4] + '.ply', file_type='ply') + +# load triplane +# fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/sample_16_0.npy' +# u = np.load(fname) +# triplane_fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/triplane_16_0.npy' +# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt' +# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt_simple' +folder = '/mnt/lustre/hongfangzhou.p/AE3D/log/diff_res32ch8_preprocess_ca_text_new_triplane_96_full_openaimodel_only_cap3d_high_quality_7w/sample_demo_424_prompts_for_demo_30_60_10' +save_folder = folder + '_extract_mesh' +os.makedirs(save_folder, exist_ok=True) +fnames = [f.replace('_sample', 'triplane').replace('mp4', 'npy') for f in os.listdir(folder) if f.startswith('_')] +prompts = [l.strip() for l in open('test/prompts_for_demo_2.txt', 'r').readlines()][30:60] +# fnames = [os.path.join(folder, f) for f in os.listdir(folder) if (f.startswith('triplane') and f.endswith('.npy'))] +fnames = sorted(fnames) + +def extract_number(s): + return int(s.split('_')[-2]) + +def extract_id(s): + return s.split('_')[-1].split('.')[0] + +for fname in fnames: + try: + print(fname) + extract_mesh(os.path.join(folder, fname), os.path.join(save_folder, prompts[extract_number(fname)].replace(' ', '_') + '_' + extract_id(fname) + '.ply')) + except Exception as e: + print(e) diff --git a/3DTopia/utility/triplane_renderer/eg3d_renderer.py b/3DTopia/utility/triplane_renderer/eg3d_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..796c47071d2a77d709cf81fba12323155f115b79 --- /dev/null +++ b/3DTopia/utility/triplane_renderer/eg3d_renderer.py @@ -0,0 +1,685 @@ +import os +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# TriPlane Utils +class MipRayMarcher2(nn.Module): + def __init__(self): + super().__init__() + + def run_forward(self, colors, densities, depths, rendering_options): + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + + if rendering_options['clamp_mode'] == 'softplus': + densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better + else: + assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!" + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = torch.sum(weights * depths_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + # composite_depth = torch.nan_to_num(composite_depth, 0.) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights + + def forward(self, colors, densities, depths, rendering_options): + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + + return composite_rgb, composite_depth, weights + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + + # # ORIGINAL + # N, M, C = coordinates.shape + # xy_coords = coordinates[..., [0, 1]] + # xz_coords = coordinates[..., [0, 2]] + # zx_coords = coordinates[..., [2, 0]] + # return torch.stack([xy_coords, xz_coords, zx_coords], dim=1).reshape(N*3, M, 2) + + # FIXED + N, M, _ = coordinates.shape + xy_coords = coordinates[..., [0, 1]] + yz_coords = coordinates[..., [1, 2]] + zx_coords = coordinates[..., [2, 0]] + return torch.stack([xy_coords, yz_coords, zx_coords], dim=1).reshape(N*3, M, 2) + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + + coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + + output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', padding_mode='zeros', align_corners=False) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class FullyConnectedLayer(nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + # self.weight = torch.nn.Parameter(torch.full([out_features, in_features], np.float32(0))) + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + + +def positional_encoding(positions, freqs): + freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) + pts = (positions[..., None] * freq_bands).reshape( + positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) + pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) + return pts + +# class TriPlane_Decoder(nn.Module): +# def __init__(self, dim=12, width=128): +# super().__init__() +# self.net = torch.nn.Sequential( +# FullyConnectedLayer(dim, width), +# torch.nn.Softplus(), +# FullyConnectedLayer(width, width), +# torch.nn.Softplus(), +# FullyConnectedLayer(width, 1 + 3) +# ) + +# def forward(self, sampled_features, viewdir): +# sampled_features = sampled_features.mean(1) +# x = sampled_features + +# N, M, C = x.shape +# x = x.view(N*M, C) + +# x = self.net(x) +# x = x.view(N, M, -1) +# rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF +# sigma = x[..., 0:1] +# return {'rgb': rgb, 'sigma': sigma} + +class TriPlane_Decoder(nn.Module): + def __init__(self, dim=12, width=128): + super().__init__() + self.net = torch.nn.Sequential( + FullyConnectedLayer(dim, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, 1 + 3) + ) + + # def forward(self, sampled_features, viewdir): + # #ipdb.set_trace() + # sampled_features = sampled_features.mean(1) + # x = sampled_features + + # N, M, C = x.shape + # x = x.view(N*M, C) + + # x = self.net(x) + # x = x.view(N, M, -1) + # rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + # sigma = x[..., 0:1] + # return {'rgb': rgb, 'sigma': sigma} + def forward(self, sampled_features, viewdir): + M = sampled_features.shape[-2] + batch_size = 256 * 256 + num_batches = M // batch_size + if num_batches * batch_size < M: + num_batches += 1 + res = { + 'rgb': [], + 'sigma': [], + } + + for b in range(num_batches): + p = b * batch_size + b_sampled_features = sampled_features[:, :, p:p+batch_size] + b_res = self._forward(b_sampled_features) + res['rgb'].append(b_res['rgb']) + res['sigma'].append(b_res['sigma']) + res['rgb'] = torch.cat(res['rgb'], -2) + res['sigma'] = torch.cat(res['sigma'], -2) + + return res + + def _forward(self, sampled_features): + # N, _, M, C = sampled_features.shape + sampled_features = sampled_features.mean(1) + x = sampled_features + N, M, C = x.shape + x = x.view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + # assert self.sigma_dim + self.c_dim == C + # sigma_features = sampled_features[..., :self.sigma_dim] + # rgb_features = sampled_features[..., -self.c_dim:] + # sigma_features = sigma_features.permute(0, 2, 1, 3).reshape(N * M, self.sigma_dim * 3) + # rgb_features = rgb_features.permute(0, 2, 1, 3).reshape(N * M, self.c_dim * 3) + + # x = torch.cat([self.sigmanet(sigma_features), self.rgbnet(rgb_features)], -1) + # x = x.view(N, M, -1) + # rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + # sigma = x[..., 0:1] + return {'rgb': rgb, 'sigma': sigma} + +class TriPlane_Decoder_PE(nn.Module): + def __init__(self, dim=12, width=128, viewpe=2, feape=2): + super().__init__() + assert viewpe > 0 and feape > 0 + self.viewpe = viewpe + self.feape = feape + # self.densitynet = torch.nn.Sequential( + # FullyConnectedLayer(dim + 2*feape*dim, width), + # torch.nn.Softplus() + # ) + # self.densityout = FullyConnectedLayer(width, 1) + # self.rgbnet = torch.nn.Sequential( + # FullyConnectedLayer(width + 3 + 2 * viewpe * 3, width), + # torch.nn.Softplus(), + # FullyConnectedLayer(width, 3) + # ) + self.net = torch.nn.Sequential( + FullyConnectedLayer(dim+2*feape*dim+3+2*viewpe*3, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, 1 + 3) + ) + + def forward(self, sampled_features,viewdir): + sampled_features = sampled_features.mean(1) + x = sampled_features + N, M, C = x.shape + x = x.view(N*M, C) + viewdir = viewdir.view(N*M, 3) + x_pe = positional_encoding(x, self.feape) + viewdir_pe = positional_encoding(viewdir, self.viewpe) + + x = torch.cat([x, x_pe, viewdir, viewdir_pe], -1) + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + # layer1 = self.densitynet(torch.cat([x, x_pe], -1)) + # sigma = self.densityout(layer1).view(N, M, 1) + # rgb = self.rgbnet(torch.cat([layer1, viewdir, viewdir_pe], -1)).view(N, M, -1) + # rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 + return {'rgb': rgb, 'sigma': sigma} + +class TriPlane_Decoder_Decompose(nn.Module): + def __init__(self, sigma_dim=12, c_dim=12, width=128): + super().__init__() + self.rgbnet = torch.nn.Sequential( + FullyConnectedLayer(c_dim * 3, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, 3) + ) + self.sigmanet = torch.nn.Sequential( + FullyConnectedLayer(sigma_dim * 3, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, width), + torch.nn.Softplus(), + FullyConnectedLayer(width, 1) + ) + self.sigma_dim = sigma_dim + self.c_dim = c_dim + + def forward(self, sampled_features, viewdir): + M = sampled_features.shape[-2] + batch_size = 256 * 256 + num_batches = M // batch_size + if num_batches * batch_size < M: + num_batches += 1 + res = { + 'rgb': [], + 'sigma': [], + } + + for b in range(num_batches): + p = b * batch_size + b_sampled_features = sampled_features[:, :, p:p+batch_size] + b_res = self._forward(b_sampled_features) + res['rgb'].append(b_res['rgb']) + res['sigma'].append(b_res['sigma']) + res['rgb'] = torch.cat(res['rgb'], -2) + res['sigma'] = torch.cat(res['sigma'], -2) + + return res + + def _forward(self, sampled_features): + N, _, M, C = sampled_features.shape + assert self.sigma_dim + self.c_dim == C + sigma_features = sampled_features[..., :self.sigma_dim] + rgb_features = sampled_features[..., -self.c_dim:] + sigma_features = sigma_features.permute(0, 2, 1, 3).reshape(N * M, self.sigma_dim * 3) + rgb_features = rgb_features.permute(0, 2, 1, 3).reshape(N * M, self.c_dim * 3) + + x = torch.cat([self.sigmanet(sigma_features), self.rgbnet(rgb_features)], -1) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + return {'rgb': rgb, 'sigma': sigma} + +class Renderer_TriPlane(nn.Module): + # def __init__(self, rgbnet_dim=18, rgbnet_width=128, viewpe=0, feape=0): + # super(Renderer_TriPlane, self).__init__() + # if viewpe > 0 or feape > 0: + # self.decoder = TriPlane_Decoder_PE(dim=rgbnet_dim//3, width=rgbnet_width, viewpe=viewpe, feape=feape) + # else: + # self.decoder = TriPlane_Decoder(dim=rgbnet_dim//3, width=rgbnet_width) + # self.ray_marcher = MipRayMarcher2() + # self.plane_axes = generate_planes() + + def __init__(self, rgbnet_dim=18, rgbnet_width=128, viewpe=0, feape=0, sigma_dim=0, c_dim=0): + super(Renderer_TriPlane, self).__init__() + if viewpe > 0 and feape > 0: + self.decoder = TriPlane_Decoder_PE(dim=rgbnet_dim//3, width=rgbnet_width, viewpe=viewpe, feape=feape) + elif sigma_dim > 0 and c_dim > 0: + self.decoder = TriPlane_Decoder_Decompose(sigma_dim=sigma_dim, c_dim=c_dim, width=rgbnet_width) + else: + self.decoder = TriPlane_Decoder(dim=rgbnet_dim, width=rgbnet_width) + self.ray_marcher = MipRayMarcher2() + self.plane_axes = generate_planes() + + def forward(self, planes, ray_origins, ray_directions, rendering_options, whole_img=False, tvloss=False): + self.plane_axes = self.plane_axes.to(ray_origins.device) + + ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'], + rendering_options['det']) + + batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape + + # Coarse Pass + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + + + out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options) + colors_coarse = out['rgb'] + densities_coarse = out['sigma'] + colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1]) + densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance, rendering_options['det']) + + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options) + colors_fine = out['rgb'] + densities_fine = out['sigma'] + colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1]) + densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + + # Aggregate + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + else: + rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + + if tvloss: + initial_coordinates = torch.rand((batch_size, 1000, 3), device=planes.device) * 2 - 1 + perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * 0.004 + all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) + projected_coordinates = project_onto_planes(self.plane_axes, all_coordinates).unsqueeze(1) + N, n_planes, C, H, W = planes.shape + _, M, _ = all_coordinates.shape + planes = planes.view(N*n_planes, C, H, W) + output_features = torch.nn.functional.grid_sample(planes, projected_coordinates.float(), mode='bilinear', padding_mode='zeros', align_corners=False).permute(0, 3, 2, 1).reshape(batch_size, n_planes, M, C) + sigma = self.decoder(output_features)['sigma'] + sigma_initial = sigma[:, :sigma.shape[1]//2] + sigma_perturbed = sigma[:, sigma.shape[1]//2:] + TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) + else: + TVloss = None + + # return rgb_final, depth_final, weights.sum(2) + if whole_img: + H = W = int(ray_origins.shape[1] ** 0.5) + rgb_final = rgb_final.permute(0, 2, 1).reshape(-1, 3, H, W).contiguous() + depth_final = depth_final.permute(0, 2, 1).reshape(-1, 1, H, W).contiguous() + depth_final = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min()) + depth_final = depth_final.repeat(1, 3, 1, 1) + # rgb_final = torch.clip(rgb_final, min=0, max=1) + rgb_final = (rgb_final + 1) / 2. + weights = weights.sum(2).reshape(rgb_final.shape[0], rgb_final.shape[2], rgb_final.shape[3]) + return { + 'rgb_marched': rgb_final, + 'depth_final': depth_final, + 'weights': weights, + 'tvloss': TVloss, + } + else: + rgb_final = (rgb_final + 1) / 2. + return { + 'rgb_marched': rgb_final, + 'depth_final': depth_final, + 'tvloss': TVloss, + } + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False, det=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + if det: + depths_coarse += 0.5 * depth_delta[..., None] + else: + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + if det: + depths_coarse += 0.5 * depth_delta + else: + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance, det=False): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance, det=det).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom 0.: + # get intervals between samples + mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) + upper = torch.cat([mids, z_vals[...,-1:]], -1) + lower = torch.cat([z_vals[...,:1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape) + + # Pytest, overwrite u with numpy's fixed random numbers + if pytest: + np.random.seed(0) + t_rand = np.random.rand(*list(z_vals.shape)) + t_rand = torch.Tensor(t_rand) + + z_vals = lower + (upper - lower) * t_rand + + pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] + + +# raw = run_network(pts) + raw = network_query_fn(pts, viewdirs, label,network_fn) + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) + + + ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map} + if retraw: + ret['raw'] = raw + if N_importance > 0: + ret['rgb0'] = rgb_map_0 + ret['disp0'] = disp_map_0 + ret['acc0'] = acc_map_0 + ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] + + for k in ret: + if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()): + print(f"! [Numerical Error] {k} contains nan or inf.") + + return ret + +def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + #ipdb.set_trace() + act_ff=nn.Softplus() + + raw2alpha = lambda raw, dists, act_fn=act_ff: 1.-torch.exp(-act_fn(raw)*dists) + + dists = z_vals[...,1:] - z_vals[...,:-1] + dists = torch.cat([dists, torch.Tensor([1e10]).to(dists.device).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] + + dists = dists * torch.norm(rays_d[...,None,:], dim=-1) + + rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] + noise = 0. + if raw_noise_std > 0.: + noise = torch.randn(raw[...,3].shape) * raw_noise_std + + # Overwrite randomly sampled data if pytest + if pytest: + np.random.seed(0) + noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std + noise = torch.Tensor(noise) + #ipdb.set_trace() + alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] + + #ipdb.set_trace() + weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0],alpha.shape[1], 1)).to(alpha.device), 1.-alpha + 1e-10], -1), -1)[:,:, :-1] + rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] + + depth_map = torch.sum(weights * z_vals, -1) + disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) + acc_map = torch.sum(weights, -1) + + if white_bkgd: + rgb_map = rgb_map + (1.-acc_map[...,None]) + + return rgb_map, disp_map, acc_map, weights, depth_map + + +def get_rays(H, W, K, c2w): + i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' + i = i.t() + j = j.t() + dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1) + # Rotate ray directions from camera frame to the world frame + rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] + # Translate camera frame's origin to the world frame. It is the origin of all rays. + rays_o = c2w[:3,-1].expand(rays_d.shape) + return rays_o, rays_d diff --git a/app.py b/app.py index fde396687d61ef4293c054f93797c08a8999f5b5..38997b1fc168785c5f63bf22b842939aba5de306 100644 --- a/app.py +++ b/app.py @@ -1,43 +1,25 @@ -print("Start importing everything.") - import pytorch_lightning as pl -print("pl") - import os import sys import cv2 import time import json -print("Start importing everything.") import torch -print("Start importing everything.") import mcubes -print("Start importing everything.") import trimesh -print("Start importing everything.") import datetime import argparse import subprocess import numpy as np -print("Start importing everything.") import gradio as gr from tqdm import tqdm -print("Start importing everything.") import imageio.v2 as imageio -print("Start importing everything.") from omegaconf import OmegaConf -print("Start importing everything.") from safetensors.torch import load_file -print("Start importing everything.") from huggingface_hub import hf_hub_download -print("Importing everything done.") - -os.system("git clone https://github.com/3DTopia/3DTopia.git") sys.path.append("3DTopia") -print("Github clone done.") - from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler