diff --git a/3DTopia/.gitignore b/3DTopia/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..0be8b7cbb5e914722cdeb68149788f8ddeab2820
--- /dev/null
+++ b/3DTopia/.gitignore
@@ -0,0 +1,4 @@
+__pycache__
+checkpoints
+results
+tmp
\ No newline at end of file
diff --git a/3DTopia/LICENSE b/3DTopia/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/3DTopia/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/3DTopia/README.md b/3DTopia/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a9db3069cd3c8da9a29aafffb3d43b7644563b93
--- /dev/null
+++ b/3DTopia/README.md
@@ -0,0 +1,65 @@
+
+
+
+
+
3DTopia
+ A two-stage text-to-3D generation model. The first stage uses diffusion model to quickly generate candidates. The second stage refines the assets chosen from the first stage.
+
+https://github.com/3DTopia/3DTopia/assets/23376858/c9716cf0-6e61-4983-82b2-2e8f579bd46c
+
+
+
+## News
+
+[2024/01/18] We release a text-to-3D model 3DTopia!
+
+## 1. Quick Start
+
+### 1.1 Install Environment for this Repository
+We recommend using Anaconda to manage the environment.
+```bash
+conda env create -f environment.yml
+```
+
+### 1.2 Install Second Stage Refiner
+Please refer to [threefiner](https://github.com/3DTopia/threefiner) to install our second stage mesh refiner. We have tested installing both environments together with Pytorch 1.12.0 and CUDA 11.3.
+
+### 1.3 Download Checkpoints \[Optional\]
+We have implemented automatic checkpoint download for both `gradio_demo.py` and `sample_stage1.py`. If you prefer to download manually, you may download checkpoint `3dtopia_diffusion_state_dict.ckpt` or `model.safetensors` from [huggingface](https://huggingface.co/hongfz16/3DTopia).
+
+### Q&A
+- If you encounter this error in the second stage `ImportError: /lib64/libc.so.6: version 'GLIBC_2.25' not found`, try to install a lower version of pymeshlab by `pip install pymeshlab==0.2`.
+
+## 2. Inference
+
+### 2.1 First Stage
+Run the following command to sample `a robot` as the first stage. Results will be located under the folder `results`.
+```bash
+python -u sample_stage1.py --text "a robot" --samples 1 --sampler ddim --steps 200 --cfg_scale 7.5 --seed 0
+```
+
+Arguments:
+- `--ckpt` specifies checkpoint file path;
+- `--test_folder` controls which subfolder to put all the results;
+- `--seed` will fix random seeds; `--sampler` can be set to `ddim` for DDIM sampling (By default, we use 1000 steps DDPM sampling);
+- `--steps` controls sampling steps only for DDIM;
+- `--samples` controls number of samples;
+- `--text` is the input text;
+- `--no_video` and `--no_mcubes` suppress rendering multi-view videos and marching cubes, which are by-default enabled;
+- `--mcubes_res` controls the resolution of the 3D volumn sampled for marching cubes; One can lower this resolution to save graphics memory;
+- `--render_res` controls the resolution of the rendered video;
+
+### 2.2 Second Stage
+There are two steps as the second stage refinement. Here is a simple example. Please refer to [threefiner](https://github.com/3DTopia/threefiner) for more detailed usage.
+```bash
+# step 1
+threefiner sd --mesh results/default/stage1/a_robot_0_0.ply --prompt "a robot" --text_dir --front_dir='-y' --outdir results/default/stage2/ --save a_robot_0_0_sd.glb
+# step 2
+threefiner if2 --mesh results/default/stage2/a_robot_0_0_sd.glb --prompt "a robot" --outdir results/default/stage2/ --save a_robot_0_0_if2.glb
+```
+The resulting mesh can be found at `results/default/stage2/a_robot_0_0_if2.glb`
+
+## 3. Acknowledgement
+We thank the community for building and open-sourcing the foundation of this work. Specifically, we want to thank [EG3D](https://github.com/NVlabs/eg3d), [Stable Diffusion](https://github.com/CompVis/stable-diffusion) for their codes. We also want to thank [Objaverse](https://objaverse.allenai.org) for the wonderful dataset.
diff --git a/3DTopia/assets/3dtopia.jpeg b/3DTopia/assets/3dtopia.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..b4c1f6b718da9519547a2ce7fc766c7ad29247b9
Binary files /dev/null and b/3DTopia/assets/3dtopia.jpeg differ
diff --git a/3DTopia/assets/sample_data/pose/000000.txt b/3DTopia/assets/sample_data/pose/000000.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f06623177e8b5f19d1fb96b5b0d0441ae6048f4f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000000.txt
@@ -0,0 +1 @@
+-0.8414713144302368 -0.5386366844177246 -0.04239124804735184 0.05086996778845787 3.72529200376448e-07 -0.07845887541770935 0.9969174861907959 -1.1963008642196655 -0.5403022766113281 0.838877260684967 0.06602128595113754 -0.07922526448965073 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000001.txt b/3DTopia/assets/sample_data/pose/000001.txt
new file mode 100644
index 0000000000000000000000000000000000000000..eb2e88460ba96961a3509c60a3d1bc363cf61731
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000001.txt
@@ -0,0 +1 @@
+-0.9320390224456787 -0.3607495129108429 -0.03410102799534798 0.04092103987932205 -2.0861622829215776e-07 -0.0941082164645195 0.9955618977546692 -1.1946742534637451 -0.3623576760292053 0.9279026389122009 0.08771242946386337 -0.105255126953125 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000002.txt b/3DTopia/assets/sample_data/pose/000002.txt
new file mode 100644
index 0000000000000000000000000000000000000000..70a984ac99d014a045401e25d51e90476bb2468b
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000002.txt
@@ -0,0 +1 @@
+-0.9854495525360107 -0.1689407229423523 -0.0186510868370533 0.022381475195288658 1.5646213569198153e-07 -0.10973447561264038 0.9939608573913574 -1.1927531957626343 -0.1699671447277069 0.9794984459877014 0.10813764482736588 -0.12976518273353577 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000003.txt b/3DTopia/assets/sample_data/pose/000003.txt
new file mode 100644
index 0000000000000000000000000000000000000000..93692ad3e17af9bca69cbe8be8ab2e7b89d93fae
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000003.txt
@@ -0,0 +1 @@
+-0.9995737075805664 0.028960729017853737 0.0036580152809619904 -0.0043916041031479836 -5.648469141306123e-07 -0.12533310055732727 0.9921146631240845 -1.1905378103256226 0.02919083461165428 0.9916918873786926 0.1252794861793518 -0.15033574402332306 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000004.txt b/3DTopia/assets/sample_data/pose/000004.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dd0b6b5c01faf4ac7bb4d8f549f990f61da59387
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000004.txt
@@ -0,0 +1 @@
+-0.973847508430481 0.2249356210231781 0.03201328590512276 -0.03841566666960716 1.5646217832454568e-07 -0.14090147614479065 0.9900236129760742 -1.188028335571289 0.22720229625701904 0.9641320705413818 0.1372164487838745 -0.16465960443019867 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000005.txt b/3DTopia/assets/sample_data/pose/000005.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e00dfcbb2a62f6f5f93b195e4c7579b7eae30a1d
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000005.txt
@@ -0,0 +1 @@
+-0.9092977046966553 0.4110230505466461 0.06509938091039658 -0.078119657933712 -3.83704957584996e-07 -0.15643461048603058 0.987688422203064 -1.1852262020111084 0.416146457195282 0.8981026411056519 0.14224585890769958 -0.17069458961486816 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000006.txt b/3DTopia/assets/sample_data/pose/000006.txt
new file mode 100644
index 0000000000000000000000000000000000000000..86bbbaef592616de66c5adaaafdad35054f7f257
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000006.txt
@@ -0,0 +1 @@
+-0.8084962368011475 0.5797380208969116 0.1011803075671196 -0.12141657620668411 -2.2351736461700966e-08 -0.17192888259887695 0.9851093292236328 -1.1821314096450806 0.5885012149810791 0.7964572906494141 0.13900375366210938 -0.16680487990379333 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000007.txt b/3DTopia/assets/sample_data/pose/000007.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f89eba6a7e6ad8112a94d48cbb0c89b1089aad2f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000007.txt
@@ -0,0 +1 @@
+-0.6754631996154785 0.7243322134017944 0.13817371428012848 -0.1658085733652115 -1.6391271628890536e-07 -0.18738147616386414 0.9822871685028076 -1.1787446737289429 0.7373935580253601 0.6634989380836487 0.1265692412853241 -0.15188303589820862 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000008.txt b/3DTopia/assets/sample_data/pose/000008.txt
new file mode 100644
index 0000000000000000000000000000000000000000..21c17dfd0c3c174320d8f8cdc59cd1ca04f3f5fc
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000008.txt
@@ -0,0 +1 @@
+-0.5155014991760254 0.8390849828720093 0.17376619577407837 -0.20851942896842957 1.2665987014770508e-07 -0.20278730988502502 0.9792228937149048 -1.175067663192749 0.8568887114524841 0.5047908425331116 0.1045369878411293 -0.12544460594654083 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000009.txt b/3DTopia/assets/sample_data/pose/000009.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ed0c8b01f0cec6488b4dd0d164cf5ef1db14edcd
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000009.txt
@@ -0,0 +1 @@
+-0.3349881172180176 0.91953045129776 0.20553946495056152 -0.24664731323719025 -1.0430810704065152e-07 -0.21814337372779846 0.9759166836738586 -1.1710999011993408 0.9422222375869751 0.3269205689430237 0.07307547330856323 -0.08769046515226364 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000010.txt b/3DTopia/assets/sample_data/pose/000010.txt
new file mode 100644
index 0000000000000000000000000000000000000000..89bb8d731884d8d45904d6cc2c7fb16faecf5516
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000010.txt
@@ -0,0 +1 @@
+-0.1411200612783432 0.9626389741897583 0.23110917210578918 -0.2773309648036957 -1.8998981943241233e-07 -0.2334454208612442 0.9723699688911438 -1.1668438911437988 0.9899925589561462 0.13722087442874908 0.03294399753212929 -0.03953259065747261 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000011.txt b/3DTopia/assets/sample_data/pose/000011.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fd2e8b687234a44c62961db3c50562677acd04ae
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000011.txt
@@ -0,0 +1 @@
+0.05837392061948776 0.9669322967529297 0.24826295673847198 -0.29791900515556335 -1.9557775843281888e-08 -0.2486870288848877 0.9685839414596558 -1.1622999906539917 0.9982947707176208 -0.05654003843665123 -0.014516821131110191 0.01742047443985939 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000012.txt b/3DTopia/assets/sample_data/pose/000012.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ec690a4bf437c985b44db3b268164f33fc4a6152
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000012.txt
@@ -0,0 +1 @@
+0.2555410861968994 0.9325323700904846 0.255111962556839 -0.3061343729496002 -7.450580596923828e-09 -0.26387304067611694 0.964557409286499 -1.1574687957763672 0.9667981863021851 -0.24648404121398926 -0.06743041425943375 0.08091648668050766 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000013.txt b/3DTopia/assets/sample_data/pose/000013.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bdd0f7851359ddbde82bf598055d47fbdcc35225
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000013.txt
@@ -0,0 +1 @@
+0.4425203502178192 0.861151397228241 0.25018739700317383 -0.300225168466568 -2.458690460116486e-07 -0.2789909243583679 0.9602935910224915 -1.1523523330688477 0.8967583179473877 -0.42494940757751465 -0.12345901876688004 0.14815115928649902 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000014.txt b/3DTopia/assets/sample_data/pose/000014.txt
new file mode 100644
index 0000000000000000000000000000000000000000..782685d881a0e2f7f93f64b9f7fd229064cd0ec0
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000014.txt
@@ -0,0 +1 @@
+0.6118577718734741 0.7560015320777893 0.23257626593112946 -0.27909165620803833 2.9802322387695312e-08 -0.29404014348983765 0.955793023109436 -1.146951675415039 0.7909678220748901 -0.584809422492981 -0.1799107939004898 0.21589307487010956 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000015.txt b/3DTopia/assets/sample_data/pose/000015.txt
new file mode 100644
index 0000000000000000000000000000000000000000..acbfd3638de2adc6042114baee3656d9d716c045
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000015.txt
@@ -0,0 +1 @@
+0.756802499294281 0.6216520071029663 0.2019868791103363 -0.2423844039440155 -1.4901159417490817e-08 -0.3090169131755829 0.9510564804077148 -1.1412678956985474 0.6536435484886169 -0.7197620272636414 -0.23386473953723907 0.28063780069351196 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000016.txt b/3DTopia/assets/sample_data/pose/000016.txt
new file mode 100644
index 0000000000000000000000000000000000000000..013a71d628da4714379a1a3461d11c4d7ed897ec
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000016.txt
@@ -0,0 +1 @@
+0.8715758323669434 0.46382859349250793 0.15880392491817474 -0.19056479632854462 -8.940693874137651e-08 -0.3239172697067261 0.9460852146148682 -1.1353023052215576 0.4902608096599579 -0.8245849609375 -0.28231847286224365 0.33878225088119507 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000017.txt b/3DTopia/assets/sample_data/pose/000017.txt
new file mode 100644
index 0000000000000000000000000000000000000000..953c867619ee156d85a30b81f43c46b727c72908
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000017.txt
@@ -0,0 +1 @@
+0.951602041721344 0.2891636788845062 0.10410525649785995 -0.1249263659119606 7.450580596923828e-09 -0.33873775601387024 0.9408808350563049 -1.1290569305419922 0.30733293294906616 -0.8953441381454468 -0.32234352827072144 0.3868124783039093 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000018.txt b/3DTopia/assets/sample_data/pose/000018.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e9b756715528be1e34cd4c39a7e33d9026885e1a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000018.txt
@@ -0,0 +1 @@
+0.9936910271644592 0.1049124225974083 0.039643093943595886 -0.04757172241806984 -0.0 -0.35347482562065125 0.9354441165924072 -1.1225329637527466 0.11215253174304962 -0.9295423626899719 -0.3512447774410248 0.421493798494339 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000019.txt b/3DTopia/assets/sample_data/pose/000019.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0b9c1a7fda4a3ac75568ab601fa815a06f447ab5
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000019.txt
@@ -0,0 +1 @@
+0.9961645603179932 -0.08135451376438141 -0.032210517674684525 0.038652628660202026 1.862645149230957e-09 -0.3681243658065796 0.9297765493392944 -1.1157318353652954 -0.0874989926815033 -0.9262105226516724 -0.3667125105857849 0.4400551915168762 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000020.txt b/3DTopia/assets/sample_data/pose/000020.txt
new file mode 100644
index 0000000000000000000000000000000000000000..656c47232cb297949c8b24916d4026845e409ea5
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000020.txt
@@ -0,0 +1 @@
+0.9589242935180664 -0.2620696723461151 -0.10855279117822647 0.13026338815689087 1.4901161193847656e-08 -0.38268333673477173 0.9238795638084412 -1.108655333518982 -0.28366219997406006 -0.8859305381774902 -0.36696434020996094 0.4403572976589203 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000021.txt b/3DTopia/assets/sample_data/pose/000021.txt
new file mode 100644
index 0000000000000000000000000000000000000000..39cb359918f5d10abbc0d3362dd9a5ecf67ff103
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000021.txt
@@ -0,0 +1 @@
+0.8834545612335205 -0.4299834668636322 -0.18607036769390106 0.22328442335128784 4.470348002882929e-08 -0.39714762568473816 0.9177546501159668 -1.101305603981018 -0.4685167372226715 -0.8107945919036865 -0.35086193680763245 0.4210345447063446 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000022.txt b/3DTopia/assets/sample_data/pose/000022.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c0094b6672bb6244a2cd51b02c37fc8051b5c90a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000022.txt
@@ -0,0 +1 @@
+0.7727645039558411 -0.578461229801178 -0.2611852288246155 0.3134223520755768 -2.980232949312267e-08 -0.41151440143585205 0.9114034175872803 -1.093684196472168 -0.6346929669380188 -0.7043001651763916 -0.31800374388694763 0.38160452246665955 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000023.txt b/3DTopia/assets/sample_data/pose/000023.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6a0887ec9ec4f11fb3d6390abaf7783cecc0ea6f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000023.txt
@@ -0,0 +1 @@
+0.6312665939331055 -0.7017531991004944 -0.33021968603134155 0.3962639272212982 1.4901161193847656e-08 -0.4257790148258209 0.9048272371292114 -1.0857925415039062 -0.775566041469574 -0.5711871385574341 -0.26878002285957336 0.32253631949424744 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000024.txt b/3DTopia/assets/sample_data/pose/000024.txt
new file mode 100644
index 0000000000000000000000000000000000000000..687344b34e28ac322c26e376e7c8393ec4e72365
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000024.txt
@@ -0,0 +1 @@
+0.46460211277008057 -0.7952209711074829 -0.38957479596138 0.467489629983902 -0.0 -0.4399392306804657 0.8980275988578796 -1.0776331424713135 -0.8855195641517639 -0.4172254800796509 -0.20439667999744415 0.2452760487794876 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000025.txt b/3DTopia/assets/sample_data/pose/000025.txt
new file mode 100644
index 0000000000000000000000000000000000000000..388b1e20e526d3a4d0ffc9f326ea1d75c6e7b6c0
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000025.txt
@@ -0,0 +1 @@
+0.279415488243103 -0.8555179834365845 -0.43590813875198364 0.5230898261070251 -7.450580596923828e-09 -0.4539904296398163 0.8910065293312073 -1.069207787513733 -0.960170328617096 -0.24896103143692017 -0.12685194611549377 0.1522223800420761 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000026.txt b/3DTopia/assets/sample_data/pose/000026.txt
new file mode 100644
index 0000000000000000000000000000000000000000..64b583bbe7e213a2b9dd13c20b4cbf27917572d9
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000026.txt
@@ -0,0 +1 @@
+0.08308916538953781 -0.8807101845741272 -0.4663105905056 0.5595741271972656 2.7939665869780583e-07 -0.46792876720428467 0.8837661147117615 -1.060518741607666 -0.9965419769287109 -0.07343138754367828 -0.0388796441257 0.046656012535095215 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000027.txt b/3DTopia/assets/sample_data/pose/000027.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6936a913ed3666bd4f705c4445c2e1692ef56484
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000027.txt
@@ -0,0 +1 @@
+-0.11654932051897049 -0.8703341484069824 -0.4784710705280304 0.5741645693778992 1.1175869474300271e-07 -0.481754332780838 0.8763062953948975 -1.0515679121017456 -0.9931849241256714 0.1021328940987587 0.05614820867776871 -0.06737759709358215 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000028.txt b/3DTopia/assets/sample_data/pose/000028.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d425c87f033bea37fc20adfada19f8afb77895e0
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000028.txt
@@ -0,0 +1 @@
+-0.31154143810272217 -0.8254019618034363 -0.4708009660243988 0.5649611949920654 1.4901161193847656e-08 -0.4954586327075958 0.8686314821243286 -1.0423579216003418 -0.9502326250076294 0.27061471343040466 0.15435591340065002 -0.18522702157497406 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000029.txt b/3DTopia/assets/sample_data/pose/000029.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7886694a8b39dea814cd563fba9620cf1617381a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000029.txt
@@ -0,0 +1 @@
+-0.494113564491272 -0.7483266592025757 -0.44255948066711426 0.5310712456703186 -1.4901161193847656e-07 -0.5090416669845581 0.860741913318634 -1.0328905582427979 -0.8693974018096924 0.4253043532371521 0.25152426958084106 -0.3018290102481842 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000030.txt b/3DTopia/assets/sample_data/pose/000030.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9fc9f1e48c461ab064a28a6a7da8a9682c5d9dbc
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000030.txt
@@ -0,0 +1 @@
+-0.6569865345954895 -0.6428073644638062 -0.3939129412174225 0.4726954400539398 1.4901162970204496e-08 -0.522498607635498 0.8526401519775391 -1.0231680870056152 -0.7539023160934448 0.5601730942726135 0.3432745933532715 -0.41192948818206787 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000031.txt b/3DTopia/assets/sample_data/pose/000031.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b9be832443b12709ae39f75016318dc78ac5077d
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000031.txt
@@ -0,0 +1 @@
+-0.7936679124832153 -0.5136478543281555 -0.3259711265563965 0.39116519689559937 -2.2351741790771484e-07 -0.5358269214630127 0.8443279266357422 -1.0131936073303223 -0.6083512902259827 0.6701160073280334 0.4252684712409973 -0.5103223323822021 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000032.txt b/3DTopia/assets/sample_data/pose/000032.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d3c77ddb53b7661d1f25778b6f0925a437178a7b
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000032.txt
@@ -0,0 +1 @@
+-0.8987078070640564 -0.36654093861579895 -0.2407723218202591 0.28892695903778076 1.639126594454865e-07 -0.5490229725837708 0.8358070850372314 -1.0029689073562622 -0.43854713439941406 0.7511465549468994 0.49341118335723877 -0.5920935273170471 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000033.txt b/3DTopia/assets/sample_data/pose/000033.txt
new file mode 100644
index 0000000000000000000000000000000000000000..13dba8064732a2969c047bd7eee5438ac99303cd
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000033.txt
@@ -0,0 +1 @@
+-0.9679195880889893 -0.2078111618757248 -0.14122851192951202 0.1694747358560562 -1.8626440123625798e-07 -0.56208336353302 0.827080488204956 -0.9924965500831604 -0.2512587904930115 0.8005476593971252 0.5440514087677002 -0.6528617739677429 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000034.txt b/3DTopia/assets/sample_data/pose/000034.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3d2298ea6b094f1da0aaefb40b2cfdda07b2e08e
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000034.txt
@@ -0,0 +1 @@
+-0.9985430240631104 -0.04414258524775505 -0.031028015539050102 0.03722957894206047 -3.3453090964030707e-06 -0.5750053524971008 0.8181495666503906 -0.9817797541618347 -0.053956516087055206 0.8169578313827515 0.5741673707962036 -0.6890010833740234 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000035.txt b/3DTopia/assets/sample_data/pose/000035.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f2aa76ce87ec4f9eb23655cd4f18d0b2bde3f1c6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000035.txt
@@ -0,0 +1 @@
+-0.9893580675125122 0.11771200597286224 0.08552265912294388 -0.10262733697891235 -1.043080928297968e-07 -0.5877853035926819 0.8090168237686157 -0.970820426940918 0.14549998939037323 0.8004074692726135 0.5815301537513733 -0.6978363394737244 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000036.txt b/3DTopia/assets/sample_data/pose/000036.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e6dd057f262657b11171b5e9e263897f9b962595
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000036.txt
@@ -0,0 +1 @@
+-0.940730631351471 0.2712166905403137 0.20363546907901764 -0.24436251819133759 1.6391278734317893e-07 -0.6004202365875244 0.7996845841407776 -0.9596214294433594 0.3391546905040741 0.7522878646850586 0.5648337006568909 -0.6778002977371216 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000037.txt b/3DTopia/assets/sample_data/pose/000037.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e2eb5ccbae9d91c08c51f7896bfb9c434d313b0f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000037.txt
@@ -0,0 +1 @@
+-0.8545990586280823 0.4103184938430786 0.31827494502067566 -0.3819308578968048 -5.513428504855256e-07 -0.6129069924354553 0.7901549935340881 -0.9481860995292664 0.519288182258606 0.6752656102180481 0.5237899422645569 -0.6285476088523865 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000038.txt b/3DTopia/assets/sample_data/pose/000038.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8c67481b528183a88bda479577977d8a912b3acc
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000038.txt
@@ -0,0 +1 @@
+-0.7343969345092773 0.5296936631202698 0.42436450719833374 -0.5092376470565796 -5.9604616353681195e-08 -0.6252426505088806 0.7804303765296936 -0.936516523361206 0.6787199378013611 0.5731458067893982 0.45917630195617676 -0.5510116219520569 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000039.txt b/3DTopia/assets/sample_data/pose/000039.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5539121de371664f1ad0987a401926f9542decdc
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000039.txt
@@ -0,0 +1 @@
+-0.5849173665046692 0.6249576807022095 0.5170100331306458 -0.6204122304916382 -5.960463056453591e-08 -0.6374240517616272 0.770513117313385 -0.9246158599853516 0.811092734336853 0.4506864845752716 0.37284043431282043 -0.4474082589149475 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000040.txt b/3DTopia/assets/sample_data/pose/000040.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7340764fc6dc28476c22bd62d2b8b7c130508936
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000040.txt
@@ -0,0 +1 @@
+-0.4121185541152954 0.692828893661499 0.5917317271232605 -0.7100781798362732 -1.341104507446289e-07 -0.6494481563568115 0.7604060173034668 -0.9124871492385864 0.9111302495002747 0.31337738037109375 0.26764971017837524 -0.32117941975593567 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000041.txt b/3DTopia/assets/sample_data/pose/000041.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5b93615c6eed7a5500d1936ae081317ad38e9ac6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000041.txt
@@ -0,0 +1 @@
+-0.22289007902145386 0.7312408685684204 0.6446757316589355 -0.7736107707023621 -1.1920927533992653e-07 -0.6613120436668396 0.7501108646392822 -0.9001333117485046 0.9748435020446777 0.1671922653913498 0.14739994704723358 -0.1768796592950821 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000042.txt b/3DTopia/assets/sample_data/pose/000042.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fa90ee264a33681b93ab18b934cd7343a2a291d1
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000042.txt
@@ -0,0 +1 @@
+-0.02477543242275715 0.7394039630889893 0.6728058457374573 -0.8073671460151672 -1.7136329688582919e-07 -0.6730126738548279 0.7396309971809387 -0.887557327747345 0.9996929168701172 0.018324699252843857 0.01667424477636814 -0.020009009167551994 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000043.txt b/3DTopia/assets/sample_data/pose/000043.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4dff3f84326c1434af139ddba028c26ef6fa464d
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000043.txt
@@ -0,0 +1 @@
+0.17432667315006256 0.7178065776824951 0.6740651726722717 -0.8088783025741577 -5.215405707303944e-08 -0.6845470666885376 0.7289686799049377 -0.8747623562812805 0.9846878051757812 -0.12707868218421936 -0.11933477967977524 0.14320188760757446 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000044.txt b/3DTopia/assets/sample_data/pose/000044.txt
new file mode 100644
index 0000000000000000000000000000000000000000..676b06b5ef91c0926a8f2f1c8e3104de89e7753d
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000044.txt
@@ -0,0 +1 @@
+0.3664790987968445 0.6681637167930603 0.6474955081939697 -0.7769947052001953 -2.9802322387695312e-08 -0.6959127187728882 0.7181264162063599 -0.8617515563964844 0.9304263591766357 -0.263178288936615 -0.25503745675086975 0.3060450553894043 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000045.txt b/3DTopia/assets/sample_data/pose/000045.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9b3e98714504ae3743baa7e85023c241c9c39d7a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000045.txt
@@ -0,0 +1 @@
+0.5440210700035095 0.5933132171630859 0.5933132171630859 -0.7119758129119873 1.4901161193847656e-08 -0.7071067690849304 0.7071068286895752 -0.8485281467437744 0.8390715718269348 -0.38468101620674133 -0.38468098640441895 0.46161726117134094 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000046.txt b/3DTopia/assets/sample_data/pose/000046.txt
new file mode 100644
index 0000000000000000000000000000000000000000..de5ee9d3b5e7aa35e8cc57db0cad4d8c9e10a118
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000046.txt
@@ -0,0 +1 @@
+0.699874758720398 0.4970666766166687 0.5129328966140747 -0.6155195832252502 -1.4901161193847656e-08 -0.7181262969970703 0.6959128975868225 -0.8350953459739685 0.7142656445503235 -0.4870518445968628 -0.5025984048843384 0.6031181812286377 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000047.txt b/3DTopia/assets/sample_data/pose/000047.txt
new file mode 100644
index 0000000000000000000000000000000000000000..292398d3996a7b4e9c17c561cde7b08df6d13644
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000047.txt
@@ -0,0 +1 @@
+0.8278263807296753 0.38402023911476135 0.408939927816391 -0.49072784185409546 4.470348358154297e-08 -0.728968620300293 0.6845471262931824 -0.8214565515518188 0.5609843134880066 -0.56668621301651 -0.6034594774246216 0.7241514921188354 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000048.txt b/3DTopia/assets/sample_data/pose/000048.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c7017f7c4bd6ab51b19f5fca5a032ff5ff71618a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000048.txt
@@ -0,0 +1 @@
+0.9227754473686218 0.2593373954296112 0.28500810265541077 -0.3420097231864929 -0.0 -0.7396311163902283 0.6730124950408936 -0.8076150417327881 0.38533815741539 -0.6210393905639648 -0.682513415813446 0.81901615858078 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000049.txt b/3DTopia/assets/sample_data/pose/000049.txt
new file mode 100644
index 0000000000000000000000000000000000000000..13928ac07d2057760018824acfeecadb15d7413e
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000049.txt
@@ -0,0 +1 @@
+0.9809362888336182 0.1285126805305481 0.14576902985572815 -0.17492283880710602 -7.450581485102248e-09 -0.7501111030578613 0.66131192445755 -0.7935742139816284 0.19432991743087769 -0.6487048268318176 -0.7358112335205078 0.8829733729362488 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000050.txt b/3DTopia/assets/sample_data/pose/000050.txt
new file mode 100644
index 0000000000000000000000000000000000000000..764270d7e9b8b544b87908ce1c470cc55b4e9473
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000050.txt
@@ -0,0 +1 @@
+0.9999901652336121 -0.0028742607682943344 -0.0033653262071311474 0.004038391634821892 -2.3283061589829401e-10 -0.7604058980941772 0.6494479179382324 -0.7793375253677368 -0.00442569749429822 -0.6494415998458862 -0.7603984475135803 0.9124780297279358 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000051.txt b/3DTopia/assets/sample_data/pose/000051.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c5794890908d9dbabb26e1429bfc337a52782633
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000051.txt
@@ -0,0 +1 @@
+0.979177713394165 -0.12940019369125366 -0.15641796588897705 0.1877015084028244 -7.450580596923828e-09 -0.7705132365226746 0.6374240517616272 -0.7649087905883789 -0.2030048966407776 -0.624151349067688 -0.7544693946838379 0.9053632020950317 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000052.txt b/3DTopia/assets/sample_data/pose/000052.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9e4b173f044faee03f774286251bfe3ef1b096b6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000052.txt
@@ -0,0 +1 @@
+0.9193285703659058 -0.24602729082107544 -0.30709224939346313 0.3685106933116913 -0.0 -0.7804304361343384 0.6252426505088806 -0.7502912282943726 -0.39349088072776794 -0.5748034119606018 -0.7174719572067261 0.8609663844108582 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000053.txt b/3DTopia/assets/sample_data/pose/000053.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4ba7967d04ca6116b6a6de9fdeedaa997f5c4136
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000053.txt
@@ -0,0 +1 @@
+0.8228285908699036 -0.3483087122440338 -0.4490368962287903 0.5388442873954773 -0.0 -0.7901550531387329 0.6129070520401001 -0.7354884147644043 -0.5682896375656128 -0.5043174624443054 -0.6501621007919312 0.7801946997642517 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000054.txt b/3DTopia/assets/sample_data/pose/000054.txt
new file mode 100644
index 0000000000000000000000000000000000000000..18bb9aa82b377f174a1f78a8793725eba9593f80
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000054.txt
@@ -0,0 +1 @@
+0.6935251355171204 -0.4325622320175171 -0.5761187076568604 0.6913425326347351 -1.4901159417490817e-08 -0.7996845245361328 0.6004201769828796 -0.7205043435096741 -0.7204324007034302 -0.4164064824581146 -0.5546013116836548 0.6655217409133911 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000055.txt b/3DTopia/assets/sample_data/pose/000055.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5c795995f90fa82ceaf094217099c6cb0f4417d6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000055.txt
@@ -0,0 +1 @@
+0.5365728735923767 -0.49600499868392944 -0.6826922297477722 0.8192306160926819 4.470348713425665e-08 -0.8090170621871948 0.5877854228019714 -0.7053423523902893 -0.8438540697097778 -0.3153897225856781 -0.4340965449810028 0.5209159851074219 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000056.txt b/3DTopia/assets/sample_data/pose/000056.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d3d4f889b29e5366d8c01f2a081bd03e1b6c6a09
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000056.txt
@@ -0,0 +1 @@
+0.3582288920879364 -0.5368444323539734 -0.7638519406318665 0.9166225790977478 1.490115550950577e-07 -0.8181496262550354 0.575005292892456 -0.6900063157081604 -0.93363356590271 -0.20598354935646057 -0.2930847704410553 0.3517022430896759 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000057.txt b/3DTopia/assets/sample_data/pose/000057.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca2c0b4a3adb1181f77b0cdbf518ff7c8cad69a0
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000057.txt
@@ -0,0 +1 @@
+0.1656038910150528 -0.5543226599693298 -0.8156602382659912 0.9787925481796265 7.4505797087454084e-09 -0.8270803093910217 0.5620836615562439 -0.6744999885559082 -0.9861923456192017 -0.0930832177400589 -0.136967733502388 0.1643616110086441 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000058.txt b/3DTopia/assets/sample_data/pose/000058.txt
new file mode 100644
index 0000000000000000000000000000000000000000..557f498858ea589f14c5ed04bdae4cbb81939903
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000058.txt
@@ -0,0 +1 @@
+-0.03362308070063591 -0.5487122535705566 -0.8353347778320312 1.0024018287658691 4.749744064724837e-08 -0.8358075618743896 0.5490226745605469 -0.6588274240493774 -0.9994345307350159 0.01845986768603325 0.028102422133088112 -0.03372287005186081 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000059.txt b/3DTopia/assets/sample_data/pose/000059.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e4ee736ef15c8763363013a8f4bd528fec551aec
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000059.txt
@@ -0,0 +1 @@
+-0.23150981962680817 -0.5212697386741638 -0.8213896751403809 0.9856675863265991 7.450580596923828e-09 -0.8443279266357422 0.5358267426490784 -0.6429921984672546 -0.9728325605392456 0.12404916435480118 0.1954701989889145 -0.23456427454948425 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000060.txt b/3DTopia/assets/sample_data/pose/000060.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5a1a4f0de215b9b5d45d6c2a1b2dc47848914eae
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000060.txt
@@ -0,0 +1 @@
+-0.42016705870628357 -0.474139541387558 -0.7737252116203308 0.9284706711769104 1.4901151246249356e-07 -0.8526401519775391 0.5224984884262085 -0.6269983053207397 -0.9074463844299316 0.2195366770029068 0.35825133323669434 -0.4299015402793884 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000061.txt b/3DTopia/assets/sample_data/pose/000061.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9d93876315c8f72cb7560190146125be50bd9318
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000061.txt
@@ -0,0 +1 @@
+-0.5920736193656921 -0.4102281630039215 -0.6936582326889038 0.8323898315429688 -2.980232594040899e-08 -0.8607421517372131 0.5090413093566895 -0.6108497381210327 -0.8058839440345764 0.3013899028301239 0.5096226930618286 -0.6115471720695496 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000062.txt b/3DTopia/assets/sample_data/pose/000062.txt
new file mode 100644
index 0000000000000000000000000000000000000000..eff3056cc60ea6abf104750d715c890dafd3ebef
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000062.txt
@@ -0,0 +1 @@
+-0.7403759360313416 -0.3330437242984772 -0.5838878750801086 0.7006657123565674 1.043080928297968e-07 -0.8686315417289734 0.4954584836959839 -0.5945504903793335 -0.6721928119659424 0.36682555079460144 0.6431138515472412 -0.7717366218566895 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000063.txt b/3DTopia/assets/sample_data/pose/000063.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dd26b45880fdfe87414665a9f4e62e9c8c63763a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000063.txt
@@ -0,0 +1 @@
+-0.8591620922088623 -0.2465151995420456 -0.44840940833091736 0.5380914807319641 4.4703490686970326e-08 -0.8763067126274109 0.4817536771297455 -0.5781044363975525 -0.5117037892341614 0.4139043986797333 0.7528894543647766 -0.9034671783447266 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000064.txt b/3DTopia/assets/sample_data/pose/000064.txt
new file mode 100644
index 0000000000000000000000000000000000000000..132633e40f4c478d77778bba5642750d3711fd92
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000064.txt
@@ -0,0 +1 @@
+-0.9436956644058228 -0.15479804575443268 -0.29236292839050293 0.35083532333374023 -1.043081283569336e-07 -0.8837655782699585 0.46792975068092346 -0.5615156292915344 -0.3308148980140686 0.4415833055973053 0.8340057730674744 -1.0008068084716797 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000065.txt b/3DTopia/assets/sample_data/pose/000065.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f09ada342d0fc5da8de63ee7fba3456573a95c2b
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000065.txt
@@ -0,0 +1 @@
+-0.9906071424484253 -0.06207740679383278 -0.12183358520269394 0.146200492978096 8.195635103902532e-08 -0.891006588935852 0.45399045944213867 -0.5447887182235718 -0.13673707842826843 0.44972628355026245 0.8826374411582947 -1.0591652393341064 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000066.txt b/3DTopia/assets/sample_data/pose/000066.txt
new file mode 100644
index 0000000000000000000000000000000000000000..64173ae9e839c86a2a67043f6f5280963d5ac9d7
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000066.txt
@@ -0,0 +1 @@
+-0.9980266094207764 0.027624979615211487 0.0563855841755867 -0.06766645610332489 -1.7657868056630832e-06 -0.8980275988578796 0.4399391710758209 -0.5279269218444824 0.06278911978006363 0.43907099962234497 0.8962554335594177 -1.0755064487457275 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000067.txt b/3DTopia/assets/sample_data/pose/000067.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b1cd5fefa04afc735826d918a8ab7b65a4e1b4c4
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000067.txt
@@ -0,0 +1 @@
+-0.9656579494476318 0.11062507331371307 0.23508931696414948 -0.28210771083831787 -4.023314090773056e-07 -0.9048269987106323 0.4257793426513672 -0.5109351277351379 0.2598170340061188 0.4111570417881012 0.8737534880638123 -1.0485038757324219 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000068.txt b/3DTopia/assets/sample_data/pose/000068.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f860f797e8891165b6cc501ef4a61a21e33144c6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000068.txt
@@ -0,0 +1 @@
+-0.894791305065155 0.18373513221740723 0.40692755579948425 -0.48831334710121155 -2.682209014892578e-07 -0.9114032983779907 0.41151440143585205 -0.4938172996044159 0.4464847445487976 0.36821937561035156 0.8155156970024109 -0.9786188006401062 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000069.txt b/3DTopia/assets/sample_data/pose/000069.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ec94ae96b22c3a416e27d5794ab8046f8a6d8633
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000069.txt
@@ -0,0 +1 @@
+-0.7882521152496338 0.24438586831092834 0.5647425055503845 -0.6776911020278931 2.9802318834981634e-08 -0.9177546501159668 0.39714789390563965 -0.47657743096351624 0.6153523921966553 0.3130526840686798 0.7234220504760742 -0.8681064248085022 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000070.txt b/3DTopia/assets/sample_data/pose/000070.txt
new file mode 100644
index 0000000000000000000000000000000000000000..10ca2355685da194750f3114d96886af2ae6e946
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000070.txt
@@ -0,0 +1 @@
+-0.6502879858016968 0.29071998596191406 0.7018599510192871 -0.8422322273254395 -2.9802318834981634e-08 -0.9238795638084412 0.38268351554870605 -0.45922014117240906 0.7596877217292786 0.24885447323322296 0.6007877588272095 -0.7209452986717224 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000071.txt b/3DTopia/assets/sample_data/pose/000071.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1d086455038440042ad1f62d0178bee54bb029e1
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000071.txt
@@ -0,0 +1 @@
+-0.48639875650405884 0.3216439187526703 0.8123796582221985 -0.9748559594154358 -7.450575623124678e-08 -0.9297764897346497 0.36812451481819153 -0.4417493939399719 0.873736560344696 0.1790553480386734 0.4522421360015869 -0.5426904559135437 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000072.txt b/3DTopia/assets/sample_data/pose/000072.txt
new file mode 100644
index 0000000000000000000000000000000000000000..253bdb8edec9e4eb40174ab5526043a7a296708e
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000072.txt
@@ -0,0 +1 @@
+-0.30311834812164307 0.33684486150741577 0.8914337754249573 -1.069720983505249 -9.68574909165909e-08 -0.9354440569877625 0.35347482562065125 -0.4241698682308197 0.95295250415802 0.10714472085237503 0.2835502326488495 -0.3402603268623352 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000073.txt b/3DTopia/assets/sample_data/pose/000073.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1065ec11a970cb433f82159667976f17b869b7ad
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000073.txt
@@ -0,0 +1 @@
+-0.10775362700223923 0.3367651700973511 0.9354028105735779 -1.122483253479004 1.1175870895385742e-08 -0.9408808946609497 0.33873745799064636 -0.40648552775382996 0.9941775798797607 0.03650019317865372 0.10138332843780518 -0.1216600239276886 -0.0 -0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000074.txt b/3DTopia/assets/sample_data/pose/000074.txt
new file mode 100644
index 0000000000000000000000000000000000000000..936daa35e0a7e2920faf188120d95254364af547
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000074.txt
@@ -0,0 +1 @@
+0.09190679341554642 0.3225465416908264 0.9420811533927917 -1.1304973363876343 -7.450580596923828e-09 -0.9460853934288025 0.3239174783229828 -0.3887008726596832 0.9957676529884338 -0.02977021597325802 -0.08695167303085327 0.10434205830097198 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000075.txt b/3DTopia/assets/sample_data/pose/000075.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6459bfbf10f90125316735b8fb2ed5930c17c0bb
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000075.txt
@@ -0,0 +1 @@
+0.2879031300544739 0.2959333658218384 0.910788357257843 -1.0929460525512695 -0.0 -0.9510565400123596 0.3090173006057739 -0.37082037329673767 0.9576596021652222 -0.08896704763174057 -0.27381211519241333 0.3285748362541199 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000076.txt b/3DTopia/assets/sample_data/pose/000076.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e43d0f90f4fdaa1ca35987f4ce7752e58cd5946a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000076.txt
@@ -0,0 +1 @@
+0.4724216163158417 0.25915926694869995 0.8424096703529358 -1.0108915567398071 -5.960463056453591e-08 -0.9557929039001465 0.2940405011177063 -0.3528483808040619 0.8813725709915161 -0.13891111314296722 -0.45153722167015076 0.5418452024459839 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000077.txt b/3DTopia/assets/sample_data/pose/000077.txt
new file mode 100644
index 0000000000000000000000000000000000000000..662781cae983bf3ae319dbed50dede17db50e6db
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000077.txt
@@ -0,0 +1 @@
+0.6381067037582397 0.2148085981607437 0.7393761277198792 -0.8872514367103577 -1.4901159417490817e-08 -0.9602935910224915 0.27899107336997986 -0.3347893953323364 0.7699478268623352 -0.17802608013153076 -0.6127697825431824 0.7353238463401794 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000078.txt b/3DTopia/assets/sample_data/pose/000078.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8e2d7a807fdfeb4e3023d3a797d0b608bfcd0a9a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000078.txt
@@ -0,0 +1 @@
+0.7783521413803101 0.1656668782234192 0.6055761575698853 -0.726691484451294 7.450580596923828e-09 -0.964557409286499 0.26387304067611694 -0.31664761900901794 0.6278279423713684 -0.20538613200187683 -0.7507652640342712 0.9009183049201965 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000079.txt b/3DTopia/assets/sample_data/pose/000079.txt
new file mode 100644
index 0000000000000000000000000000000000000000..24ba57c4edabd091941cd2a2fb2fda2b1cf94bcb
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000079.txt
@@ -0,0 +1 @@
+0.8875669240951538 0.1145661398768425 0.4462054967880249 -0.5354465842247009 7.4505797087454084e-09 -0.9685830473899841 0.2486899495124817 -0.29842785000801086 0.4606785774230957 -0.22072899341583252 -0.8596823811531067 1.0316189527511597 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000080.txt b/3DTopia/assets/sample_data/pose/000080.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b74aa513a77caee515d08792c10e1422e2b7c072
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000080.txt
@@ -0,0 +1 @@
+0.9613975286483765 0.0642356276512146 0.26756060123443604 -0.3210725784301758 -0.0 -0.9723699688911438 0.233445405960083 -0.2801344394683838 0.27516335248947144 -0.22443383932113647 -0.9348340034484863 1.1218007802963257 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000081.txt b/3DTopia/assets/sample_data/pose/000081.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e23aaf83ca0b1a6f9823145f9d23729ecc6aa851
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000081.txt
@@ -0,0 +1 @@
+0.9969000816345215 0.017163122072815895 0.07678337395191193 -0.09214004129171371 1.862645371275562e-09 -0.9759168028831482 0.2181432992219925 -0.26177191734313965 0.078678198158741 -0.2174670696258545 -0.9728915691375732 1.16746985912323 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000082.txt b/3DTopia/assets/sample_data/pose/000082.txt
new file mode 100644
index 0000000000000000000000000000000000000000..884d4a666c7ff8885c74f4ff22f123a25e86cdb5
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000082.txt
@@ -0,0 +1 @@
+0.9926593899726868 -0.024525828659534454 -0.1184307411313057 0.14211684465408325 -0.0 -0.9792227745056152 0.20278730988502502 -0.2433447241783142 -0.12094360589981079 -0.20129872858524323 -0.972034752368927 1.1664414405822754 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000083.txt b/3DTopia/assets/sample_data/pose/000083.txt
new file mode 100644
index 0000000000000000000000000000000000000000..22b28a7dc788c0d18d7b8fb27841332257e4a5db
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000083.txt
@@ -0,0 +1 @@
+0.9488444924354553 -0.05916447937488556 -0.310151070356369 0.372181236743927 3.725290742551124e-09 -0.9822872877120972 0.187381312251091 -0.22485758364200592 -0.3157437741756439 -0.17779573798179626 -0.932037889957428 1.1184452772140503 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000084.txt b/3DTopia/assets/sample_data/pose/000084.txt
new file mode 100644
index 0000000000000000000000000000000000000000..328296c9245ecf6dd9f22d0f4807869a6661795c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000084.txt
@@ -0,0 +1 @@
+0.8672021627426147 -0.08561316877603531 -0.4905413091182709 0.588649570941925 -0.0 -0.9851093292236328 0.17192910611629486 -0.20631492137908936 -0.49795621633529663 -0.1490972936153412 -0.8542889356613159 1.0251468420028687 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000085.txt b/3DTopia/assets/sample_data/pose/000085.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1d753f8a18d6c2b1704e5c8f395d2318c470f194
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000085.txt
@@ -0,0 +1 @@
+0.7509872913360596 -0.10329629480838776 -0.6521871089935303 0.7826245427131653 -0.0 -0.9876883625984192 0.15643447637557983 -0.1877213567495346 -0.6603167057037354 -0.11748029291629791 -0.7417413592338562 0.8900896906852722 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000086.txt b/3DTopia/assets/sample_data/pose/000086.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dcfa47155ee42f92e895cff4f456a09cb13b9a5c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000086.txt
@@ -0,0 +1 @@
+0.6048324704170227 -0.11220713704824448 -0.7884079813957214 0.946089506149292 4.470348002882929e-08 -0.9900237321853638 0.14090131223201752 -0.16908152401447296 -0.7963526844978333 -0.08522169291973114 -0.5987984538078308 0.718558669090271 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000087.txt b/3DTopia/assets/sample_data/pose/000087.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e3874f74764b3d722fc17ec2887449e02ed64866
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000087.txt
@@ -0,0 +1 @@
+0.43456563353538513 -0.1128801479935646 -0.8935384154319763 1.0722460746765137 3.725290298461914e-09 -0.9921147227287292 0.12533323466777802 -0.15039989352226257 -0.9006401896476746 -0.05446551740169525 -0.43113893270492554 0.5173667073249817 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000088.txt b/3DTopia/assets/sample_data/pose/000088.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7f91baf2142506bb58bb4adc9b969d367b20358e
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000088.txt
@@ -0,0 +1 @@
+0.24697357416152954 -0.10633499920368195 -0.9631701111793518 1.15580415725708 1.8626447051417472e-09 -0.9939608573913574 0.10973432660102844 -0.1316811591386795 -0.9690220952033997 -0.02710147760808468 -0.24548210203647614 0.29457858204841614 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000089.txt b/3DTopia/assets/sample_data/pose/000089.txt
new file mode 100644
index 0000000000000000000000000000000000000000..01730422d0cebff267fe7caaa136318f455b21e4
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000089.txt
@@ -0,0 +1 @@
+0.049535539001226425 -0.09399279206991196 -0.9943397641181946 1.1932077407836914 6.51925802230835e-09 -0.9955620169639587 0.09410832822322845 -0.11292997747659683 -0.9987723231315613 -0.0046617123298347 -0.04931569844484329 0.05917895957827568 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000090.txt b/3DTopia/assets/sample_data/pose/000090.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b60a7c97118e2848dcfe298ebfd1c23b7bcade64
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000090.txt
@@ -0,0 +1 @@
+-0.14987720549106598 -0.07757285982370377 -0.98565673828125 1.1827882528305054 -0.0 -0.9969173669815063 0.07845908403396606 -0.09415092319250107 -0.9887046217918396 0.011759229004383087 0.14941516518592834 -0.1792982518672943 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000091.txt b/3DTopia/assets/sample_data/pose/000091.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a14b8cf84c53f5acc76d1b80f245aadf6c5dad75
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000091.txt
@@ -0,0 +1 @@
+-0.34331491589546204 -0.058974120765924454 -0.937366783618927 1.1248406171798706 3.7252885221050747e-09 -0.9980267882347107 0.06279051303863525 -0.0753486305475235 -0.939220130443573 0.021556934341788292 0.34263747930526733 -0.4111650288105011 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000092.txt b/3DTopia/assets/sample_data/pose/000092.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2a1e764a5cd4a1ce03586527814eeeafc2d91e99
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000092.txt
@@ -0,0 +1 @@
+-0.5230658054351807 -0.04014846310019493 -0.8513460755348206 1.0216155052185059 3.7252898543727042e-09 -0.9988899230957031 0.04710644856095314 -0.05652773752808571 -0.8522922396659851 0.02463977038860321 0.5224851369857788 -0.6269820332527161 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000093.txt b/3DTopia/assets/sample_data/pose/000093.txt
new file mode 100644
index 0000000000000000000000000000000000000000..953d8bb0c4d870bdba217954835590a2acfb9e2f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000093.txt
@@ -0,0 +1 @@
+-0.6819634437561035 -0.022973379120230675 -0.7310251593589783 0.8772302269935608 -9.313223969797946e-09 -0.9995065331459045 0.03141074627637863 -0.03769290819764137 -0.7313860654830933 0.02142098918557167 0.6816269159317017 -0.8179523944854736 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000094.txt b/3DTopia/assets/sample_data/pose/000094.txt
new file mode 100644
index 0000000000000000000000000000000000000000..077a7907fa1c1cc7701873b81a59f27a86591ddd
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000094.txt
@@ -0,0 +1 @@
+-0.8136734962463379 -0.009131004102528095 -0.5812501311302185 0.6975001096725464 -4.190950253502024e-09 -0.9998766779899597 0.015707319602370262 -0.018848778679966927 -0.5813218355178833 0.012780634686350822 0.8135731816291809 -0.9762880206108093 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000095.txt b/3DTopia/assets/sample_data/pose/000095.txt
new file mode 100644
index 0000000000000000000000000000000000000000..481ab52b37e4e23f1c5de5c779c81f73e477e395
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000095.txt
@@ -0,0 +1 @@
+-0.9129449129104614 -1.2935611884759817e-15 -0.4080818295478821 0.4896984398365021 4.2351617070456256e-22 -1.0 3.169856057190978e-15 -3.8038288779917735e-15 -0.4080818295478821 2.8939047931950254e-15 0.9129449129104614 -1.0955342054367065 0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000096.txt b/3DTopia/assets/sample_data/pose/000096.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7774578fccfa2776f29df14063d1b9806ee9a20c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000096.txt
@@ -0,0 +1 @@
+-0.9758200645446777 0.003433206817135215 -0.21854621171951294 0.2622557282447815 -8.381896066111949e-09 -0.9998766183853149 -0.015707315877079964 0.018848782405257225 -0.2185731828212738 -0.015327518805861473 0.9756997227668762 -1.1708403825759888 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000097.txt b/3DTopia/assets/sample_data/pose/000097.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3e6fdbaa06f2f98f1f12b4a4cae0e24ad2a52202
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000097.txt
@@ -0,0 +1 @@
+-0.9997926354408264 0.0006389844347722828 -0.020337846130132675 0.024408958852291107 1.5809190756499447e-07 -0.9995064735412598 -0.03141075372695923 0.037692904472351074 -0.02034788206219673 -0.03140425682067871 0.9992992877960205 -1.1991593837738037 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000098.txt b/3DTopia/assets/sample_data/pose/000098.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c760319b66cb65b2e30cd921736ef9cc5749cf4c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000098.txt
@@ -0,0 +1 @@
+-0.9839062690734863 -0.00841712299734354 0.17848443984985352 -0.2141815721988678 1.3038505386475663e-08 -0.9988898634910583 -0.04710642620921135 0.05652773752808571 0.1786828190088272 -0.04634832963347435 0.9828140139579773 -1.1793771982192993 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000099.txt b/3DTopia/assets/sample_data/pose/000099.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2d49a8f5085e4de83c1a31dacd59e0f782855080
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000099.txt
@@ -0,0 +1 @@
+-0.9287950992584229 -0.023269735276699066 0.36986207962036133 -0.4438343644142151 -1.4901159417490817e-08 -0.9980266094207764 -0.06279050558805466 0.0753486156463623 0.3705933690071106 -0.05831952393054962 0.9269623160362244 -1.1123547554016113 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000100.txt b/3DTopia/assets/sample_data/pose/000100.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f9904c7007ec8ce29c9d0ef9d4b20141ed8e3a74
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000100.txt
@@ -0,0 +1 @@
+-0.8366557955741882 -0.04297434538602829 0.5460407733917236 -0.6552488803863525 -3.725291186640334e-09 -0.9969173669815063 -0.07845912128686905 0.09415092319250107 0.5477291941642761 -0.06564325839281082 0.8340766429901123 -1.000891923904419 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000101.txt b/3DTopia/assets/sample_data/pose/000101.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7e85d82ccedb21034bc93d929d56b5b0f9cc4698
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000101.txt
@@ -0,0 +1 @@
+-0.7111608982086182 -0.06616085022687912 0.6999088525772095 -0.839890718460083 -1.4901154088420299e-08 -0.9955620169639587 -0.09410828351974487 0.11292998492717743 0.7030289173126221 -0.06692616641521454 0.7080047130584717 -0.8496062159538269 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000102.txt b/3DTopia/assets/sample_data/pose/000102.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f52b5d44b44d916edc72ac097a931fb20b69aab8
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000102.txt
@@ -0,0 +1 @@
+-0.557314932346344 -0.09111247956752777 0.8252866864204407 -0.9903444051742554 7.450577044210149e-09 -0.993960976600647 -0.10973427444696426 0.1316811740398407 0.8303009271621704 -0.06115657836198807 0.5539493560791016 -0.6647393703460693 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000103.txt b/3DTopia/assets/sample_data/pose/000103.txt
new file mode 100644
index 0000000000000000000000000000000000000000..628a79aa312adfdc54645db2f3186e6876244f3c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000103.txt
@@ -0,0 +1 @@
+-0.3812505602836609 -0.115867018699646 0.9171820282936096 -1.1006184816360474 -3.725290298461914e-09 -0.9921146631240845 -0.12533321976661682 0.15039989352226257 0.9244717359542847 -0.04778335988521576 0.3782442808151245 -0.4538930654525757 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000104.txt b/3DTopia/assets/sample_data/pose/000104.txt
new file mode 100644
index 0000000000000000000000000000000000000000..511f249cc3bb9f7cd2c1952ddffc165697aae457
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000104.txt
@@ -0,0 +1 @@
+-0.1899867206811905 -0.13833492994308472 0.971992015838623 -1.1663905382156372 7.450580596923828e-09 -0.990023672580719 -0.14090122282505035 0.16908153891563416 0.9817867279052734 -0.026769354939460754 0.18809135258197784 -0.22570960223674774 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000105.txt b/3DTopia/assets/sample_data/pose/000105.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7bf130e319aebed4e688318b892f049659653bb3
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000105.txt
@@ -0,0 +1 @@
+0.008851335383951664 -0.15642881393432617 0.9876496195793152 -1.185179591178894 1.7462300494486271e-09 -0.9876883029937744 -0.15643493831157684 0.1877213567495346 0.9999608993530273 0.0013846629299223423 -0.008742359466850758 0.01049080304801464 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000106.txt b/3DTopia/assets/sample_data/pose/000106.txt
new file mode 100644
index 0000000000000000000000000000000000000000..dd2446d7e26c97486770c44c0c16aac0b33be82f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000106.txt
@@ -0,0 +1 @@
+0.20733638107776642 -0.16819317638874054 0.9637025594711304 -1.1564432382583618 -7.450580596923828e-09 -0.9851093292236328 -0.171929270029068 0.20631493628025055 0.9782696962356567 0.035647179931402206 -0.20424900949001312 0.24509884417057037 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000107.txt b/3DTopia/assets/sample_data/pose/000107.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c63eef9320d8658f8550ed8aba13409f58845b0c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000107.txt
@@ -0,0 +1 @@
+0.39755570888519287 -0.1719369888305664 0.9013251662254333 -1.081590175628662 -2.2351741790771484e-08 -0.9822872877120972 -0.1873813271522522 0.22485756874084473 0.9175780415534973 0.07449448853731155 -0.39051389694213867 0.4686166048049927 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000108.txt b/3DTopia/assets/sample_data/pose/000108.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d18b02f5b68c2656ccca9e3fb434e603868fe18e
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000108.txt
@@ -0,0 +1 @@
+0.5719255805015564 -0.16634754836559296 0.8032617568969727 -0.9639140367507935 2.9802318834981634e-08 -0.9792227745056152 -0.20278732478618622 0.2433447390794754 0.8203054070472717 0.11597928404808044 -0.5600425601005554 0.672051191329956 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000109.txt b/3DTopia/assets/sample_data/pose/000109.txt
new file mode 100644
index 0000000000000000000000000000000000000000..166a0b7be54186c836c494541344aa9fa9fd3c31
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000109.txt
@@ -0,0 +1 @@
+0.7234946489334106 -0.1505908966064453 0.6737046241760254 -0.8084455132484436 -7.450580596923828e-09 -0.9759168028831482 -0.21814334392547607 0.26177188754081726 0.6903300285339355 0.1578255444765091 -0.7060705423355103 0.8472847938537598 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000110.txt b/3DTopia/assets/sample_data/pose/000110.txt
new file mode 100644
index 0000000000000000000000000000000000000000..70e94906e0987d8359ab076191051939d4a136df
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000110.txt
@@ -0,0 +1 @@
+0.8462203741073608 -0.12438744306564331 0.5181108713150024 -0.6217329502105713 -0.0 -0.972369909286499 -0.2334454357624054 0.2801344394683838 0.5328330993652344 0.19754627346992493 -0.8228392004966736 0.9874071478843689 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000111.txt b/3DTopia/assets/sample_data/pose/000111.txt
new file mode 100644
index 0000000000000000000000000000000000000000..09e241678a4175b8286de5dbde5176e05649dbdf
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000111.txt
@@ -0,0 +1 @@
+0.9352098703384399 -0.08805953711271286 0.34296926856040955 -0.41156312823295593 -0.0 -0.9685831069946289 -0.24868986010551453 0.2984278202056885 0.3540937900543213 0.23257721960544586 -0.9058284759521484 1.0869944095611572 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000112.txt b/3DTopia/assets/sample_data/pose/000112.txt
new file mode 100644
index 0000000000000000000000000000000000000000..890cbd2954f25f2baa3d4592b868bbbb20b918b4
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000112.txt
@@ -0,0 +1 @@
+0.9869155883789062 -0.04254636913537979 0.15552331507205963 -0.18662789463996887 -7.450581485102248e-09 -0.9645574688911438 -0.26387304067611694 0.31664755940437317 0.16123799979686737 0.26042044162750244 -0.9519367218017578 1.1423239707946777 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000113.txt b/3DTopia/assets/sample_data/pose/000113.txt
new file mode 100644
index 0000000000000000000000000000000000000000..033ec4314be76c46782822ce4e48fdca921af492
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000113.txt
@@ -0,0 +1 @@
+0.9992760419845581 0.010614471510052681 -0.036535248160362244 0.04384230822324753 9.313225746154785e-10 -0.960293710231781 -0.27899110317230225 0.3347893953323364 -0.038045912981033325 0.27878910303115845 -0.9595984220504761 1.1515181064605713 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000114.txt b/3DTopia/assets/sample_data/pose/000114.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2f45dbf2b8a907c1cab96db926977bed501de707
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000114.txt
@@ -0,0 +1 @@
+0.9717984795570374 0.06933855265378952 -0.2253885120153427 0.2704661190509796 -1.4901162970204496e-08 -0.9557930827140808 -0.29404035210609436 0.3528483510017395 -0.23581309616565704 0.2857479453086853 -0.9288381934165955 1.1146059036254883 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000115.txt b/3DTopia/assets/sample_data/pose/000115.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9ec438c94b2f7224f4b43e7e7fc095543c906619
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000115.txt
@@ -0,0 +1 @@
+0.9055783748626709 0.1310785412788391 -0.40341824293136597 0.4841018319129944 1.4901162970204496e-08 -0.9510565996170044 -0.30901703238487244 0.37082037329673767 -0.4241790473461151 0.2798391580581665 -0.8612562417984009 1.0335074663162231 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000116.txt b/3DTopia/assets/sample_data/pose/000116.txt
new file mode 100644
index 0000000000000000000000000000000000000000..635a38de7120930fa33c9ce5644faae62e131de3
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000116.txt
@@ -0,0 +1 @@
+0.8032556176185608 0.19293642044067383 -0.5635209679603577 0.6762250065803528 -0.0 -0.9460852742195129 -0.3239175081253052 0.3887009024620056 -0.595634400844574 0.2601885497570038 -0.7599483132362366 0.9119382500648499 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000117.txt b/3DTopia/assets/sample_data/pose/000117.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6b97f76a4e73f094109721b801d6143752ec7a3c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000117.txt
@@ -0,0 +1 @@
+0.6689097285270691 0.2517986595630646 -0.6993976831436157 0.8392772674560547 -2.9802322387695312e-08 -0.9408807754516602 -0.3387379050254822 0.40648549795150757 -0.7433436512947083 0.22658511996269226 -0.6293643712997437 0.7552372217178345 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000118.txt b/3DTopia/assets/sample_data/pose/000118.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a0059c94ccb540c18eaff754b8aa9993f6aa77dd
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000118.txt
@@ -0,0 +1 @@
+0.5078965425491333 0.30448970198631287 -0.805808424949646 0.966969907283783 -2.9802322387695312e-08 -0.9354440569877625 -0.3534749150276184 0.4241698086261749 -0.8614181280136108 0.17952869832515717 -0.4751087725162506 0.5701305866241455 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000119.txt b/3DTopia/assets/sample_data/pose/000119.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8155cd49cc7d2f1b840ec963913fdecf443e2476
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000119.txt
@@ -0,0 +1 @@
+0.32663485407829285 0.34793341159820557 -0.8787785172462463 1.0545341968536377 -4.4703469370688254e-08 -0.9297763109207153 -0.3681248426437378 0.4417494237422943 -0.9451503753662109 0.12024238705635071 -0.3036973476409912 0.3644372224807739 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000120.txt b/3DTopia/assets/sample_data/pose/000120.txt
new file mode 100644
index 0000000000000000000000000000000000000000..453e91e44313cd67b3db21655d9c28889113920c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000120.txt
@@ -0,0 +1 @@
+0.1323518306016922 0.3793167471885681 -0.9157520532608032 1.0989023447036743 1.4901161193847656e-08 -0.9238796234130859 -0.38268330693244934 0.4592200517654419 -0.9912027716636658 0.05064881592988968 -0.1222771555185318 0.14673250913619995 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000121.txt b/3DTopia/assets/sample_data/pose/000121.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b8c0d781bee51b0bedecad3a4d3a7f1bb3696f5c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000121.txt
@@ -0,0 +1 @@
+-0.06720828264951706 0.3962496519088745 -0.9156795144081116 1.0988155603408813 -3.166496043149891e-08 -0.9177546501159668 -0.39714762568473816 0.476577490568161 -0.9977388381958008 -0.0266916174441576 0.06168072298169136 -0.07401663064956665 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000122.txt b/3DTopia/assets/sample_data/pose/000122.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c4264455ebecff57680a44302b389b3e0f933d0a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000122.txt
@@ -0,0 +1 @@
+-0.2640887498855591 0.39690470695495605 -0.8790467977523804 1.0548564195632935 -1.0430807861894209e-07 -0.9114033579826355 -0.4115141034126282 0.4938172996044159 -0.9644981622695923 -0.1086762398481369 0.2406913787126541 -0.28882941603660583 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000123.txt b/3DTopia/assets/sample_data/pose/000123.txt
new file mode 100644
index 0000000000000000000000000000000000000000..933e6ad09b0591b1e6064a4090051f2654729453
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000123.txt
@@ -0,0 +1 @@
+-0.45044049620628357 0.3801383674144745 -0.807835042476654 0.9694024920463562 -1.0430805730266002e-07 -0.9048270583152771 -0.42577916383743286 0.5109351277351379 -0.8928060531616211 -0.19178827106952667 0.4075707793235779 -0.489084929227829 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000124.txt b/3DTopia/assets/sample_data/pose/000124.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f48bf53bcde81245411808371cb73a0261919e6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000124.txt
@@ -0,0 +1 @@
+-0.6188351511955261 0.3455813229084015 -0.7054193019866943 0.8465033173561096 -4.470347292340193e-08 -0.8980275392532349 -0.439939022064209 0.5279269218444824 -0.7855206727981567 -0.2722497582435608 0.5557310581207275 -0.6668769717216492 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000125.txt b/3DTopia/assets/sample_data/pose/000125.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7b1a8475036401e325178ba2ad9184218377e036
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000125.txt
@@ -0,0 +1 @@
+-0.7625583410263062 0.2936951220035553 -0.576409101486206 0.6916911005973816 -4.4703462265260896e-08 -0.8910065293312073 -0.4539904296398163 0.5447885990142822 -0.6469191312789917 -0.3461942672729492 0.6794444918632507 -0.8153334856033325 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000126.txt b/3DTopia/assets/sample_data/pose/000126.txt
new file mode 100644
index 0000000000000000000000000000000000000000..338c3452f049a63246b7b3913ba4225f26a96e33
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000126.txt
@@ -0,0 +1 @@
+-0.8758810758590698 0.22578869760036469 -0.42644059658050537 0.5117289423942566 -4.470348002882929e-08 -0.8837655782699585 -0.4679297208786011 0.5615156888961792 -0.48252683877944946 -0.4098508059978485 0.774073600769043 -0.9288883209228516 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000127.txt b/3DTopia/assets/sample_data/pose/000127.txt
new file mode 100644
index 0000000000000000000000000000000000000000..de9dc4ed1416526cc4fabeb0d5defe351813f2a9
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000127.txt
@@ -0,0 +1 @@
+-0.9542849063873291 0.14399497210979462 -0.2619266211986542 0.3143114745616913 3.650783639841393e-07 -0.8763066530227661 -0.4817536175251007 0.5781043767929077 -0.2988981604576111 -0.459730327129364 0.8362461924552917 -1.003495693206787 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000128.txt b/3DTopia/assets/sample_data/pose/000128.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4afe2894717f298b381ac546ad1dd410ff91e788
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000128.txt
@@ -0,0 +1 @@
+-0.9946445822715759 0.05120658501982689 -0.08977368474006653 0.1077304556965828 -5.103644298287691e-07 -0.8686315417289734 -0.4954585134983063 0.5945504903793335 -0.10335099697113037 -0.49280521273612976 0.8639796376228333 -1.0367757081985474 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000129.txt b/3DTopia/assets/sample_data/pose/000129.txt
new file mode 100644
index 0000000000000000000000000000000000000000..92e56dd71cc479b36a5d57724ffcd5f1186b52a8
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000129.txt
@@ -0,0 +1 @@
+-0.9953508973121643 -0.04902723804116249 0.08289926499128342 -0.09948068112134933 6.332989528345934e-07 -0.8607419729232788 -0.5090413689613342 0.6108497381210327 0.0963117778301239 -0.5066748857498169 0.8567403554916382 -1.0280886888504028 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000130.txt b/3DTopia/assets/sample_data/pose/000130.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ecd4e0ee7c552170eae6ae7b1b9ad3efdb114cfb
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000130.txt
@@ -0,0 +1 @@
+-0.9563760161399841 -0.1526419222354889 0.24908897280693054 -0.2989071309566498 7.450580596923828e-09 -0.8526401519775391 -0.5224985480308533 0.6269983053207397 0.2921384572982788 -0.49970507621765137 0.8154445886611938 -0.9785334467887878 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000131.txt b/3DTopia/assets/sample_data/pose/000131.txt
new file mode 100644
index 0000000000000000000000000000000000000000..37207cc643e3c0f610c452f13e7e4cb727e350a1
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000131.txt
@@ -0,0 +1 @@
+-0.8792728781700134 -0.2552239000797272 0.4021683931350708 -0.48260238766670227 7.450576333667414e-08 -0.844327986240387 -0.5358267426490784 0.6429921984672546 0.476317822933197 -0.47113797068595886 0.7423946261405945 -0.8908737301826477 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000132.txt b/3DTopia/assets/sample_data/pose/000132.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8c91ab4790a62fe18fb51c98bcd3267cee3b96f2
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000132.txt
@@ -0,0 +1 @@
+-0.767116367816925 -0.35220232605934143 0.536176860332489 -0.6434125900268555 -0.0 -0.8358073830604553 -0.5490227341651917 0.6588274240493774 0.6415076851844788 -0.4211643934249878 0.6411615014076233 -0.7693938612937927 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000133.txt b/3DTopia/assets/sample_data/pose/000133.txt
new file mode 100644
index 0000000000000000000000000000000000000000..663c86d135a5adaf022d0d4181a456face2bf48d
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000133.txt
@@ -0,0 +1 @@
+-0.6243770122528076 -0.4390561878681183 0.6460515856742859 -0.7752619981765747 5.960462345910855e-08 -0.8270806670188904 -0.5620833039283752 0.674500048160553 0.7811228632926941 -0.3509519398212433 0.5164101123809814 -0.6196922063827515 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000134.txt b/3DTopia/assets/sample_data/pose/000134.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ce2f7e85679bd8570b11952368e620fac314fdad
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000134.txt
@@ -0,0 +1 @@
+-0.45674604177474976 -0.5115229487419128 0.7278236150741577 -0.8733885884284973 4.470348002882929e-08 -0.8181498050689697 -0.5750052332878113 0.6900063753128052 0.8895970582962036 -0.26263129711151123 0.37368670105934143 -0.4484238922595978 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000135.txt b/3DTopia/assets/sample_data/pose/000135.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9ee972511034a81abeb0ce52a353c6bfa118964
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000135.txt
@@ -0,0 +1 @@
+-0.27090585231781006 -0.5658054351806641 0.7787646651268005 -0.9345173835754395 1.4901162970204496e-08 -0.8090171217918396 -0.5877851843833923 0.7053421139717102 0.962605893611908 -0.159234419465065 0.21916747093200684 -0.26300084590911865 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000136.txt b/3DTopia/assets/sample_data/pose/000136.txt
new file mode 100644
index 0000000000000000000000000000000000000000..54f5dc33b1954e1c00dd58012f9a07f7837d1219
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000136.txt
@@ -0,0 +1 @@
+-0.0742654800415039 -0.5987615585327148 0.7974767088890076 -0.9569714665412903 -7.4505797087454084e-09 -0.7996850609779358 -0.6004195213317871 0.7205042839050293 0.9972383975982666 -0.04459046944975853 0.059388987720012665 -0.07126673310995102 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000137.txt b/3DTopia/assets/sample_data/pose/000137.txt
new file mode 100644
index 0000000000000000000000000000000000000000..539b9094884bbc1dfd05b284ed81410379f30ca4
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000137.txt
@@ -0,0 +1 @@
+0.12533560395240784 -0.6080741286277771 0.7839240431785583 -0.9407090544700623 3.725290298461914e-09 -0.7901548743247986 -0.6129072308540344 0.7354885339736938 0.9921144247055054 0.07681908458471298 -0.09903453290462494 0.1188414990901947 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000138.txt b/3DTopia/assets/sample_data/pose/000138.txt
new file mode 100644
index 0000000000000000000000000000000000000000..edd169da92400cf23729351c5e4b5cda4904e343
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000138.txt
@@ -0,0 +1 @@
+0.3199399411678314 -0.5923787355422974 0.7394092082977295 -0.8872910737991333 -0.0 -0.7804303765296936 -0.6252428889274597 0.7502911686897278 0.9474378824234009 0.20004017651081085 -0.24969083070755005 0.2996290326118469 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000139.txt b/3DTopia/assets/sample_data/pose/000139.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ceb637943c6e219c3e9dd2250a0635aa63851ef6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000139.txt
@@ -0,0 +1 @@
+0.5017891526222229 -0.5513653755187988 0.6664860248565674 -0.7997834086418152 1.192092469182171e-07 -0.770513117313385 -0.6374239921569824 0.7649087309837341 0.8649898171424866 0.3198525011539459 -0.38663509488105774 0.4639623761177063 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000140.txt b/3DTopia/assets/sample_data/pose/000140.txt
new file mode 100644
index 0000000000000000000000000000000000000000..61e450b42271cc785ec348dc49a5a61705a33b8c
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000140.txt
@@ -0,0 +1 @@
+0.6636338233947754 -0.4858245551586151 0.5688273906707764 -0.6825927495956421 8.940696716308594e-08 -0.7604058980941772 -0.6494481563568115 0.7793376445770264 0.7480576038360596 0.43099576234817505 -0.5046311020851135 0.6055574417114258 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000141.txt b/3DTopia/assets/sample_data/pose/000141.txt
new file mode 100644
index 0000000000000000000000000000000000000000..127403537b562462ed6dc6f8b739ecae8b1d2e5a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000141.txt
@@ -0,0 +1 @@
+0.7990213632583618 -0.3976486325263977 0.4510436952114105 -0.5412524342536926 2.9802318834981634e-08 -0.7501109838485718 -0.66131192445755 0.793574333190918 0.6013026237487793 0.5284023284912109 -0.5993546843528748 0.7192258834838867 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000142.txt b/3DTopia/assets/sample_data/pose/000142.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8e60c3f744ad13545f419459d0d546dac43680db
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000142.txt
@@ -0,0 +1 @@
+0.9025545120239258 -0.2897827625274658 0.31846699118614197 -0.38216039538383484 -2.9802322387695312e-08 -0.7396310567855835 -0.6730126142501831 0.8076150417327881 0.43057551980018616 0.607430636882782 -0.6675573587417603 0.8010690212249756 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000143.txt b/3DTopia/assets/sample_data/pose/000143.txt
new file mode 100644
index 0000000000000000000000000000000000000000..287c8aa1d607520837208c984c89f47387edd729
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000143.txt
@@ -0,0 +1 @@
+0.970105767250061 -0.16612771153450012 0.17690807580947876 -0.21228961646556854 7.450580596923828e-09 -0.728968620300293 -0.6845470666885376 0.8214565515518188 0.24268268048763275 0.6640830039978027 -0.707176685333252 0.8486118316650391 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000144.txt b/3DTopia/assets/sample_data/pose/000144.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be8c242941a09da32614d3bb1b80f8723a7f8c48
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000144.txt
@@ -0,0 +1 @@
+0.9989818334579468 -0.03139603137969971 0.03239819407463074 -0.038877833634614944 1.862645371275562e-09 -0.7181263566017151 -0.6959128975868225 0.8350954055786133 0.04511489346623421 0.6952042579650879 -0.7173951864242554 0.8608742952346802 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000145.txt b/3DTopia/assets/sample_data/pose/000145.txt
new file mode 100644
index 0000000000000000000000000000000000000000..03d849a0a2e74ff583ac675767c2467801cd8f02
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000145.txt
@@ -0,0 +1 @@
+0.9880316853523254 0.1090722382068634 -0.1090722382068634 0.13088670372962952 -0.0 -0.7071068286895752 -0.7071068286895752 0.848528265953064 -0.15425144135951996 0.6986439228057861 -0.6986439228057861 0.8383726477622986 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000146.txt b/3DTopia/assets/sample_data/pose/000146.txt
new file mode 100644
index 0000000000000000000000000000000000000000..76962ce71616e1212648a9cc68c9df8d1af1bbb0
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000146.txt
@@ -0,0 +1 @@
+0.9376917481422424 0.2495262175798416 -0.24180757999420166 0.2901691794395447 2.980232594040899e-08 -0.6959127187728882 -0.7181264758110046 0.8617516160011292 -0.3474683463573456 0.6733812093734741 -0.6525515913963318 0.7830621004104614 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000147.txt b/3DTopia/assets/sample_data/pose/000147.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d94caf48ba11227454895acbe160a5c35fe08b24
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000147.txt
@@ -0,0 +1 @@
+0.8499690294265747 0.38404446840286255 -0.3606417775154114 0.4327700436115265 -1.4901161193847656e-08 -0.6845471262931824 -0.728968620300293 0.8747624158859253 -0.5268326997756958 0.619600772857666 -0.5818438529968262 0.6982126832008362 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000148.txt b/3DTopia/assets/sample_data/pose/000148.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0521e665ca29959ad23ce635da508f1e673821ce
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000148.txt
@@ -0,0 +1 @@
+0.7283607721328735 0.5067906975746155 -0.46114397048950195 0.553372859954834 2.9802322387695312e-08 -0.6730124950408936 -0.7396311163902283 0.8875573873519897 -0.6851938366889954 0.5387182831764221 -0.4901959300041199 0.588235080242157 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000149.txt b/3DTopia/assets/sample_data/pose/000149.txt
new file mode 100644
index 0000000000000000000000000000000000000000..247211ba513d96f8a8bbfd06470b338387f490b6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000149.txt
@@ -0,0 +1 @@
+0.5777150988578796 0.6122695207595825 -0.5397883057594299 0.6477459669113159 5.960465188081798e-08 -0.6613119840621948 -0.7501110434532166 0.9001331925392151 -0.8162385821342468 0.4333503842353821 -0.38204991817474365 0.4584598243236542 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000150.txt b/3DTopia/assets/sample_data/pose/000150.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6e4259f98d11229016597f447fda80f832b2ebd5
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000150.txt
@@ -0,0 +1 @@
+0.4040374457836151 0.6955756545066833 -0.5940772891044617 0.7128933072090149 -2.2351731843173184e-07 -0.6494478583335876 -0.760405957698822 0.912487268447876 -0.9147422909736633 0.30723246932029724 -0.26240116357803345 0.3148817718029022 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000151.txt b/3DTopia/assets/sample_data/pose/000151.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e51fabb814d266646dedf7cca87b970ef7a0c5a1
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000151.txt
@@ -0,0 +1 @@
+0.21425242722034454 0.7526208758354187 -0.622621476650238 0.7471462488174438 -5.215405352032576e-08 -0.6374234557151794 -0.7705134749412537 0.924615740776062 -0.9767781496047974 0.1650843769311905 -0.1365695297718048 0.16388361155986786 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000152.txt b/3DTopia/assets/sample_data/pose/000152.txt
new file mode 100644
index 0000000000000000000000000000000000000000..02f0d719f3488d06de0c57f81847788a949e7740
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000152.txt
@@ -0,0 +1 @@
+0.01592571847140789 0.7803348898887634 -0.6251589059829712 0.7501961588859558 -6.938351759799843e-08 -0.6252382397651672 -0.7804338932037354 0.936516523361206 -0.9998730421066284 0.01242893747985363 -0.009957351721823215 0.01194903813302517 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000153.txt b/3DTopia/assets/sample_data/pose/000153.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b5446b2e71626b4e006e7f84b339e1a5252fb931
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000153.txt
@@ -0,0 +1 @@
+-0.18303552269935608 0.7768062353134155 -0.6025526523590088 0.7230633497238159 -1.043080928297968e-07 -0.6129070520401001 -0.7901548743247986 0.9481860995292664 -0.9831060767173767 -0.14462649822235107 0.11218380182981491 -0.13462068140506744 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000154.txt b/3DTopia/assets/sample_data/pose/000154.txt
new file mode 100644
index 0000000000000000000000000000000000000000..03f3501e10aae7752f61e4b2bf11395797ab6693
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000154.txt
@@ -0,0 +1 @@
+-0.3747004270553589 0.7414241433143616 -0.5566772818565369 0.6680126190185547 -4.470347292340193e-08 -0.6004204750061035 -0.799684464931488 0.9596216082572937 -0.9271458387374878 -0.2996421754360199 0.2249777913093567 -0.26997312903404236 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000155.txt b/3DTopia/assets/sample_data/pose/000155.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9adba264cdd3fa5131f7e00537dbbb41cd4b2341
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000155.txt
@@ -0,0 +1 @@
+-0.5514266490936279 0.6749008893966675 -0.49034419655799866 0.5884130001068115 1.4901161193847656e-08 -0.5877853035926819 -0.80901700258255 0.970820426940918 -0.8342233896255493 -0.44611358642578125 0.32412049174308777 -0.38894450664520264 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000156.txt b/3DTopia/assets/sample_data/pose/000156.txt
new file mode 100644
index 0000000000000000000000000000000000000000..083b436903364a1b52b5ef97e083c72d298f50ae
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000156.txt
@@ -0,0 +1 @@
+-0.7061696648597717 0.5792850852012634 -0.40712860226631165 0.48855406045913696 1.4901168299275014e-08 -0.5750053524971008 -0.818149745464325 0.9817796945571899 -0.7080430388450623 -0.5777523517608643 0.4060514271259308 -0.48726147413253784 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000157.txt b/3DTopia/assets/sample_data/pose/000157.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0a6b93de25e0422c2661ad190d8bbd0cfadaf003
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000157.txt
@@ -0,0 +1 @@
+-0.8327592015266418 0.4579005837440491 -0.3111884593963623 0.37342676520347595 -3.7252868878567824e-07 -0.5620833039283752 -0.827080488204956 0.9924965500831604 -0.553634524345398 -0.6887590885162354 0.4680800437927246 -0.5616962909698486 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000158.txt b/3DTopia/assets/sample_data/pose/000158.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f32d22f520ff6c3fa804bc923787587db83ff477
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000158.txt
@@ -0,0 +1 @@
+-0.9261496067047119 0.3152289092540741 -0.2070665806531906 0.24848027527332306 -7.45057349149647e-08 -0.5490228533744812 -0.835807204246521 1.0029689073562622 -0.3771549463272095 -0.7740828990936279 0.5084771513938904 -0.6101730465888977 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000159.txt b/3DTopia/assets/sample_data/pose/000159.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d059c9453bb4b6d6036e6357d94b15aae3cc8035
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000159.txt
@@ -0,0 +1 @@
+-0.9826175570487976 0.15674014389514923 -0.09946972131729126 0.11936488747596741 -4.917378646496218e-07 -0.5358267426490784 -0.8443277478218079 1.0131934881210327 -0.18563862144947052 -0.8296516537666321 0.5265126824378967 -0.6318156123161316 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000160.txt b/3DTopia/assets/sample_data/pose/000160.txt
new file mode 100644
index 0000000000000000000000000000000000000000..47ea199c68e79262de73bd5b2f7a84c18a7c2b2a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000160.txt
@@ -0,0 +1 @@
+-0.9999117851257324 -0.011320343241095543 0.006937115918844938 -0.008324497379362583 -0.0 -0.522498607635498 -0.8526401519775391 1.0231682062149048 0.013276812620460987 -0.8525649905204773 0.5224524140357971 -0.626943051815033 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000161.txt b/3DTopia/assets/sample_data/pose/000161.txt
new file mode 100644
index 0000000000000000000000000000000000000000..30010e886c7d3b90c795b97930ecd7b672f6843f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000161.txt
@@ -0,0 +1 @@
+-0.9773425459861755 -0.18218766152858734 0.10774486511945724 -0.12929487228393555 5.662440685227921e-07 -0.5090415477752686 -0.8607418537139893 1.0328903198242188 0.2116631716489792 -0.8412397503852844 0.49750807881355286 -0.5970093607902527 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000162.txt b/3DTopia/assets/sample_data/pose/000162.txt
new file mode 100644
index 0000000000000000000000000000000000000000..18de1b495916f082a77529102796b17272bfb15b
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000162.txt
@@ -0,0 +1 @@
+-0.9158093929290771 -0.3488530218601227 0.1989825814962387 -0.2387789785861969 -2.6822073095900123e-07 -0.49545881152153015 -0.8686313033103943 1.0423578023910522 0.40161237120628357 -0.7955010533332825 0.45374563336372375 -0.5444949865341187 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000163.txt b/3DTopia/assets/sample_data/pose/000163.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0ffb1eb0f1f1b14928f8608d55d7caf394f7dfd6
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000163.txt
@@ -0,0 +1 @@
+-0.8177659511566162 -0.5043586492538452 0.2772735059261322 -0.3327282667160034 1.4901154088420299e-08 -0.48175370693206787 -0.8763065338134766 1.051567792892456 0.575550377368927 -0.7166138887405396 0.39396175742149353 -0.47275421023368835 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000164.txt b/3DTopia/assets/sample_data/pose/000164.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6dbce3289f8ad6e960998dd2bf14f9b2d1d3aa7
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000164.txt
@@ -0,0 +1 @@
+-0.6871210336685181 -0.6420933604240417 0.33997073769569397 -0.4079653024673462 3.1292415769712534e-07 -0.4679299294948578 -0.8837653398513794 1.0605188608169556 0.7265424728393555 -0.6072539687156677 0.3215245306491852 -0.385829359292984 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000165.txt b/3DTopia/assets/sample_data/pose/000165.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9fcf6f650fa615f11a899719603cd1ce69db335e
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000165.txt
@@ -0,0 +1 @@
+-0.5290825366973877 -0.7560816407203674 0.3852425813674927 -0.4622913897037506 1.3411039390121005e-07 -0.4539903998374939 -0.8910064697265625 1.069207787513733 0.8485701680183411 -0.47141605615615845 0.24019837379455566 -0.28823819756507874 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000166.txt b/3DTopia/assets/sample_data/pose/000166.txt
new file mode 100644
index 0000000000000000000000000000000000000000..87b4b3e158ffe731093160489ff82a88974dddf2
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000166.txt
@@ -0,0 +1 @@
+-0.3499513566493988 -0.8412432074546814 0.4121206998825073 -0.49454501271247864 2.0861612881617475e-07 -0.4399392604827881 -0.8980274200439453 1.0776331424713135 0.936767578125 -0.3142661154270172 0.15395735204219818 -0.18474876880645752 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000167.txt b/3DTopia/assets/sample_data/pose/000167.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1570aaacef31dad805db29f459e59a1d671fedfb
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000167.txt
@@ -0,0 +1 @@
+-0.15686866641044617 -0.8936248421669006 0.42050766944885254 -0.5046095252037048 8.940696005765858e-08 -0.4257791042327881 -0.9048271179199219 1.0857925415039062 0.9876194596290588 -0.14193904399871826 0.06679143011569977 -0.08014968037605286 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000168.txt b/3DTopia/assets/sample_data/pose/000168.txt
new file mode 100644
index 0000000000000000000000000000000000000000..28e1bdfd3ed2fa47360b21a91ddf4dc100c462df
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000168.txt
@@ -0,0 +1 @@
+0.04246791824698448 -0.910581648349762 0.4111417531967163 -0.4933716952800751 2.60770320892334e-08 -0.41151300072669983 -0.9114038348197937 1.0936838388442993 0.9990978240966797 0.03870541974902153 -0.017476091161370277 0.020971447229385376 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000169.txt b/3DTopia/assets/sample_data/pose/000169.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b9de0ca4e16008afcb9074467e065abac6a06e1f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000169.txt
@@ -0,0 +1 @@
+0.2401116043329239 -0.890906035900116 0.3855292797088623 -0.462635338306427 -1.4901161193847656e-08 -0.3971477150917053 -0.9177546501159668 1.1013054847717285 0.970745325088501 0.2203635573387146 -0.09535978734493256 0.11443176120519638 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000170.txt b/3DTopia/assets/sample_data/pose/000170.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca76a0491d972790bc92a6ff2e60ddd24604f2c2
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000170.txt
@@ -0,0 +1 @@
+0.42818257212638855 -0.834902822971344 0.34582775831222534 -0.41499361395835876 -2.9802318834981634e-08 -0.3826831877231598 -0.9238796234130859 1.1086554527282715 0.9036921262741089 0.39558911323547363 -0.1638583242893219 0.1966300755739212 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000171.txt b/3DTopia/assets/sample_data/pose/000171.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d018746d105d9ca7bc289bd82187832022b2a396
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000171.txt
@@ -0,0 +1 @@
+0.5991833806037903 -0.744390070438385 0.29472458362579346 -0.35366982221603394 1.4901157641133977e-08 -0.3681242763996124 -0.9297765493392944 1.1157318353652954 0.8006117343902588 0.557106614112854 -0.22057394683361053 0.26468899846076965 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000172.txt b/3DTopia/assets/sample_data/pose/000172.txt
new file mode 100644
index 0000000000000000000000000000000000000000..02c9cdd36ee85a463ca7c1fd981ecd6b5a8c12f9
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000172.txt
@@ -0,0 +1 @@
+0.7462967038154602 -0.6226441264152527 0.23527754843235016 -0.282333105802536 2.9802318834981634e-08 -0.35347476601600647 -0.9354439973831177 1.1225329637527466 0.6656134724617004 0.6981187462806702 -0.26379701495170593 0.31655651330947876 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000173.txt b/3DTopia/assets/sample_data/pose/000173.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0174b2409348ae3a8a2ca1c7cf957d847ee5a17f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000173.txt
@@ -0,0 +1 @@
+0.8636573553085327 -0.47427839040756226 0.1707507073879242 -0.20490090548992157 -0.0 -0.3387379050254822 -0.9408807158470154 1.1290568113327026 0.5040791630744934 0.8125985264778137 -0.29255348443984985 0.3510642647743225 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000174.txt b/3DTopia/assets/sample_data/pose/000174.txt
new file mode 100644
index 0000000000000000000000000000000000000000..591a8dee55f6948e83a456fa86296bf3550aa528
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000174.txt
@@ -0,0 +1 @@
+0.9465868473052979 -0.30506426095962524 0.10444684326648712 -0.12533621490001678 1.4901162970204496e-08 -0.32391735911369324 -0.9460853934288025 1.1353024244308472 0.3224489986896515 0.8955519795417786 -0.306615948677063 0.3679392337799072 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000175.txt b/3DTopia/assets/sample_data/pose/000175.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e03ff318cf1447e40b5c296bba185eb54fe3663a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000175.txt
@@ -0,0 +1 @@
+0.9917788505554199 -0.1217007040977478 0.03954293206334114 -0.04745154082775116 -3.725290298461914e-09 -0.3090168833732605 -0.9510565400123596 1.141268014907837 0.12796369194984436 0.9432377815246582 -0.3064764142036438 0.36777186393737793 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000176.txt b/3DTopia/assets/sample_data/pose/000176.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a03b53267607cc86cfb041843ebf308eb288e916
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000176.txt
@@ -0,0 +1 @@
+0.997431755065918 0.06845686584711075 -0.021060077473521233 0.025272099301218987 -0.0 -0.2940402626991272 -0.955793023109436 1.146951675415039 -0.07162310928106308 0.9533383250236511 -0.2932851016521454 0.35194218158721924 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000177.txt b/3DTopia/assets/sample_data/pose/000177.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f577381efeecf95d9bb59657c115ba47f8ceb97a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000177.txt
@@ -0,0 +1 @@
+0.9633201360702515 0.257699191570282 -0.07486848533153534 0.08984224498271942 -6.705520405603238e-08 -0.27899086475372314 -0.9602935314178467 1.1523525714874268 -0.26835453510284424 0.9250701665878296 -0.26875752210617065 0.3225092887878418 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000178.txt b/3DTopia/assets/sample_data/pose/000178.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8baf4f70fb1b4b8c16f7f13212a36eff7ffd40e7
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000178.txt
@@ -0,0 +1 @@
+0.8908040523529053 0.4382828176021576 -0.1199006661772728 0.143880695104599 -0.0 -0.2638731598854065 -0.9645572900772095 1.1574686765670776 -0.45438748598098755 0.8592315912246704 -0.23505929112434387 0.2820710241794586 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000179.txt b/3DTopia/assets/sample_data/pose/000179.txt
new file mode 100644
index 0000000000000000000000000000000000000000..79dd4508d344313742c389fbc5eb1794defab2a7
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000179.txt
@@ -0,0 +1 @@
+0.7827745079994202 0.6027545928955078 -0.15476103127002716 0.1857132613658905 -1.4901161193847656e-08 -0.24868980050086975 -0.9685832262039185 1.1622997522354126 -0.6223054528236389 0.7581822276115417 -0.1946680247783661 0.23360170423984528 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000180.txt b/3DTopia/assets/sample_data/pose/000180.txt
new file mode 100644
index 0000000000000000000000000000000000000000..14c85ed729cb52a1e206fb6e5c39309874fe9426
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000180.txt
@@ -0,0 +1 @@
+0.6435381174087524 0.7442656755447388 -0.17868220806121826 0.21441881358623505 -5.960463766996327e-08 -0.23344513773918152 -0.972369909286499 1.1668438911437988 -0.7654140591621399 0.625757098197937 -0.15023081004619598 0.18027718365192413 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000181.txt b/3DTopia/assets/sample_data/pose/000181.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a668f50004a5a71ff3758a9f2890a8dd86d4458f
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000181.txt
@@ -0,0 +1 @@
+0.4786456823348999 0.8568629026412964 -0.1915312558412552 0.2298377901315689 -1.4901159417490817e-08 -0.21814295649528503 -0.9759168028831482 1.1711000204086304 -0.8780080676078796 0.46711835265159607 -0.10441319644451141 0.12529604136943817 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000182.txt b/3DTopia/assets/sample_data/pose/000182.txt
new file mode 100644
index 0000000000000000000000000000000000000000..314070f0cd1522f64866ede307832dab2294bdcc
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000182.txt
@@ -0,0 +1 @@
+0.2946713864803314 0.9357439875602722 -0.19378307461738586 0.23253990709781647 2.607702853651972e-08 -0.2027871459722519 -0.9792227745056152 1.17506742477417 -0.9555985927581787 0.28854894638061523 -0.05975561589002609 0.07170679420232773 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000183.txt b/3DTopia/assets/sample_data/pose/000183.txt
new file mode 100644
index 0000000000000000000000000000000000000000..366bcc215c52775da6e8e4544bc05fedaa1d21be
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000183.txt
@@ -0,0 +1 @@
+0.09894955158233643 0.9774670600891113 -0.186459481716156 0.22375409305095673 4.6566128730773926e-08 -0.18737904727458954 -0.9822876453399658 1.1787446737289429 -0.9950924515724182 0.09719691425561905 -0.01854112185537815 0.022249583154916763 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000184.txt b/3DTopia/assets/sample_data/pose/000184.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a12478cd73cfad215240a993b54e9e0cd6fe16ff
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000184.txt
@@ -0,0 +1 @@
+-0.10071694850921631 0.9801000952720642 -0.17105454206466675 0.2052658498287201 -6.89178563106907e-08 -0.17192883789539337 -0.985109269618988 1.1821311712265015 -0.9949150085449219 -0.0992172583937645 0.01731616072356701 -0.0207794401794672 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000185.txt b/3DTopia/assets/sample_data/pose/000185.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6622c985b9bf721d1e955e2e63acf744a0c6770d
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000185.txt
@@ -0,0 +1 @@
+-0.29636847972869873 0.9433149099349976 -0.14940685033798218 0.17928773164749146 -1.527368311826649e-07 -0.15643510222434998 -0.9876880645751953 1.1852260828018188 -0.9550734162330627 -0.2927198112010956 0.046362414956092834 -0.05563471466302872 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000186.txt b/3DTopia/assets/sample_data/pose/000186.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d8204fa64ab060476fc74f601d18278867097d68
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000186.txt
@@ -0,0 +1 @@
+-0.4802047610282898 0.8684056997299194 -0.12359234690666199 0.14831088483333588 -1.527369306586479e-07 -0.14090117812156677 -0.9900237321853638 1.188028335571289 -0.8771565556526184 -0.4754140377044678 0.06766162067651749 -0.0811937153339386 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000187.txt b/3DTopia/assets/sample_data/pose/000187.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c5a6e6b7b897650d9f808da3b0f6bc606746ee74
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000187.txt
@@ -0,0 +1 @@
+-0.6448968648910522 0.7582430243492126 -0.0957883819937706 0.11494607478380203 -1.4156101713069802e-07 -0.12533338367938995 -0.9921146631240845 1.190537691116333 -0.7642695903778076 -0.6398116946220398 0.08082716166973114 -0.0969923734664917 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000188.txt b/3DTopia/assets/sample_data/pose/000188.txt
new file mode 100644
index 0000000000000000000000000000000000000000..da3d3f26f0049ad7d135f80d80876ffc78f99cca
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000188.txt
@@ -0,0 +1 @@
+-0.783878743648529 0.6171642541885376 -0.06813523918390274 0.08176270127296448 -3.2782554626464844e-07 -0.10973420739173889 -0.9939610362052917 1.1927533149719238 -0.6209139823913574 -0.7791448831558228 0.08601849526166916 -0.1032220870256424 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000189.txt b/3DTopia/assets/sample_data/pose/000189.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aa24f22623a4f6c44d69320e7361fbe5d1dcff29
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000189.txt
@@ -0,0 +1 @@
+-0.8916096687316895 0.45079466700553894 -0.04261254519224167 0.05113517865538597 -8.19563368281706e-08 -0.09410841017961502 -0.9955618381500244 1.1946742534637451 -0.45280423760414124 -0.8876528143882751 0.08390773087739944 -0.10068946331739426 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000190.txt b/3DTopia/assets/sample_data/pose/000190.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1c7b725f4f8bf00720fef46d91da5ea9c1df6038
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000190.txt
@@ -0,0 +1 @@
+-0.9637949466705322 0.2658207416534424 -0.020919324830174446 0.02510467730462551 -1.2833611435780767e-06 -0.07845940440893173 -0.9969170093536377 1.196300745010376 -0.26664260029792786 -0.9608241319656372 0.07561864703893661 -0.09074220806360245 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000191.txt b/3DTopia/assets/sample_data/pose/000191.txt
new file mode 100644
index 0000000000000000000000000000000000000000..016bad1f7081ba063ced918647343f73f0ba3bca
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000191.txt
@@ -0,0 +1 @@
+-0.9975574016571045 0.06970969587564468 -0.004386159125715494 0.005263194907456636 3.86498697935167e-07 -0.06279079616069794 -0.9980266094207764 1.1976321935653687 -0.06984754651784897 -0.9955891370773315 0.06263715028762817 -0.07516458630561829 -0.0 0.0 0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000192.txt b/3DTopia/assets/sample_data/pose/000192.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a1350253ca0a3fab00909d501c9b03c75a81a618
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000192.txt
@@ -0,0 +1 @@
+-0.9915500283241272 -0.1295798271894455 0.0061101061291992664 -0.007333071436733007 7.31088050542894e-07 -0.0471065416932106 -0.9988898634910583 1.1986677646636963 0.12972381711006165 -0.9904493689537048 0.04670844227075577 -0.05605006963014603 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000193.txt b/3DTopia/assets/sample_data/pose/000193.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c48c861782fa785615b975d66ac418dbad46d76a
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000193.txt
@@ -0,0 +1 @@
+-0.9460127353668213 -0.32396966218948364 0.0101803382858634 -0.012217401526868343 8.381903739973495e-07 -0.03141067922115326 -0.9995065927505493 1.199407935142517 0.324129581451416 -0.9455459117889404 0.02971518225967884 -0.03565797209739685 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/assets/sample_data/pose/000194.txt b/3DTopia/assets/sample_data/pose/000194.txt
new file mode 100644
index 0000000000000000000000000000000000000000..addc79d52f19d732eb2b17c1eb4bcca99088534b
--- /dev/null
+++ b/3DTopia/assets/sample_data/pose/000194.txt
@@ -0,0 +1 @@
+-0.8627605438232422 -0.5055502653121948 0.007941310293972492 -0.009530180133879185 6.407498744920304e-07 -0.01570744998753071 -0.9998766183853149 1.1998521089553833 0.5056126117706299 -0.8626541495323181 0.013552011922001839 -0.016261987388134003 -0.0 0.0 -0.0 1.0
diff --git a/3DTopia/configs/default.yaml b/3DTopia/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9fc30020f42cdc132b73509b0c48f8a9f58910e0
--- /dev/null
+++ b/3DTopia/configs/default.yaml
@@ -0,0 +1,72 @@
+model:
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ shift_scale: 2
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "triplane"
+ cond_stage_key: "caption"
+ image_size: 32
+ channels: 8
+ cond_stage_trainable: false
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.5147210212065061
+ use_ema: False
+ learning_rate: 5e-5
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32
+ in_channels: 8
+ out_channels: 8
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ context_dim: 768
+ transformer_depth: 1
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: model.triplane_vae.AutoencoderKLRollOut
+ params:
+ embed_dim: 8
+ learning_rate: 1e-5
+ norm: False
+ renderer_type: eg3d
+ ddconfig:
+ double_z: true
+ z_channels: 8
+ resolution: 256
+ in_channels: 32
+ out_ch: 32
+ ch: 128
+ ch_mult:
+ - 2
+ - 4
+ - 4
+ - 8
+ num_res_blocks: 2
+ attn_resolutions: [32]
+ dropout: 0.0
+ lossconfig:
+ kl_weight: 1e-5
+ rec_weight: 1
+ latent_tv_weight: 2e-3
+ renderer_config:
+ rgbnet_dim: -1
+ rgbnet_width: 128
+ sigma_dim: 12
+ c_dim: 20
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPTextEmbedder
+
diff --git a/3DTopia/environment.yml b/3DTopia/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2415cf057c34303455a4843b7db3b96a71955666
--- /dev/null
+++ b/3DTopia/environment.yml
@@ -0,0 +1,150 @@
+name: 3dtopia
+channels:
+ - pytorch
+ - anaconda
+ - conda-forge
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - brotli=1.0.9=h166bdaf_7
+ - brotli-bin=1.0.9=h166bdaf_7
+ - bzip2=1.0.8=h7f98852_4
+ - ca-certificates=2023.5.7=hbcca054_0
+ - certifi=2023.5.7=pyhd8ed1ab_0
+ - charset-normalizer=3.1.0=pyhd8ed1ab_0
+ - colorama=0.4.6=pyhd8ed1ab_0
+ - cudatoolkit=11.3.1=h9edb442_10
+ - ffmpeg=4.3.2=hca11adc_0
+ - freetype=2.10.4=h0708190_1
+ - fsspec=2023.5.0=pyh1a96a4e_0
+ - gmp=6.2.1=h58526e2_0
+ - gnutls=3.6.13=h85f3911_1
+ - idna=3.4=pyhd8ed1ab_0
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - jpeg=9e=h166bdaf_1
+ - lame=3.100=h7f98852_1001
+ - lcms2=2.12=hddcbb42_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - libbrotlicommon=1.0.9=h166bdaf_7
+ - libbrotlidec=1.0.9=h166bdaf_7
+ - libbrotlienc=1.0.9=h166bdaf_7
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libpng=1.6.37=h21135ba_2
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtiff=4.2.0=hecacb30_2
+ - libwebp-base=1.2.2=h7f98852_1
+ - lightning-utilities=0.8.0=pyhd8ed1ab_0
+ - lz4-c=1.9.3=h9c3ff4c_1
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py38h95df7f1_0
+ - mkl_fft=1.3.1=py38h8666266_1
+ - mkl_random=1.2.2=py38h1abd341_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.6=he412f7d_0
+ - numpy=1.24.3=py38h14f4228_0
+ - numpy-base=1.24.3=py38h31eccc5_0
+ - olefile=0.46=pyh9f0ad1d_1
+ - openh264=2.1.1=h780b84a_0
+ - openjpeg=2.4.0=hb52868f_1
+ - openssl=1.1.1u=h7f8727e_0
+ - packaging=23.1=pyhd8ed1ab_0
+ - pip=23.0.1=py38h06a4308_0
+ - pixman-cos6-x86_64=0.32.8=4
+ - pysocks=1.7.1=pyha2e5f31_6
+ - python=3.8.16=h7a1cb2a_3
+ - python_abi=3.8=2_cp38
+ - pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0
+ - pytorch-lightning=2.0.2=pyhd8ed1ab_0
+ - pytorch-mutex=1.0=cuda
+ - pyyaml=6.0=py38h0a891b7_4
+ - readline=8.2=h5eee18b_0
+ - requests=2.31.0=pyhd8ed1ab_0
+ - setuptools=67.8.0=py38h06a4308_0
+ - six=1.16.0=pyh6c4a22f_0
+ - sqlite=3.41.2=h5eee18b_0
+ - tk=8.6.12=h1ccaba5_0
+ - torchaudio=0.12.0=py38_cu113
+ - torchmetrics=0.11.4=pyhd8ed1ab_0
+ - torchvision=0.13.0=py38_cu113
+ - tqdm=4.65.0=pyhd8ed1ab_1
+ - typing_extensions=4.6.3=pyha770c72_0
+ - urllib3=2.0.2=pyhd8ed1ab_0
+ - wheel=0.38.4=py38h06a4308_0
+ - x264=1!161.3030=h7f98852_1
+ - xorg-x11-server-common-cos6-x86_64=1.17.4=4
+ - xorg-x11-server-xvfb-cos6-x86_64=1.17.4=4
+ - xz=5.4.2=h5eee18b_0
+ - yaml=0.2.5=h7f98852_2
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.5.2=ha4553b6_0
+ - pip:
+ - antlr4-python3-runtime==4.9.3
+ - appdirs==1.4.4
+ - asttokens==2.4.0
+ - av==10.0.0
+ - backcall==0.2.0
+ - click==8.1.3
+ - git+https://github.com/openai/CLIP.git
+ - contourpy==1.1.1
+ - cycler==0.12.1
+ - decorator==5.1.1
+ - docker-pycreds==0.4.0
+ - einops==0.6.1
+ - executing==1.2.0
+ - filelock==3.12.2
+ - fonttools==4.43.1
+ - ftfy==6.1.1
+ - gitdb==4.0.10
+ - gitpython==3.1.31
+ - huggingface-hub==0.16.4
+ - imageio==2.31.0
+ - imageio-ffmpeg==0.4.8
+ - importlib-resources==6.1.0
+ - ipdb==0.13.13
+ - ipython==8.12.2
+ - jedi==0.19.0
+ - kiwisolver==1.4.5
+ - kornia==0.6.0
+ - lpips==0.1.4
+ - matplotlib==3.7.3
+ - matplotlib-inline==0.1.6
+ - omegaconf==2.3.0
+ - open-clip-torch==2.20.0
+ - opencv-python==4.7.0.72
+ - parso==0.8.3
+ - pathtools==0.1.2
+ - pexpect==4.8.0
+ - pickleshare==0.7.5
+ - pillow==9.5.0
+ - prompt-toolkit==3.0.39
+ - protobuf==3.20.3
+ - psutil==5.9.5
+ - ptyprocess==0.7.0
+ - pure-eval==0.2.2
+ - pygments==2.16.1
+ - pymcubes==0.1.4
+ - pyparsing==3.1.1
+ - pytorch-fid==0.3.0
+ - pytorch-msssim==1.0.0
+ - regex==2023.6.3
+ - safetensors==0.3.3
+ - scipy==1.10.1
+ - sentencepiece==0.1.99
+ - sentry-sdk==1.25.0
+ - setproctitle==1.3.2
+ - smmap==5.0.0
+ - stack-data==0.6.2
+ - timm==0.9.7
+ - tokenizers==0.12.1
+ - tomli==2.0.1
+ - traitlets==5.9.0
+ - transformers
+ - trimesh==4.0.2
+ - vit-pytorch==1.2.2
+ - wandb==0.15.3
+ - wcwidth==0.2.6
+ - zipp==3.17.0
diff --git a/3DTopia/gradio_demo.py b/3DTopia/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02e3354ae21a91699428ffa7e7ed454d1c5832b
--- /dev/null
+++ b/3DTopia/gradio_demo.py
@@ -0,0 +1,334 @@
+import os
+import cv2
+import time
+import json
+import torch
+import mcubes
+import trimesh
+import datetime
+import argparse
+import subprocess
+import numpy as np
+import gradio as gr
+from tqdm import tqdm
+import imageio.v2 as imageio
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.dpm_solver import DPMSolverSampler
+
+from utility.initialize import instantiate_from_config, get_obj_from_str
+from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
+from utility.triplane_renderer.renderer import get_rays, to8b
+from safetensors.torch import load_file
+from huggingface_hub import hf_hub_download
+
+import warnings
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+def add_text(rgb, caption):
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ # org
+ gap = 10
+ org = (gap, gap)
+ # fontScale
+ fontScale = 0.3
+ # Blue color in BGR
+ color = (255, 0, 0)
+ # Line thickness of 2 px
+ thickness = 1
+ break_caption = []
+ for i in range(len(caption) // 30 + 1):
+ break_caption_i = caption[i*30:(i+1)*30]
+ break_caption.append(break_caption_i)
+ for i, bci in enumerate(break_caption):
+ cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
+ return rgb
+
+config = "configs/default.yaml"
+# ckpt = "checkpoints/3dtopia_diffusion_state_dict.ckpt"
+ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
+configs = OmegaConf.load(config)
+os.makedirs("tmp", exist_ok=True)
+
+if ckpt.endswith(".ckpt"):
+ model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
+elif ckpt.endswith(".safetensors"):
+ model = get_obj_from_str(configs.model["target"])(**configs.model.params)
+ model_ckpt = load_file(ckpt)
+ model.load_state_dict(model_ckpt)
+else:
+ raise NotImplementedError
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+model = model.to(device)
+sampler = DDIMSampler(model)
+
+img_size = configs.model.params.unet_config.params.image_size
+channels = configs.model.params.unet_config.params.in_channels
+shape = [channels, img_size, img_size * 3]
+
+pose_folder = 'assets/sample_data/pose'
+poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
+batch_rays_list = []
+H = 128
+ratio = 512 // H
+for p in poses_fname:
+ c2w = np.loadtxt(p).reshape(4, 4)
+ c2w[:3, 3] *= 2.2
+ c2w = np.array([
+ [1, 0, 0, 0],
+ [0, 0, -1, 0],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1]
+ ]) @ c2w
+
+ k = np.array([
+ [560 / ratio, 0, H * 0.5],
+ [0, 560 / ratio, H * 0.5],
+ [0, 0, 1]
+ ])
+
+ rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
+ coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
+ coords = torch.reshape(coords, [-1,2]).long()
+ rays_o = rays_o[coords[:, 0], coords[:, 1]]
+ rays_d = rays_d[coords[:, 0], coords[:, 1]]
+ batch_rays = torch.stack([rays_o, rays_d], 0)
+ batch_rays_list.append(batch_rays)
+batch_rays_list = torch.stack(batch_rays_list, 0)
+
+def marching_cube(b, text, global_info):
+ # prepare volumn for marching cube
+ res = 128
+ assert 'decode_res' in global_info
+ decode_res = global_info['decode_res']
+ c_list = torch.linspace(-1.2, 1.2, steps=res)
+ grid_x, grid_y, grid_z = torch.meshgrid(
+ c_list, c_list, c_list, indexing='ij'
+ )
+ coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device)
+ plane_axes = generate_planes()
+ feats = sample_from_planes(
+ plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4
+ )
+ fake_dirs = torch.zeros_like(coords)
+ fake_dirs[..., 0] = 1
+ out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs)
+ u = out['sigma'].reshape(res, res, res).detach().cpu().numpy()
+ del out
+
+ # marching cube
+ vertices, triangles = mcubes.marching_cubes(u, 10)
+ min_bound = np.array([-1.2, -1.2, -1.2])
+ max_bound = np.array([1.2, 1.2, 1.2])
+ vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :]
+ pt_vertices = torch.from_numpy(vertices).to(device)
+
+ # extract vertices color
+ res_triplane = 256
+ render_kwargs = {
+ 'depth_resolution': 128,
+ 'disparity_space_sampling': False,
+ 'box_warp': 2.4,
+ 'depth_resolution_importance': 128,
+ 'clamp_mode': 'softplus',
+ 'white_back': True,
+ 'det': True
+ }
+ rays_o_list = [
+ np.array([0, 0, 2]),
+ np.array([0, 0, -2]),
+ np.array([0, 2, 0]),
+ np.array([0, -2, 0]),
+ np.array([2, 0, 0]),
+ np.array([-2, 0, 0]),
+ ]
+ rgb_final = None
+ diff_final = None
+ for rays_o in tqdm(rays_o_list):
+ rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
+ rays_d = pt_vertices.reshape(-1, 3) - rays_o
+ rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
+ dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1)
+
+ render_out = model.first_stage_model.triplane_decoder(
+ decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane),
+ rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs,
+ whole_img=False, tvloss=False
+ )
+ rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
+ depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy()
+ depth_diff = np.abs(dist - depth)
+
+ if rgb_final is None:
+ rgb_final = rgb.copy()
+ diff_final = depth_diff.copy()
+
+ else:
+ ind = diff_final > depth_diff
+ rgb_final[ind] = rgb[ind]
+ diff_final[ind] = depth_diff[ind]
+
+ # bgr to rgb
+ rgb_final = np.stack([
+ rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0]
+ ], -1)
+
+ # export to ply
+ mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8))
+ path = os.path.join('tmp', f"{text.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.ply")
+ trimesh.exchange.export.export_mesh(mesh, path, file_type='ply')
+
+ del vertices, triangles, rgb_final
+ torch.cuda.empty_cache()
+
+ return path
+
+def infer(prompt, samples, steps, scale, seed, global_info):
+ prompt = prompt.replace('/', '')
+ pl.seed_everything(seed)
+ batch_size = samples
+ with torch.no_grad():
+ noise = None
+ c = model.get_learned_conditioning([prompt])
+ unconditional_c = torch.zeros_like(c)
+ sample, _ = sampler.sample(
+ S=steps,
+ batch_size=batch_size,
+ shape=shape,
+ verbose=False,
+ x_T = noise,
+ conditioning = c.repeat(batch_size, 1, 1),
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1)
+ )
+ decode_res = model.decode_first_stage(sample)
+
+ big_video_list = []
+
+ global_info['decode_res'] = decode_res
+
+ for b in range(batch_size):
+ def render_img(v):
+ rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device),
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ rgb_sample = np.stack(
+ [rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1
+ )
+ rgb_sample = add_text(rgb_sample, str(b))
+ return rgb_sample
+
+ view_num = len(batch_rays_list)
+ video_list = []
+ for v in tqdm(range(view_num//8*3, view_num//8*5, 2)):
+ rgb_sample = render_img(v)
+ video_list.append(rgb_sample)
+ big_video_list.append(video_list)
+ # if batch_size == 2:
+ # cat_video_list = [
+ # np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \
+ # for i in range(len(big_video_list[0]))
+ # ]
+ # elif batch_size > 2:
+ # if batch_size == 3:
+ # big_video_list.append(
+ # [np.zeros_like(f) for f in big_video_list[0]]
+ # )
+ # cat_video_list = [
+ # np.concatenate([
+ # np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
+ # np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
+ # ], 0) \
+ # for i in range(len(big_video_list[0]))
+ # ]
+ # else:
+ # cat_video_list = big_video_list[0]
+
+ for _ in range(4 - batch_size):
+ big_video_list.append(
+ [np.zeros_like(f) + 255 for f in big_video_list[0]]
+ )
+ cat_video_list = [
+ np.concatenate([
+ np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1),
+ np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1),
+ ], 0) \
+ for i in range(len(big_video_list[0]))
+ ]
+
+ path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
+ imageio.mimwrite(path, np.stack(cat_video_list, 0))
+
+ return global_info, path
+
+def infer_stage2(prompt, selection, seed, global_info):
+ prompt = prompt.replace('/', '')
+ mesh_path = marching_cube(int(selection), prompt, global_info)
+ mesh_name = mesh_path.split('/')[-1][:-4]
+
+ if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y"
+ print(if2_cmd)
+ # os.system(if2_cmd)
+ subprocess.Popen(if2_cmd, shell=True).wait()
+ torch.cuda.empty_cache()
+
+ video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4"
+ render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256"
+ print(render_cmd)
+ # os.system(render_cmd)
+ subprocess.Popen(render_cmd, shell=True).wait()
+ torch.cuda.empty_cache()
+
+ return video_path, os.path.join('tmp', mesh_name + '_if2.glb')
+
+block = gr.Blocks()
+
+with block:
+ global_info = gr.State(dict())
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ text = gr.Textbox(
+ label = "Enter your prompt",
+ max_lines = 1,
+ placeholder = "Enter your prompt",
+ container = False,
+ )
+ btn = gr.Button("Generate 3D")
+ gallery = gr.Video(height=512)
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
+ with gr.Row(elem_id="advanced-options"):
+ samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1)
+ steps = gr.Slider(label="Steps", minimum=1, maximum=500, value=50, step=1)
+ scale = gr.Slider(
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
+ )
+ seed = gr.Slider(
+ label="Seed",
+ minimum=0,
+ maximum=2147483647,
+ step=1,
+ randomize=True,
+ )
+ gr.on([text.submit, btn.click], infer, inputs=[text, samples, steps, scale, seed, global_info], outputs=[global_info, gallery])
+ advanced_button.click(
+ None,
+ [],
+ text,
+ )
+ with gr.Column():
+ with gr.Row():
+ dropdown = gr.Dropdown(
+ ['0', '1', '2', '3'], label="Choose a Candidate For Stage2", value='0'
+ )
+ btn_stage2 = gr.Button("Start Refinement")
+ gallery = gr.Video(height=512)
+ download = gr.File(label="Download Mesh", file_count="single", height=100)
+ gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info], outputs=[gallery, download])
+
+block.launch(share=True)
diff --git a/3DTopia/ldm/data/__init__.py b/3DTopia/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3DTopia/ldm/data/base.py b/3DTopia/ldm/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b196c2f7aa583a3e8bc4aad9f943df0c4dae0da7
--- /dev/null
+++ b/3DTopia/ldm/data/base.py
@@ -0,0 +1,23 @@
+from abc import abstractmethod
+from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass
\ No newline at end of file
diff --git a/3DTopia/ldm/data/imagenet.py b/3DTopia/ldm/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790
--- /dev/null
+++ b/3DTopia/ldm/data/imagenet.py
@@ -0,0 +1,394 @@
+import os, yaml, pickle, shutil, tarfile, glob
+import cv2
+import albumentations
+import PIL
+import numpy as np
+import torchvision.transforms.functional as TF
+from omegaconf import OmegaConf
+from functools import partial
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset, Subset
+
+import taming.data.utils as tdu
+from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
+from taming.data.imagenet import ImagePaths
+
+from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
+
+
+def synset2idx(path_to_yaml="data/index_synset.yaml"):
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ return dict((v,k) for k,v in di2s.items())
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._prepare_human_to_integer_label()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _prepare_human_to_integer_label(self):
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
+ if (not os.path.exists(self.human2integer)):
+ download(URL, self.human2integer)
+ with open(self.human2integer, "r") as f:
+ lines = f.read().splitlines()
+ assert len(lines) == 1000
+ self.human2integer_dict = dict()
+ for line in lines:
+ value, key = line.split(":")
+ self.human2integer_dict[key] = int(value)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ if not self.keep_orig_class_label:
+ self.class_labels = [class_dict[s] for s in self.synsets]
+ else:
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+
+ if self.process_images:
+ self.size = retrieve(self.config, "size", default=256)
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=self.size,
+ random_crop=self.random_crop,
+ )
+ else:
+ self.data = self.abspaths
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.process_images = process_images
+ self.data_root = data_root
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.data_root = data_root
+ self.process_images = process_images
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+
+class ImageNetSR(Dataset):
+ def __init__(self, size=None,
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
+ random_crop=True):
+ """
+ Imagenet Superresolution Dataloader
+ Performs following ops in order:
+ 1. crops a crop of size s from image either as random or center crop
+ 2. resizes crop to size with cv2.area_interpolation
+ 3. degrades resized crop with degradation_fn
+
+ :param size: resizing to size after cropping
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
+ :param downscale_f: Low Resolution Downsample factor
+ :param min_crop_f: determines crop size s,
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
+ :param max_crop_f: ""
+ :param data_root:
+ :param random_crop:
+ """
+ self.base = self.get_base()
+ assert size
+ assert (size / downscale_f).is_integer()
+ self.size = size
+ self.LR_size = int(size / downscale_f)
+ self.min_crop_f = min_crop_f
+ self.max_crop_f = max_crop_f
+ assert(max_crop_f <= 1.)
+ self.center_crop = not random_crop
+
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
+
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
+
+ if degradation == "bsrgan":
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
+
+ elif degradation == "bsrgan_light":
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
+
+ else:
+ interpolation_fn = {
+ "cv_nearest": cv2.INTER_NEAREST,
+ "cv_bilinear": cv2.INTER_LINEAR,
+ "cv_bicubic": cv2.INTER_CUBIC,
+ "cv_area": cv2.INTER_AREA,
+ "cv_lanczos": cv2.INTER_LANCZOS4,
+ "pil_nearest": PIL.Image.NEAREST,
+ "pil_bilinear": PIL.Image.BILINEAR,
+ "pil_bicubic": PIL.Image.BICUBIC,
+ "pil_box": PIL.Image.BOX,
+ "pil_hamming": PIL.Image.HAMMING,
+ "pil_lanczos": PIL.Image.LANCZOS,
+ }[degradation]
+
+ self.pil_interpolation = degradation.startswith("pil_")
+
+ if self.pil_interpolation:
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
+
+ else:
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
+ interpolation=interpolation_fn)
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = Image.open(example["file_path_"])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ image = np.array(image).astype(np.uint8)
+
+ min_side_len = min(image.shape[:2])
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
+ crop_side_len = int(crop_side_len)
+
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
+
+ else:
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
+
+ image = self.cropper(image=image)["image"]
+ image = self.image_rescaler(image=image)["image"]
+
+ if self.pil_interpolation:
+ image_pil = PIL.Image.fromarray(image)
+ LR_image = self.degradation_process(image_pil)
+ LR_image = np.array(LR_image).astype(np.uint8)
+
+ else:
+ LR_image = self.degradation_process(image=image)["image"]
+
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
+
+ return example
+
+
+class ImageNetSRTrain(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetTrain(process_images=False,)
+ return Subset(dset, indices)
+
+
+class ImageNetSRValidation(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetValidation(process_images=False,)
+ return Subset(dset, indices)
diff --git a/3DTopia/ldm/data/lsun.py b/3DTopia/ldm/data/lsun.py
new file mode 100644
index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e
--- /dev/null
+++ b/3DTopia/ldm/data/lsun.py
@@ -0,0 +1,92 @@
+import os
+import numpy as np
+import PIL
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+
+class LSUNBase(Dataset):
+ def __init__(self,
+ txt_file,
+ data_root,
+ size=None,
+ interpolation="bicubic",
+ flip_p=0.5
+ ):
+ self.data_paths = txt_file
+ self.data_root = data_root
+ with open(self.data_paths, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ }
+
+ self.size = size
+ self.interpolation = {"linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ }[interpolation]
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+ crop = min(img.shape[0], img.shape[1])
+ h, w, = img.shape[0], img.shape[1]
+ img = img[(h - crop) // 2:(h + crop) // 2,
+ (w - crop) // 2:(w + crop) // 2]
+
+ image = Image.fromarray(img)
+ if self.size is not None:
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip(image)
+ image = np.array(image).astype(np.uint8)
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
+ return example
+
+
+class LSUNChurchesTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
+
+
+class LSUNChurchesValidation(LSUNBase):
+ def __init__(self, flip_p=0., **kwargs):
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
+ flip_p=flip_p, **kwargs)
+
+
+class LSUNBedroomsTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
+
+
+class LSUNBedroomsValidation(LSUNBase):
+ def __init__(self, flip_p=0.0, **kwargs):
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
+ flip_p=flip_p, **kwargs)
+
+
+class LSUNCatsTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
+
+
+class LSUNCatsValidation(LSUNBase):
+ def __init__(self, flip_p=0., **kwargs):
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
+ flip_p=flip_p, **kwargs)
diff --git a/3DTopia/ldm/lr_scheduler.py b/3DTopia/ldm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade
--- /dev/null
+++ b/3DTopia/ldm/lr_scheduler.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n,**kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
+
diff --git a/3DTopia/ldm/models/autoencoder.py b/3DTopia/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a9c4f45498561953b8085981609b2a3298a5473
--- /dev/null
+++ b/3DTopia/ldm/models/autoencoder.py
@@ -0,0 +1,443 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ batch_resize_range=None,
+ scheduler_config=None,
+ lr_g_factor=1.0,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ use_ema=False
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.n_embed = n_embed
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap,
+ sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.batch_resize_range = batch_resize_range
+ if self.batch_resize_range is not None:
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.scheduler_config = scheduler_config
+ self.lr_g_factor = lr_g_factor
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ print(f"Unexpected Keys: {unexpected}")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input, return_pred_indices=False):
+ quant, diff, (_,_,ind) = self.encode(input)
+ dec = self.decode(quant)
+ if return_pred_indices:
+ return dec, diff, ind
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ if self.batch_resize_range is not None:
+ lower_size = self.batch_resize_range[0]
+ upper_size = self.batch_resize_range[1]
+ if self.global_step <= 4:
+ # do the first few batches with max size to avoid later oom
+ new_resize = upper_size
+ else:
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+ if new_resize != x.shape[2]:
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
+ x = x.detach()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ # https://github.com/pytorch/pytorch/issues/37142
+ # try not to fool the heuristics
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train",
+ predicted_indices=ind)
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, suffix=""):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log(f"val{suffix}/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log(f"val{suffix}/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
+ del log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr_d = self.learning_rate
+ lr_g = self.lr_g_factor*self.learning_rate
+ print("lr_d", lr_d)
+ print("lr_g", lr_g)
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr_g, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr_d, betas=(0.5, 0.9))
+
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ {
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ ]
+ return [opt_ae, opt_disc], scheduler
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if only_inputs:
+ log["inputs"] = x
+ return log
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ if plot_ema:
+ with self.ema_scope():
+ xrec_ema, _ = self(x)
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQModelInterface(VQModel):
+ def __init__(self, embed_dim, *args, **kwargs):
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ self.embed_dim = embed_dim
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, h, force_not_quantize=False):
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/3DTopia/ldm/models/diffusion/__init__.py b/3DTopia/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3DTopia/ldm/models/diffusion/classifier.py b/3DTopia/ldm/models/diffusion/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/classifier.py
@@ -0,0 +1,267 @@
+import os
+import torch
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+from torch.nn import functional as F
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import LambdaLR
+from copy import deepcopy
+from einops import rearrange
+from glob import glob
+from natsort import natsorted
+
+from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
+from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
+
+__models__ = {
+ 'class_label': EncoderUNetModel,
+ 'segmentation': UNetModel
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class NoisyLatentImageClassifier(pl.LightningModule):
+
+ def __init__(self,
+ diffusion_path,
+ num_classes,
+ ckpt_path=None,
+ pool='attention',
+ label_key=None,
+ diffusion_ckpt_path=None,
+ scheduler_config=None,
+ weight_decay=1.e-2,
+ log_steps=10,
+ monitor='val/loss',
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.num_classes = num_classes
+ # get latest config of diffusion model
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
+ self.load_diffusion()
+
+ self.monitor = monitor
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
+ self.log_steps = log_steps
+
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
+ else self.diffusion_model.cond_stage_key
+
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
+
+ if self.label_key not in __models__:
+ raise NotImplementedError()
+
+ self.load_classifier(ckpt_path, pool)
+
+ self.scheduler_config = scheduler_config
+ self.use_scheduler = self.scheduler_config is not None
+ self.weight_decay = weight_decay
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def load_diffusion(self):
+ model = instantiate_from_config(self.diffusion_config)
+ self.diffusion_model = model.eval()
+ self.diffusion_model.train = disabled_train
+ for param in self.diffusion_model.parameters():
+ param.requires_grad = False
+
+ def load_classifier(self, ckpt_path, pool):
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
+ model_config.out_channels = self.num_classes
+ if self.label_key == 'class_label':
+ model_config.pool = pool
+
+ self.model = __models__[self.label_key](**model_config)
+ if ckpt_path is not None:
+ print('#####################################################################')
+ print(f'load from ckpt "{ckpt_path}"')
+ print('#####################################################################')
+ self.init_from_ckpt(ckpt_path)
+
+ @torch.no_grad()
+ def get_x_noisy(self, x, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x))
+ continuous_sqrt_alpha_cumprod = None
+ if self.diffusion_model.use_continuous_noise:
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
+ # todo: make sure t+1 is correct here
+
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
+
+ def forward(self, x_noisy, t, *args, **kwargs):
+ return self.model(x_noisy, t)
+
+ @torch.no_grad()
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ @torch.no_grad()
+ def get_conditioning(self, batch, k=None):
+ if k is None:
+ k = self.label_key
+ assert k is not None, 'Needs to provide label key'
+
+ targets = batch[k].to(self.device)
+
+ if self.label_key == 'segmentation':
+ targets = rearrange(targets, 'b h w c -> b c h w')
+ for down in range(self.numd):
+ h, w = targets.shape[-2:]
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
+
+ # targets = rearrange(targets,'b c h w -> b h w c')
+
+ return targets
+
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
+ _, top_ks = torch.topk(logits, k, dim=1)
+ if reduction == "mean":
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
+ elif reduction == "none":
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
+
+ def on_train_epoch_start(self):
+ # save some memory
+ self.diffusion_model.model.to('cpu')
+
+ @torch.no_grad()
+ def write_logs(self, loss, logits, targets):
+ log_prefix = 'train' if self.training else 'val'
+ log = {}
+ log[f"{log_prefix}/loss"] = loss.mean()
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
+ logits, targets, k=1, reduction="mean"
+ )
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
+ logits, targets, k=5, reduction="mean"
+ )
+
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
+
+ def shared_step(self, batch, t=None):
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
+ targets = self.get_conditioning(batch)
+ if targets.dim() == 4:
+ targets = targets.argmax(dim=1)
+ if t is None:
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
+ else:
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
+ x_noisy = self.get_x_noisy(x, t)
+ logits = self(x_noisy, t)
+
+ loss = F.cross_entropy(logits, targets, reduction='none')
+
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
+
+ loss = loss.mean()
+ return loss, logits, x_noisy, targets
+
+ def training_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+ return loss
+
+ def reset_noise_accs(self):
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
+
+ def on_validation_start(self):
+ self.reset_noise_accs()
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+
+ for t in self.noisy_acc:
+ _, logits, _, targets = self.shared_step(batch, t)
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
+
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
+
+ if self.use_scheduler:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [optimizer], scheduler
+
+ return optimizer
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, *args, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
+ log['inputs'] = x
+
+ y = self.get_conditioning(batch)
+
+ if self.label_key == 'class_label':
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['labels'] = y
+
+ if ismap(y):
+ log['labels'] = self.diffusion_model.to_rgb(y)
+
+ for step in range(self.log_steps):
+ current_time = step * self.log_time_interval
+
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
+
+ log[f'inputs@t{current_time}'] = x_noisy
+
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
+ pred = rearrange(pred, 'b h w c -> b c h w')
+
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
+
+ for key in log:
+ log[key] = log[key][:N]
+
+ return log
diff --git a/3DTopia/ldm/models/diffusion/ddim.py b/3DTopia/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb31215db5c3f3f703f15987d7eee6a179c9f7ec
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/ddim.py
@@ -0,0 +1,241 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
+ extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ return x_dec
\ No newline at end of file
diff --git a/3DTopia/ldm/models/diffusion/ddpm.py b/3DTopia/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..79a84e0632a2303d6a863e27a72e55a34fa629bb
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1746 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import os
+import wandb
+import torch
+import imageio
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.rank_zero import rank_zero_only
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from module.model_2d import DiagonalGaussianDistribution
+from ldm.modules.distributions.distributions import normal_kl
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+from utility.triplane_renderer.renderer import to8b
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ learning_rate=1e-4,
+ shift_scale=None,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.beta_schedule = beta_schedule
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, shift_scale=shift_scale)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.learning_rate = learning_rate
+
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s, shift_scale=shift_scale)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ print("Using timesteps of {}".format(self.num_timesteps))
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # print("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, t, model_out)
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if isinstance(x, list):
+ return x
+ if len(x.shape) == 3:
+ x = x[..., None]
+ # x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=False,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=False, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_shift=0.0,
+ scale_by_std=False,
+ use_3daware=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ self.scale_shift = scale_shift
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ self.use_3daware = use_3daware
+
+ self.is_test = False
+
+ self.test_mode = None
+ self.test_tag = ""
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s, shift_scale)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.mode()
+ # z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * (z + self.scale_shift)
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c).float()
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox']:
+ xc = batch[cond_key]
+ elif cond_key == 'class_label':
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ # import pudb; pudb.set_trace()
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ # assert not predict_cids
+ # if predict_cids:
+ # if z.dim() == 4:
+ # z = torch.argmax(z.exp(), dim=1).long()
+ # z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ # z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ # import os
+ # import random
+ # import string
+ # z_np = z.detach().cpu().numpy()
+ # fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy'
+ # with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f:
+ # np.save(f, z_np)
+
+ z = 1. / self.scale_factor * z - self.scale_shift
+
+ # if hasattr(self, "split_input_params"):
+ # if self.split_input_params["patch_distributed_vq"]:
+ # ks = self.split_input_params["ks"] # eg. (128, 128)
+ # stride = self.split_input_params["stride"] # eg. (64, 64)
+ # uf = self.split_input_params["vqf"]
+ # bs, nc, h, w = z.shape
+ # if ks[0] > h or ks[1] > w:
+ # ks = (min(ks[0], h), min(ks[1], w))
+ # print("reducing Kernel")
+
+ # if stride[0] > h or stride[1] > w:
+ # stride = (min(stride[0], h), min(stride[1], w))
+ # print("reducing stride")
+
+ # fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ # z = unfold(z) # (bn, nc * prod(**ks), L)
+ # # 1. Reshape to img shape
+ # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # # 2. apply model loop over last dim
+ # if isinstance(self.first_stage_model, VQModelInterface):
+ # output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ # force_not_quantize=predict_cids or force_not_quantize)
+ # for i in range(z.shape[-1])]
+ # else:
+
+ # output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ # for i in range(z.shape[-1])]
+
+ # o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ # o = o * weighting
+ # # Reverse 1. reshape to img shape
+ # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # # stitch crops together
+ # decoded = fold(o)
+ # decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ # return decoded
+ # else:
+ # if isinstance(self.first_stage_model, VQModelInterface):
+ # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ # else:
+ # return self.first_stage_model.decode(z)
+
+ # else:
+ # if isinstance(self.first_stage_model, VQModelInterface):
+ # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ # else:
+ return self.first_stage_model.decode(z, unrollout=True)
+
+ # same as above but without decorator
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z - self.scale_shift
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ # if hasattr(self, "split_input_params"):
+ # if self.split_input_params["patch_distributed_vq"]:
+ # ks = self.split_input_params["ks"] # eg. (128, 128)
+ # stride = self.split_input_params["stride"] # eg. (64, 64)
+ # df = self.split_input_params["vqf"]
+ # self.split_input_params['original_image_size'] = x.shape[-2:]
+ # bs, nc, h, w = x.shape
+ # if ks[0] > h or ks[1] > w:
+ # ks = (min(ks[0], h), min(ks[1], w))
+ # print("reducing Kernel")
+
+ # if stride[0] > h or stride[1] > w:
+ # stride = (min(stride[0], h), min(stride[1], w))
+ # print("reducing stride")
+
+ # fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ # z = unfold(x) # (bn, nc * prod(**ks), L)
+ # # Reshape to img shape
+ # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ # for i in range(z.shape[-1])]
+
+ # o = torch.stack(output_list, axis=-1)
+ # o = o * weighting
+
+ # # Reverse reshape to img shape
+ # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # # stitch crops together
+ # decoded = fold(o)
+ # decoded = decoded / normalization
+ # return decoded
+
+ # else:
+ # return self.first_stage_model.encode(x)
+ # else:
+ return self.first_stage_model.encode(x, rollout=True)
+
+ def get_norm(self, x):
+ norm = torch.linalg.norm(x, dim=-1, keepdim=True)
+ norm[norm == 0] = 1
+
+ assert norm.shape[-1] == 1
+ assert norm.shape[0] == x.shape[0]
+ assert norm.shape[1] == x.shape[1]
+ assert x.shape[1] == 1
+
+ return norm
+
+ def random_text_feature_noise(self, c):
+ noise = torch.randn_like(c)
+ # alpha = 0.999
+ alpha = 1
+ nc = alpha * c / self.get_norm(c) + (1 - alpha) * noise / self.get_norm(noise)
+ nc = nc / self.get_norm(nc)
+
+ import random
+ if random.randint(0, 10) == 0:
+ nc[:] = 0
+ nc = c
+
+ return nc
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ # print("Random augment text feature...")
+ c = self.random_text_feature_noise(c)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c=None, return_inter=False, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, return_inter=return_inter, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.AvgPool2d((res, 1))
+ y_mp = torch.nn.AvgPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ # B, C, H, W = h.shape
+ # h_xy = th.cat([h[..., 0:(W//3)], h[..., (W//3):(2*W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1)
+ # h_xz = th.cat([h[..., (W//3):(2*W//3)], h[..., 0:(W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3)], 1)
+ # h_zy = th.cat([h[..., (2*W//3):W], h[..., 0:(W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1), h[..., (W//3):(2*W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1)
+ # h = th.cat([h_xy, h_xz, h_zy], -1)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ print(patch_limits_tknzd[0].shape)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+ print(cut_cond.shape)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ print(adapted_cond.shape)
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ print(adapted_cond.shape)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+ print(adapted_cond.shape)
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ if self.use_3daware:
+ x_noisy_3daware = self.to3daware(x_noisy)
+ x_recon = self.model(x_noisy_3daware, t, **cond)
+ else:
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None, return_inter=False):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ if return_inter:
+ return loss, loss_dict, self.predict_start_from_noise(x_noisy, t=t, noise=model_output)
+ else:
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, t, model_out)
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ # if self.is_test and i % 50 == 0:
+ # decode_res = self.decode_first_stage(img)
+ # rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ # decode_res, self.batch_rays, self.batch_img,
+ # )
+ # rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_{}.png".format(i)), rgb_sample)
+ # colorize_res = self.first_stage_model.to_rgb(img)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_latent_{}.png".format(i)), colorize_res[0])
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None,**kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size * 3)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+ shape,cond,verbose=False,**kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True,**kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ # x, c = self.get_input(batch, self.first_stage_key)
+ # self.batch_rays = batch['batch_rays'][0][1:2]
+ # self.batch_img = batch['img'][0][1:2]
+ # self.is_test = True
+ # self.test_schedule(x[0:1])
+ # exit(0)
+
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ # _, loss_dict_ema = self.shared_step(batch)
+ x, c = self.get_input(batch, self.first_stage_key)
+ _, loss_dict_ema, inter_res = self(x, c, return_inter=True)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+
+ if batch_idx < 2:
+ if self.num_timesteps < 1000:
+ x_T = self.q_sample(x_start=x[0:1], t=torch.full((1,), self.num_timesteps-1, device=x.device, dtype=torch.long), noise=torch.randn_like(x[0:1]))
+ print("Specifying x_T when sampling!")
+ else:
+ x_T = None
+ with self.ema_scope():
+ res = self.sample(c, 1, shape=x[0:1].shape, x_T = x_T)
+ decode_res = self.decode_first_stage(res)
+ decode_input = self.decode_first_stage(x[:1])
+ decode_output = self.decode_first_stage(inter_res[:1])
+
+ colorize_res = self.first_stage_model.to_rgb(res)[0]
+ colorize_x = self.first_stage_model.to_rgb(x[:1])[0]
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0])
+
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res, batch['batch_rays'][0], batch['img'][0],
+ )
+ rgb_input, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_input, batch['batch_rays'][0], batch['img'][0],
+ )
+ rgb_output, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_output, batch['batch_rays'][0], batch['img'][0],
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_output = to8b(rgb_output.detach().cpu().numpy())
+
+ if rgb_sample.shape[0] == 1:
+ rgb_all = np.concatenate([rgb_sample[0], rgb_input[0], rgb_output[0]], 1)
+ else:
+ rgb_all = np.concatenate([rgb_sample[1], rgb_input[1], rgb_output[1]], 1)
+
+ rgb_all = np.stack([rgb_all[..., 2], rgb_all[..., 1], rgb_all[..., 0]], -1)
+
+ if self.model.conditioning_key is not None:
+ if self.cond_stage_key == 'img_cond':
+ cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0)
+ rgb_all = np.concatenate([rgb_all, to8b(cond_img.cpu().numpy())], 1)
+ else:
+ import cv2
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ # org
+ org = (50, 50)
+ # fontScale
+ fontScale = 1
+ # Blue color in BGR
+ color = (255, 0, 0)
+ # Line thickness of 2 px
+ thickness = 2
+ caption = super().get_input(batch, self.cond_stage_key)[0]
+ break_caption = []
+ for i in range(len(caption) // 30 + 1):
+ break_caption_i = caption[i*30:(i+1)*30]
+ break_caption.append(break_caption_i)
+ for i, bci in enumerate(break_caption):
+ cv2.putText(rgb_all, bci, (50, 50*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
+
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)],
+ "val/colorize_rse": [wandb.Image(colorize_res)],
+ "val/colorize_x": [wandb.Image(colorize_x)],
+ })
+
+ @torch.no_grad()
+ def test_schedule(self, x_start, freq=50):
+ noise = torch.randn_like(x_start)
+ img_list = []
+ latent_list = []
+ for t in tqdm(range(self.num_timesteps)):
+ if t % freq == 0:
+ t_long = torch.Tensor([t,]).long().to(x_start.device)
+ x_noisy = self.q_sample(x_start=x_start, t=t_long, noise=noise)
+ decode_res = self.decode_first_stage(x_noisy)
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res, self.batch_rays, self.batch_img,
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}.png".format(t)), rgb_sample)
+ colorize_res = self.first_stage_model.to_rgb(x_noisy)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}.png".format(t)), colorize_res[0])
+ img_list.append(rgb_sample)
+ latent_list.append(colorize_res[0])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(img_list, 1))
+ imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(latent_list, 1))
+
+ @torch.no_grad()
+ def test_step(self, batch, batch_idx):
+ x, c = self.get_input(batch, self.first_stage_key)
+ if self.test_mode == 'fid':
+ bs = x.shape[0]
+ else:
+ bs = 1
+ if self.test_mode == 'noise_schedule':
+ self.batch_rays = batch['batch_rays'][0][33:34]
+ self.batch_img = batch['img'][0][33:34]
+ self.is_test = True
+ self.test_schedule(x)
+ exit(0)
+ with self.ema_scope():
+ if c is not None:
+ res = self.sample(c[:bs], bs, shape=x[0:bs].shape)
+ else:
+ res = self.sample(None, bs, shape=x[0:bs].shape)
+ decode_res = self.decode_first_stage(res)
+ if self.test_mode == 'fid':
+ folder = os.path.join(self.logger.log_dir, 'FID_' + self.test_tag)
+ if not os.path.exists(folder):
+ os.makedirs(folder, exist_ok=True)
+ rgb_sample_list = []
+ for b in range(bs):
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())
+ rgb_sample_list.append(rgb_sample)
+ for i in range(len(rgb_sample_list)):
+ for v in range(rgb_sample_list[i].shape[0]):
+ imageio.imwrite(os.path.join(folder, "sample_{}_{}_{}.png".format(batch_idx, i, v)), rgb_sample_list[i][v])
+ elif self.test_mode == 'sample':
+ colorize_res = self.first_stage_model.to_rgb(res)
+ colorize_x = self.first_stage_model.to_rgb(x[:1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0])
+ if self.model.conditioning_key is not None:
+ cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0)
+ cond_img = to8b(cond_img.cpu().numpy())
+ imageio.imwrite(os.path.join(self.logger.log_dir, "cond_{}_{}.png".format(batch_idx, 0)), cond_img)
+ for b in range(bs):
+ video = []
+ for v in tqdm(range(batch['batch_rays'].shape[1])):
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res[b:b+1], batch['batch_rays'][0][v:v+1], batch['img'][0][v:v+1],
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ video.append(rgb_sample)
+ imageio.mimwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b)), video, fps=24)
+ print("Saving to {}".format(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b))))
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, **kwargs):
+
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with self.ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with self.ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ with self.ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with self.ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class Layout2ImgDiffusion(LatentDiffusion):
+ # TODO: move all layout-specific hacks to this class
+ def __init__(self, cond_stage_key, *args, **kwargs):
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+ def log_images(self, batch, N=8, *args, **kwargs):
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+ key = 'train' if self.training else 'validation'
+ dset = self.trainer.datamodule.datasets[key]
+ mapper = dset.conditional_builders[self.cond_stage_key]
+
+ bbox_imgs = []
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+ bbox_imgs.append(bboximg)
+
+ cond_img = torch.stack(bbox_imgs, dim=0)
+ logs['bbox_image'] = cond_img
+ return logs
diff --git a/3DTopia/ldm/models/diffusion/ddpm_preprocess.py b/3DTopia/ldm/models/diffusion/ddpm_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..74926ca5dcfd7e4b1f283e7cd03de42ced5e74ee
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/ddpm_preprocess.py
@@ -0,0 +1,1716 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import os
+import wandb
+import torch
+import imageio
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.rank_zero import rank_zero_only
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from module.model_2d import DiagonalGaussianDistribution
+from ldm.modules.distributions.distributions import normal_kl
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+from utility.triplane_renderer.renderer import to8b
+
+import ipdb
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ learning_rate=1e-4,
+ shift_scale=None,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.beta_schedule = beta_schedule
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, shift_scale=shift_scale)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.learning_rate = learning_rate
+
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s, shift_scale=shift_scale)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ print("Using timesteps of {}".format(self.num_timesteps))
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # print("sqrt_alphas_cumprod", np.sqrt(alphas_cumprod))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, t, model_out)
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if isinstance(x, list):
+ return x
+ # if len(x.shape) == 3:
+ # x = x[..., None]
+ # x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=False,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=False, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_shift=0.0,
+ scale_by_std=False,
+ use_3daware=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ self.scale_shift = scale_shift
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ # self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ self.use_3daware = use_3daware
+
+ self.is_test = False
+
+ self.test_mode = None
+ self.test_tag = ""
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s, shift_scale)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * (z + self.scale_shift)
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c).float()
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None):
+ #ipdb.set_trace()
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ #ipdb.set_trace()
+ z = x.to(self.device) #[:1,:8,:32,:96]
+ z = self.scale_factor * (z + self.scale_shift)
+ # encoder_posterior = self.encode_first_stage(x)
+ # z = self.get_first_stage_encoding(encoder_posterior).detach()
+ #ipdb.set_trace()
+ if self.model.conditioning_key is not None:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox']:
+ xc = batch[cond_key]
+ elif cond_key == 'class_label':
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ #ipdb.set_trace()
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ # assert not predict_cids
+ # if predict_cids:
+ # if z.dim() == 4:
+ # z = torch.argmax(z.exp(), dim=1).long()
+ # z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ # z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ # import os
+ # import random
+ # import string
+ # z_np = z.detach().cpu().numpy()
+ # fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy'
+ # with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f:
+ # np.save(f, z_np)
+
+ z = 1. / self.scale_factor * z - self.scale_shift
+
+ # if hasattr(self, "split_input_params"):
+ # if self.split_input_params["patch_distributed_vq"]:
+ # ks = self.split_input_params["ks"] # eg. (128, 128)
+ # stride = self.split_input_params["stride"] # eg. (64, 64)
+ # uf = self.split_input_params["vqf"]
+ # bs, nc, h, w = z.shape
+ # if ks[0] > h or ks[1] > w:
+ # ks = (min(ks[0], h), min(ks[1], w))
+ # print("reducing Kernel")
+
+ # if stride[0] > h or stride[1] > w:
+ # stride = (min(stride[0], h), min(stride[1], w))
+ # print("reducing stride")
+
+ # fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ # z = unfold(z) # (bn, nc * prod(**ks), L)
+ # # 1. Reshape to img shape
+ # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # # 2. apply model loop over last dim
+ # if isinstance(self.first_stage_model, VQModelInterface):
+ # output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ # force_not_quantize=predict_cids or force_not_quantize)
+ # for i in range(z.shape[-1])]
+ # else:
+
+ # output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ # for i in range(z.shape[-1])]
+
+ # o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ # o = o * weighting
+ # # Reverse 1. reshape to img shape
+ # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # # stitch crops together
+ # decoded = fold(o)
+ # decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ # return decoded
+ # else:
+ # if isinstance(self.first_stage_model, VQModelInterface):
+ # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ # else:
+ # return self.first_stage_model.decode(z)
+
+ # else:
+ # if isinstance(self.first_stage_model, VQModelInterface):
+ # return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ # else:
+ return self.first_stage_model.decode(z, unrollout=True)
+
+ # same as above but without decorator
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z - self.scale_shift
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ # if hasattr(self, "split_input_params"):
+ # if self.split_input_params["patch_distributed_vq"]:
+ # ks = self.split_input_params["ks"] # eg. (128, 128)
+ # stride = self.split_input_params["stride"] # eg. (64, 64)
+ # df = self.split_input_params["vqf"]
+ # self.split_input_params['original_image_size'] = x.shape[-2:]
+ # bs, nc, h, w = x.shape
+ # if ks[0] > h or ks[1] > w:
+ # ks = (min(ks[0], h), min(ks[1], w))
+ # print("reducing Kernel")
+
+ # if stride[0] > h or stride[1] > w:
+ # stride = (min(stride[0], h), min(stride[1], w))
+ # print("reducing stride")
+
+ # fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ # z = unfold(x) # (bn, nc * prod(**ks), L)
+ # # Reshape to img shape
+ # z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ # for i in range(z.shape[-1])]
+
+ # o = torch.stack(output_list, axis=-1)
+ # o = o * weighting
+
+ # # Reverse reshape to img shape
+ # o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # # stitch crops together
+ # decoded = fold(o)
+ # decoded = decoded / normalization
+ # return decoded
+
+ # else:
+ # return self.first_stage_model.encode(x)
+ # else:
+ return self.first_stage_model.encode(x, rollout=True)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, cond=None, return_inter=False, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ #ipdb.set_trace()
+ if self.model.conditioning_key is not None:
+ assert cond is not None
+ if self.cond_stage_trainable:
+ cond = self.get_learned_conditioning(cond)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond.float()))
+ #ipdb.set_trace()
+ return self.p_losses(x, cond, t, return_inter=return_inter, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.AvgPool2d((res, 1))
+ y_mp = torch.nn.AvgPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ # B, C, H, W = h.shape
+ # h_xy = th.cat([h[..., 0:(W//3)], h[..., (W//3):(2*W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1)
+ # h_xz = th.cat([h[..., (W//3):(2*W//3)], h[..., 0:(W//3)].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3), h[..., (2*W//3):W].mean(-1).unsqueeze(-1).repeat(1, 1, 1, W//3)], 1)
+ # h_zy = th.cat([h[..., (2*W//3):W], h[..., 0:(W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1), h[..., (W//3):(2*W//3)].mean(-2).unsqueeze(-2).repeat(1, 1, H, 1)], 1)
+ # h = th.cat([h_xy, h_xz, h_zy], -1)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ #ipdb.set_trace()
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ print(patch_limits_tknzd[0].shape)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+ print(cut_cond.shape)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ print(adapted_cond.shape)
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ print(adapted_cond.shape)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+ print(adapted_cond.shape)
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ if self.use_3daware:
+ x_noisy_3daware = self.to3daware(x_noisy)
+ x_recon = self.model(x_noisy_3daware, t, **cond)
+ else:
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None, return_inter=False):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ if return_inter:
+ return loss, loss_dict, self.predict_start_from_noise(x_noisy, t=t, noise=model_output)
+ else:
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, t, model_out)
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ # if self.is_test and i % 50 == 0:
+ # decode_res = self.decode_first_stage(img)
+ # rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ # decode_res, self.batch_rays, self.batch_img,
+ # )
+ # rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_{}.png".format(i)), rgb_sample)
+ # colorize_res = self.first_stage_model.to_rgb(img)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_process_latent_{}.png".format(i)), colorize_res[0])
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None,**kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size * 3)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+ shape,cond,verbose=False,**kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True,**kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ # x, c = self.get_input(batch, self.first_stage_key)
+ # self.batch_rays = batch['batch_rays'][0][1:2]
+ # self.batch_img = batch['img'][0][1:2]
+ # self.is_test = True
+ # self.test_schedule(x[0:1])
+ # exit(0)
+
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ # _, loss_dict_ema = self.shared_step(batch)
+ x, c = self.get_input(batch, self.first_stage_key)
+ _, loss_dict_ema, inter_res = self(x, c, return_inter=True)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+
+ if batch_idx < 2:
+ if self.num_timesteps < 1000:
+ x_T = self.q_sample(x_start=x[0:1], t=torch.full((1,), self.num_timesteps-1, device=x.device, dtype=torch.long), noise=torch.randn_like(x[0:1]))
+ print("Specifying x_T when sampling!")
+ else:
+ x_T = None
+ with self.ema_scope():
+ res = self.sample(c, 1, shape=x[0:1].shape, x_T = x_T)
+ decode_res = self.decode_first_stage(res)
+ decode_input = self.decode_first_stage(x[:1])
+ decode_output = self.decode_first_stage(inter_res[:1])
+
+ colorize_res = self.first_stage_model.to_rgb(res)[0]
+ colorize_x = self.first_stage_model.to_rgb(x[:1])[0]
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0])
+
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res, batch['batch_rays'][0], batch['img'][0],
+ )
+ rgb_input, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_input, batch['batch_rays'][0], batch['img'][0],
+ )
+ rgb_output, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_output, batch['batch_rays'][0], batch['img'][0],
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_output = to8b(rgb_output.detach().cpu().numpy())
+
+ if rgb_sample.shape[0] == 1:
+ rgb_all = np.concatenate([rgb_sample[0], rgb_input[0], rgb_output[0]], 1)
+ else:
+ rgb_all = np.concatenate([rgb_sample[1], rgb_input[1], rgb_output[1]], 1)
+
+
+ if self.model.conditioning_key is not None:
+ if self.cond_stage_key == 'img_cond':
+ cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0)
+ rgb_all = np.concatenate([rgb_all, to8b(cond_img.cpu().numpy())], 1)
+ elif 'caption' in self.cond_stage_key:
+ import cv2
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ # org
+ org = (50, 50)
+ # fontScale
+ fontScale = 1
+ # Blue color in BGR
+ color = (255, 0, 0)
+ # Line thickness of 2 px
+ thickness = 2
+ caption = super().get_input(batch, 'caption')[0]
+ break_caption = []
+ for i in range(len(caption) // 30 + 1):
+ break_caption_i = caption[i*30:(i+1)*30]
+ break_caption.append(break_caption_i)
+ for i, bci in enumerate(break_caption):
+ cv2.putText(rgb_all, bci, (50, 50*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
+
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)],
+ "val/colorize_rse": [wandb.Image(colorize_res)],
+ "val/colorize_x": [wandb.Image(colorize_x)],
+ })
+
+ @torch.no_grad()
+ def test_schedule(self, x_start, freq=50):
+ noise = torch.randn_like(x_start)
+ img_list = []
+ latent_list = []
+ for t in tqdm(range(self.num_timesteps)):
+ if t % freq == 0:
+ t_long = torch.Tensor([t,]).long().to(x_start.device)
+ x_noisy = self.q_sample(x_start=x_start, t=t_long, noise=noise)
+ decode_res = self.decode_first_stage(x_noisy)
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res, self.batch_rays, self.batch_img,
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}.png".format(t)), rgb_sample)
+ colorize_res = self.first_stage_model.to_rgb(x_noisy)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}.png".format(t)), colorize_res[0])
+ img_list.append(rgb_sample)
+ latent_list.append(colorize_res[0])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(img_list, 1))
+ imageio.imwrite(os.path.join(self.logger.log_dir, "add_noise_latent_{}_{}_{}_{}.png".format(self.linear_start, self.linear_end, self.beta_schedule, self.scale_factor)), np.concatenate(latent_list, 1))
+
+ @torch.no_grad()
+ def test_step(self, batch, batch_idx):
+ x, c = self.get_input(batch, self.first_stage_key)
+ if self.test_mode == 'fid':
+ bs = x.shape[0]
+ else:
+ bs = 1
+ if self.test_mode == 'noise_schedule':
+ self.batch_rays = batch['batch_rays'][0][33:34]
+ self.batch_img = batch['img'][0][33:34]
+ self.is_test = True
+ self.test_schedule(x)
+ exit(0)
+ with self.ema_scope():
+ if c is not None:
+ res = self.sample(c[:bs], bs, shape=x[0:bs].shape)
+ else:
+ res = self.sample(None, bs, shape=x[0:bs].shape)
+ decode_res = self.decode_first_stage(res)
+ if self.test_mode == 'fid':
+ folder = os.path.join(self.logger.log_dir, 'FID_' + self.test_tag)
+ if not os.path.exists(folder):
+ os.makedirs(folder, exist_ok=True)
+ rgb_sample_list = []
+ for b in range(bs):
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())
+ rgb_sample_list.append(rgb_sample)
+ for i in range(len(rgb_sample_list)):
+ for v in range(rgb_sample_list[i].shape[0]):
+ imageio.imwrite(os.path.join(folder, "sample_{}_{}_{}.png".format(batch_idx, i, v)), rgb_sample_list[i][v])
+ elif self.test_mode == 'sample':
+ colorize_res = self.first_stage_model.to_rgb(res)
+ colorize_x = self.first_stage_model.to_rgb(x[:1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.png".format(batch_idx, 0)), colorize_res[0])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "gt_{}_{}.png".format(batch_idx, 0)), colorize_x[0])
+ if self.model.conditioning_key is not None:
+ cond_img = super().get_input(batch, self.cond_stage_key)[0].permute(1, 2, 0)
+ cond_img = to8b(cond_img.cpu().numpy())
+ imageio.imwrite(os.path.join(self.logger.log_dir, "cond_{}_{}.png".format(batch_idx, 0)), cond_img)
+ for b in range(bs):
+ video = []
+ for v in tqdm(range(batch['batch_rays'].shape[1])):
+ rgb_sample, _ = self.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res[b:b+1], batch['batch_rays'][0][v:v+1], batch['img'][0][v:v+1],
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ video.append(rgb_sample)
+ imageio.mimwrite(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b)), video, fps=24)
+ print("Saving to {}".format(os.path.join(self.logger.log_dir, "sample_{}_{}.mp4".format(batch_idx, b))))
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, **kwargs):
+
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with self.ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with self.ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ with self.ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with self.ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class Layout2ImgDiffusion(LatentDiffusion):
+ # TODO: move all layout-specific hacks to this class
+ def __init__(self, cond_stage_key, *args, **kwargs):
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+ def log_images(self, batch, N=8, *args, **kwargs):
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+ key = 'train' if self.training else 'validation'
+ dset = self.trainer.datamodule.datasets[key]
+ mapper = dset.conditional_builders[self.cond_stage_key]
+
+ bbox_imgs = []
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+ bbox_imgs.append(bboximg)
+
+ cond_img = torch.stack(bbox_imgs, dim=0)
+ logs['bbox_image'] = cond_img
+ return logs
diff --git a/3DTopia/ldm/models/diffusion/dpm_solver/__init__.py b/3DTopia/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/3DTopia/ldm/models/diffusion/dpm_solver/dpm_solver.py b/3DTopia/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdb64e0c78cc3520f92d79db3124c85fc3cfb9b4
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1184 @@
+import torch
+import torch.nn.functional as F
+import math
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0**2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3,] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3,] * (K - 1) + [1]
+ else:
+ orders = [3,] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2,] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2,] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1,] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+
+ =====================================================
+
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+
+ =====================================================
+
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in range(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order,] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,)*(dims - 1)]
\ No newline at end of file
diff --git a/3DTopia/ldm/models/diffusion/dpm_solver/sampler.py b/3DTopia/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c42d6f964d92658e769df95a81dec92250e5a99
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,82 @@
+"""SAMPLING ONLY."""
+
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type="noise",
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+ return x.to(device), None
diff --git a/3DTopia/ldm/models/diffusion/plms.py b/3DTopia/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..78eeb1003aa45d27bdbfc6b4a1d7ccbff57cd2e3
--- /dev/null
+++ b/3DTopia/ldm/models/diffusion/plms.py
@@ -0,0 +1,236 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/3DTopia/ldm/modules/attention.py b/3DTopia/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4eff39ccb6d75daa764f6eb70a7cef024fb5a3f
--- /dev/null
+++ b/3DTopia/ldm/modules/attention.py
@@ -0,0 +1,261 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
+ super().__init__()
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
\ No newline at end of file
diff --git a/3DTopia/ldm/modules/diffusionmodules/__init__.py b/3DTopia/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3DTopia/ldm/modules/diffusionmodules/model.py b/3DTopia/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..533e589a2024f1d7c52093d8c472c3b1b6617e26
--- /dev/null
+++ b/3DTopia/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,835 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from ldm.util import instantiate_from_config
+from ldm.modules.attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
diff --git a/3DTopia/ldm/modules/diffusionmodules/openaimodel.py b/3DTopia/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..43edcfbcad0cc3e26d9979734a67b1ca0e593392
--- /dev/null
+++ b/3DTopia/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,965 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, ch, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, ch, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ # if h.shape[1] == 640:
+ # # h[:, :320] = h[:, :320] * 1.4
+ # # print("here")
+ # h = h * 1.2
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
diff --git a/3DTopia/ldm/modules/diffusionmodules/triplane_3daware_unet.py b/3DTopia/ldm/modules/diffusionmodules/triplane_3daware_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1358d5f9e24d47e8e8142c87292374c755fd7e53
--- /dev/null
+++ b/3DTopia/ldm/modules/diffusionmodules/triplane_3daware_unet.py
@@ -0,0 +1,991 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ # print("Using 3d aware resblock!")
+
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels * 3, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = th.nn.AvgPool2d((res, 1))
+ y_mp = th.nn.AvgPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = th.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = th.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = th.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = th.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = th.cat([plane3, plane13, plane23], 1)
+
+ new_plane = th.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = self.to3daware(h)
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ h = self.to3daware(out_norm(h))
+ h = out_rest(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
diff --git a/3DTopia/ldm/modules/diffusionmodules/triplane_context_crossattention_unet.py b/3DTopia/ldm/modules/diffusionmodules/triplane_context_crossattention_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a846be183fba569d7178e2fd34cc145f9669cec
--- /dev/null
+++ b/3DTopia/ldm/modules/diffusionmodules/triplane_context_crossattention_unet.py
@@ -0,0 +1,1126 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from inspect import isfunction
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, TriplaneAttentionBlock):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ x = x.permute(0, 2, 1)
+ context = context.permute(0, 2, 1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -th.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out).permute(0, 2, 1)
+
+class CrossAttentionContext(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+
+ # import pdb; pdb.set_trace()
+
+ h = self.heads
+
+ x = x.permute(0, 2, 1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out).permute(0, 2, 1)
+
+
+class TriplaneAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ context_channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ # self.norm = normalization(channels)
+ self.norm1 = normalization(channels)
+ self.norm2 = normalization(channels)
+ self.norm3 = normalization(channels)
+ self.norm4 = normalization(channels)
+ self.norm5 = normalization(channels)
+ # self.norm6 = normalization(context_channels)
+ # self.norm1 = nn.LayerNorm(channels)
+ # self.norm2 = nn.LayerNorm(channels)
+ # self.norm3 = nn.LayerNorm(channels)
+ # self.norm4 = nn.LayerNorm(channels)
+ # self.norm5 = nn.LayerNorm(channels)
+ # self.norm6 = nn.LayerNorm(context_channels)
+
+ self.plane1_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+ self.plane2_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+ self.plane3_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+
+ self.context_ca = CrossAttentionContext(channels, context_channels, self.num_heads, num_head_channels)
+
+ def forward(self, x, context):
+ return checkpoint(self._forward, (x, context), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x, context):
+ b, c, *spatial = x.shape
+ res = x.shape[-2]
+ plane1 = x[..., :res].reshape(b, c, -1)
+ plane2 = x[..., res:res*2].reshape(b, c, -1)
+ plane3 = x[..., 2*res:3*res].reshape(b, c, -1)
+ x = x.reshape(b, c, -1)
+
+ plane1_output = self.plane1_ca(self.norm1(plane1), self.norm4(x))
+ plane2_output = self.plane2_ca(self.norm2(plane2), self.norm4(x))
+ plane3_output = self.plane3_ca(self.norm3(plane3), self.norm4(x))
+
+ h = th.cat([plane1_output, plane2_output, plane3_output], -1)
+
+ h = self.context_ca(self.norm5(h), context=context)
+
+ return (x + h).reshape(b, c, *spatial)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ TriplaneAttentionBlock(
+ ch,
+ context_dim,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ TriplaneAttentionBlock(
+ ch,
+ context_dim,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ TriplaneAttentionBlock(
+ ch,
+ context_dim,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, ch, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, ch, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
diff --git a/3DTopia/ldm/modules/diffusionmodules/triplane_crossattention_unet.py b/3DTopia/ldm/modules/diffusionmodules/triplane_crossattention_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e489ef26704b530453d0492429a260a3f13c4d9a
--- /dev/null
+++ b/3DTopia/ldm/modules/diffusionmodules/triplane_crossattention_unet.py
@@ -0,0 +1,1058 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from inspect import isfunction
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ x = x.permute(0, 2, 1)
+ context = context.permute(0, 2, 1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -th.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out).permute(0, 2, 1)
+
+
+class TriplaneAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+
+ self.plane1_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+ self.plane2_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+ self.plane3_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ res = x.shape[-2]
+ plane1 = x[..., :res].reshape(b, c, -1)
+ plane2 = x[..., res:res*2].reshape(b, c, -1)
+ plane3 = x[..., 2*res:3*res].reshape(b, c, -1)
+ x = x.reshape(b, c, -1)
+
+ plane1_output = self.plane1_ca(self.norm(plane1), self.norm(x))
+ plane2_output = self.plane2_ca(self.norm(plane2), self.norm(x))
+ plane3_output = self.plane3_ca(self.norm(plane3), self.norm(x))
+
+ h = th.cat([plane1_output, plane2_output, plane3_output], -1)
+
+ return (x + h).reshape(b, c, *spatial)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ TriplaneAttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ TriplaneAttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ TriplaneAttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, ch, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, ch, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
+
diff --git a/3DTopia/ldm/modules/diffusionmodules/util.py b/3DTopia/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe3bb31158f0184505b044ba1cbffe1ede3375e
--- /dev/null
+++ b/3DTopia/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,305 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+def force_zero_snr(betas):
+ alphas = 1 - betas
+ alphas_bar = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_bar ** (1/2)
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - 1e-6
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+ alphas_bar = alphas_bar_sqrt ** 2
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
+ alphas = torch.cat([alphas_bar[0:1], alphas], 0)
+ betas = 1 - alphas
+ return betas
+
+def shift_schedule(base_betas, shift_scale):
+ alphas = 1 - base_betas
+ alphas_bar = torch.cumprod(alphas, dim=0)
+ snr = alphas_bar / (1 - alphas_bar) # snr(1-ab)=ab; snr-snr*ab=ab; snr=(1+snr)ab; ab=snr/(1+snr)
+ shifted_snr = snr * ((1 / shift_scale) ** 2)
+ shifted_alphas_bar = shifted_snr / (1 + shifted_snr)
+ shifted_alphas = shifted_alphas_bar[1:] / shifted_alphas_bar[:-1]
+ shifted_alphas = torch.cat([shifted_alphas_bar[0:1], shifted_alphas], 0)
+ shifted_betas = 1 - shifted_alphas
+ return shifted_betas
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, shift_scale=None):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ elif schedule == 'linear_force_zero_snr':
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+ betas = force_zero_snr(betas)
+ elif schedule == 'linear_100':
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+ betas = betas[:100]
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+
+ if shift_scale is not None:
+ betas = shift_schedule(betas, shift_scale)
+
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/3DTopia/ldm/modules/distributions/__init__.py b/3DTopia/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3DTopia/ldm/modules/distributions/distributions.py b/3DTopia/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/3DTopia/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/3DTopia/ldm/modules/ema.py b/3DTopia/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5
--- /dev/null
+++ b/3DTopia/ldm/modules/ema.py
@@ -0,0 +1,76 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
+ else torch.tensor(-1,dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.','')
+ self.m_name2s_name.update({name:s_name})
+ self.register_buffer(s_name,p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self,model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/3DTopia/ldm/modules/encoders/__init__.py b/3DTopia/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/3DTopia/ldm/modules/encoders/modules.py b/3DTopia/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..73508813bb3942f914553f4988606d4917f4aaf8
--- /dev/null
+++ b/3DTopia/ldm/modules/encoders/modules.py
@@ -0,0 +1,386 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import clip
+from einops import rearrange, repeat
+from transformers import CLIPTokenizer, CLIPTextModel
+import kornia
+
+from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class TransformerEmbedder(AbstractEncoder):
+ """Some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
+ super().__init__()
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
+
+ def forward(self, tokens):
+ tokens = tokens.to(self.device) # meh
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, x):
+ return self(x)
+
+
+class BERTTokenizer(AbstractEncoder):
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
+ super().__init__()
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ self.device = device
+ self.vq_interface = vq_interface
+ self.max_length = max_length
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ return tokens
+
+ @torch.no_grad()
+ def encode(self, text):
+ tokens = self(text)
+ if not self.vq_interface:
+ return tokens
+ return None, None, [None, None, tokens]
+
+ def decode(self, text):
+ return text
+
+
+class BERTEmbedder(AbstractEncoder):
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ super().__init__()
+ self.use_tknz_fn = use_tokenizer
+ if self.use_tknz_fn:
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout)
+
+ def forward(self, text):
+ if self.use_tknz_fn:
+ tokens = self.tknz_fn(text)#.to(self.device)
+ else:
+ tokens = text
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, text):
+ # output of length 77
+ return self(text)
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(self,
+ n_stages=1,
+ method='bilinear',
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None
+ if self.remap_output:
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+
+ def forward(self,x):
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Any, Union, List
+from pkg_resources import packaging
+
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+
+
+if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
+
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+def tokenize_with_truncation(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ truncate: bool
+ Whether to truncate the text in case its encoding is longer than the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate:
+ tokens = tokens[:context_length]
+ tokens[-1] = eot_token
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+class FrozenCLIPTextEmbedder(nn.Module):
+ """
+ Uses the CLIP transformer encoder for text.
+ """
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
+ super().__init__()
+ self.model, _ = clip.load(version, jit=False, device="cpu")
+ self.device = device
+ self.max_length = max_length
+ self.n_repeat = n_repeat
+ self.normalize = normalize
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ # tokens = clip.tokenize(text).to(self.device)
+ tokens = tokenize_with_truncation(text, truncate=True).to(self.device)
+ z = self.model.encode_text(tokens)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ return z
+
+ def encode(self, text):
+ z = self(text)
+ if z.ndim==2:
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ return z
+
+
+class FrozenClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model='ViT-L/14',
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ # self.model, _ = clip.load(name=model, device=device, jit=jit)
+ self.model, _ = clip.load(name=model, device=device)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ # x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ z = self.model.encode_image(self.preprocess(x))
+ if z.ndim==2:
+ z = z[:, None, :]
+ return z
+
+
+############### OPENCLIP #################
+import open_clip
+
+class OpenClipTextEmbedder(nn.Module):
+ def __init__(
+ self,
+ model='ViT-bigG-14',
+ pretrained='laion2b_s39b_b160k',
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ normalize=True,):
+ super().__init__()
+ self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained, device='cpu')
+ self.tokenizer = open_clip.get_tokenizer(model)
+ self.normalize = normalize
+ self.device = device
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tok_text = self.tokenizer(text).to(self.device)
+ z = self.model.encode_text(tok_text)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ return z
+
+ def encode(self, text):
+ z = self(text)
+ if z.ndim==2:
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=1)
+ return z
+
+class OpenClipImageEmbedder(nn.Module):
+ def __init__(
+ self,
+ model='ViT-bigG-14',
+ pretrained='laion2b_s39b_b160k',
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ ):
+ super().__init__()
+ self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained, device=device)
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ z = self.model.encode_image(self.preprocess(x))
+ if z.ndim==2:
+ z = z[:, None, :]
+ return z
+
+class DinoV2(nn.Module):
+ def __init__(self, model='dinov2_vitb14', ckpt='dino_ckpt/dinov2_vitb14_pretrain.pth'):
+ super().__init__()
+ # device='cuda' if torch.cuda.is_available() else 'cpu'
+ # self.model = torch.hub.load('facebookresearch/dinov2', model)
+ self.model = torch.hub.load('dinov2', model, source='local', pretrained=False)
+ self.model.load_state_dict(torch.load(ckpt))
+ # self.model = self.model.to(device)
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]), persistent=False)
+
+ def preprocess(self, x):
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=False)
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ return self.model.forward_features(self.preprocess(x))['x_norm_patchtokens']
+
+if __name__ == "__main__":
+ from ldm.util import count_params
+ model = FrozenCLIPEmbedder()
+ count_params(model, verbose=True)
\ No newline at end of file
diff --git a/3DTopia/ldm/modules/image_degradation/__init__.py b/3DTopia/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/3DTopia/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/3DTopia/ldm/modules/image_degradation/bsrgan.py b/3DTopia/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/3DTopia/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/3DTopia/ldm/modules/image_degradation/bsrgan_light.py b/3DTopia/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1f823996bf559e9b015ea9aa2b3cd38dd13af1
--- /dev/null
+++ b/3DTopia/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,650 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/3DTopia/ldm/modules/image_degradation/utils/test.png b/3DTopia/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/3DTopia/ldm/modules/image_degradation/utils/test.png differ
diff --git a/3DTopia/ldm/modules/image_degradation/utils_image.py b/3DTopia/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/3DTopia/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/3DTopia/ldm/modules/losses/__init__.py b/3DTopia/ldm/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5
--- /dev/null
+++ b/3DTopia/ldm/modules/losses/__init__.py
@@ -0,0 +1 @@
+from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/3DTopia/ldm/modules/losses/contperceptual.py b/3DTopia/ldm/modules/losses/contperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e
--- /dev/null
+++ b/3DTopia/ldm/modules/losses/contperceptual.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+
+from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge"):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+ else:
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
+
diff --git a/3DTopia/ldm/modules/losses/vqperceptual.py b/3DTopia/ldm/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306
--- /dev/null
+++ b/3DTopia/ldm/modules/losses/vqperceptual.py
@@ -0,0 +1,167 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+def l1(x, y):
+ return torch.abs(x-y)
+
+
+def l2(x, y):
+ return torch.pow((x-y), 2)
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
+ pixel_loss="l1"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ assert perceptual_loss in ["lpips", "clips", "dists"]
+ assert pixel_loss in ["l1", "l2"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ if perceptual_loss == "lpips":
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
+ self.perceptual_loss = LPIPS().eval()
+ else:
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
+ self.perceptual_weight = perceptual_weight
+
+ if pixel_loss == "l1":
+ self.pixel_loss = l1
+ else:
+ self.pixel_loss = l2
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.n_classes = n_classes
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
+ if not exists(codebook_loss):
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if predicted_indices is not None:
+ assert self.n_classes is not None
+ with torch.no_grad():
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
+ log[f"{split}/perplexity"] = perplexity
+ log[f"{split}/cluster_usage"] = cluster_usage
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/3DTopia/ldm/modules/x_transformer.py b/3DTopia/ldm/modules/x_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576
--- /dev/null
+++ b/3DTopia/ldm/modules/x_transformer.py
@@ -0,0 +1,641 @@
+"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat, reduce
+
+# constants
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple('Intermediates', [
+ 'pre_softmax_attn',
+ 'post_softmax_attn'
+])
+
+LayerIntermediates = namedtuple('Intermediates', [
+ 'hiddens',
+ 'attn_intermediates'
+])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+
+# classes
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, 'b n d -> (b n) d'),
+ rearrange(residual, 'b n d -> (b n) d')
+ )
+
+ return gated_output.reshape_as(x)
+
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# attention.
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ # talking heads
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ # explicit topk sparse attention
+ self.sparse_topk = sparse_topk
+
+ # entmax
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ self.attn_fn = F.softmax
+
+ # add memory key / values
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ # attention on attention
+ self.attn_on_attn = on_attn
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None
+ ):
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ # in shortformer, the query would start at a position offset depending on the past cached memory
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
+
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
+ self.rotary_pos_emb = always(None)
+
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ('a', 'c', 'f')
+ elif cross_attend and only_cross:
+ default_block = ('c', 'f')
+ else:
+ default_block = ('a', 'f')
+
+ if macaron:
+ default_block = ('f',) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
+ default_block = tuple(filter(not_equals('f'), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
+ par_block = default_block + ('f',) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == 'a':
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == 'c':
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == 'f':
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f'invalid layer type {layer_type}')
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([
+ norm_fn(),
+ layer,
+ residual_fn
+ ]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == 'a':
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == 'a':
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
+ prev_attn=prev_attn, mem=layer_mem)
+ elif layer_type == 'c':
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
+ elif layer_type == 'f':
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ('a', 'c'):
+ intermediates.append(inter)
+
+ if layer_type == 'a' and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == 'c' and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens,
+ attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
+ super().__init__(causal=False, **kwargs)
+
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
+ ):
+ super().__init__()
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
+
+ # memory tokens (like [cls]) from Memory Transformers paper
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ # let funnel encoder know number of memory tokens, if specified
+ if hasattr(attn_layers, 'num_memory_tokens'):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
+ return out, attn_maps
+
+ return out
+
diff --git a/3DTopia/ldm/util.py b/3DTopia/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ba38853e7a07228cc2c187742b5c45d7359b3f9
--- /dev/null
+++ b/3DTopia/ldm/util.py
@@ -0,0 +1,203 @@
+import importlib
+
+import torch
+import numpy as np
+from collections import abc
+from einops import rearrange
+from functools import partial
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+ # create dummy dataset instance
+
+ # run prefetching
+ if idx_to_fn:
+ res = func(data, worker_id=idx)
+ else:
+ res = func(data)
+ Q.put([idx, res])
+ Q.put("Done")
+
+
+def parallel_data_prefetch(
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
+):
+ # if target_data_type not in ["ndarray", "list"]:
+ # raise ValueError(
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+ # )
+ if isinstance(data, np.ndarray) and target_data_type == "list":
+ raise ValueError("list expected but function got ndarray.")
+ elif isinstance(data, abc.Iterable):
+ if isinstance(data, dict):
+ print(
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+ )
+ data = list(data.values())
+ if target_data_type == "ndarray":
+ data = np.asarray(data)
+ else:
+ data = list(data)
+ else:
+ raise TypeError(
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+ )
+
+ if cpu_intensive:
+ Q = mp.Queue(1000)
+ proc = mp.Process
+ else:
+ Q = Queue(1000)
+ proc = Thread
+ # spawn processes
+ if target_data_type == "ndarray":
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(np.array_split(data, n_proc))
+ ]
+ else:
+ step = (
+ int(len(data) / n_proc + 1)
+ if len(data) % n_proc != 0
+ else int(len(data) / n_proc)
+ )
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(
+ [data[i: i + step] for i in range(0, len(data), step)]
+ )
+ ]
+ processes = []
+ for i in range(n_proc):
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+ processes += [p]
+
+ # start processes
+ print(f"Start prefetching...")
+ import time
+
+ start = time.time()
+ gather_res = [[] for _ in range(n_proc)]
+ try:
+ for p in processes:
+ p.start()
+
+ k = 0
+ while k < n_proc:
+ # get result
+ res = Q.get()
+ if res == "Done":
+ k += 1
+ else:
+ gather_res[res[0]] = res[1]
+
+ except Exception as e:
+ print("Exception: ", e)
+ for p in processes:
+ p.terminate()
+
+ raise e
+ finally:
+ for p in processes:
+ p.join()
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+ if target_data_type == 'ndarray':
+ if not isinstance(gather_res[0], np.ndarray):
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+ # order outputs
+ return np.concatenate(gather_res, axis=0)
+ elif target_data_type == 'list':
+ out = []
+ for r in gather_res:
+ out.extend(r)
+ return out
+ else:
+ return gather_res
diff --git a/3DTopia/model/auto_regressive.py b/3DTopia/model/auto_regressive.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9de51a3596cf5d226c4bf5bf3745d784c977510
--- /dev/null
+++ b/3DTopia/model/auto_regressive.py
@@ -0,0 +1,412 @@
+import imageio
+import os, math
+import wandb
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from utility.initialize import instantiate_from_config
+from taming.modules.util import SOSProvider
+from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr
+import numpy as np
+
+from tqdm import tqdm
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class Net2NetTransformer(pl.LightningModule):
+ def __init__(self,
+ transformer_config,
+ first_stage_config,
+ cond_stage_config,
+ permuter_config=None,
+ ckpt_path=None,
+ ignore_keys=[],
+ first_stage_key="triplane",
+ cond_stage_key="depth",
+ downsample_cond_size=-1,
+ pkeep=1.0,
+ sos_token=0,
+ unconditional=True,
+ learning_rate=1e-4,
+ ):
+ super().__init__()
+ self.be_unconditional = unconditional
+ self.sos_token = sos_token
+ self.first_stage_key = first_stage_key
+ # self.cond_stage_key = cond_stage_key
+ self.init_first_stage_from_ckpt(first_stage_config)
+ # self.init_cond_stage_from_ckpt(cond_stage_config)
+ if permuter_config is None:
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
+ self.permuter = instantiate_from_config(config=permuter_config)
+ self.transformer = instantiate_from_config(config=transformer_config)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.downsample_cond_size = downsample_cond_size
+ self.pkeep = pkeep
+ self.learning_rate = learning_rate
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ for k in sd.keys():
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ self.print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def init_first_stage_from_ckpt(self, config):
+ model = instantiate_from_config(config)
+ # model = model.eval()
+ # model.train = disabled_train
+
+ self.first_stage_model = model
+
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ self.first_stage_model.vector_quantizer.training = False
+ self.first_stage_model.vector_quantizer.embedding.update = False
+
+ def init_cond_stage_from_ckpt(self, config):
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__" or self.be_unconditional:
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
+ f"Prepending {self.sos_token} as a sos token.")
+ self.be_unconditional = True
+ self.cond_stage_key = self.first_stage_key
+ self.cond_stage_model = SOSProvider(self.sos_token)
+ else:
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.cond_stage_model = model
+
+ def forward(self, x, c):
+ # one step to produce the logits
+ _, z_indices = self.encode_to_z(x)
+ # _, c_indices = self.encode_to_c(c)
+
+ if self.training and self.pkeep < 1.0:
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
+ device=z_indices.device))
+ mask = mask.round().to(dtype=torch.int64)
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+ a_indices = mask*z_indices+(1-mask)*r_indices
+ else:
+ a_indices = z_indices
+
+ c_indices = torch.zeros_like(z_indices[:, 0:1]) + self.transformer.config.vocab_size - 1
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
+
+ # target includes all sequence elements (no need to handle first one
+ # differently because we are conditioning)
+ target = z_indices
+ # make the prediction
+ logits, _ = self.transformer(cz_indices[:, :-1])
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1:
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
+ if len(indices.shape) > 2:
+ indices = indices.view(c.shape[0], -1)
+ return quant_c, indices
+
+ # @torch.no_grad()
+ # def decode_to_img(self, index, zshape):
+ # index = self.permuter(index, reverse=True)
+ # bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
+ # quant_z = self.first_stage_model.quantize.get_codebook_entry(
+ # index.reshape(-1), shape=bhwc)
+ # x = self.first_stage_model.decode(quant_z)
+ # return x
+
+ @torch.no_grad()
+ def decode_to_triplane(self, index, zshape):
+ quant_z = self.first_stage_model.vector_quantizer.dequantize(index)
+ quant_z = quant_z.reshape(zshape[0], zshape[2], zshape[3], zshape[1])
+ quant_z = quant_z.permute(0, 3, 1, 2)
+ z = self.first_stage_model.decode(quant_z)
+ return z
+
+ @torch.no_grad()
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+ log = dict()
+
+ N = 2
+ if lr_interface:
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
+ else:
+ x, c = self.get_xc(batch, N)
+ x = x.to(device=self.device)
+ # c = c.to(device=self.device)
+ log["inputs"] = self.render_triplane(x, batch)
+
+ quant_z, z_indices = self.encode_to_z(x)
+ # quant_c, c_indices = self.encode_to_c(c)
+ c_indices = torch.zeros_like(z_indices[:, 0:1]) + self.transformer.config.vocab_size - 1
+
+ # create a "half"" sample
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample = self.first_stage_model.unrollout(self.decode_to_triplane(index_sample, quant_z.shape))
+ log["samples_half"] = self.render_triplane(x_sample, batch)
+
+ # sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_nopix = self.first_stage_model.unrollout(self.decode_to_triplane(index_sample, quant_z.shape))
+ log["samples_nopix"] = self.render_triplane(x_sample_nopix, batch)
+
+ # # det sample
+ # z_start_indices = z_indices[:, :0]
+ # index_sample = self.sample(z_start_indices, c_indices,
+ # steps=z_indices.shape[1],
+ # sample=False,
+ # callback=callback if callback is not None else lambda k: None)
+ # x_sample_det = self.first_stage_model.unrollout(self.decode_to_triplane(index_sample, quant_z.shape))
+ # log["samples_det"] = self.render_triplane(x_sample_det, batch)
+
+ # reconstruction
+ x_rec = self.first_stage_model.unrollout(self.decode_to_triplane(z_indices, quant_z.shape))
+ # x_rec = self.first_stage_model.unrollout(self.first_stage_model(self.first_stage_model.rollout(x))[0])
+ log["reconstructions"] = self.render_triplane(x_rec, batch)
+
+ # if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
+ # figure_size = (x_rec.shape[2], x_rec.shape[3])
+ # dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
+ # label_for_category_no = dataset.get_textual_label_for_category_no
+ # plotter = dataset.conditional_builders[self.cond_stage_key].plot
+ # log["conditioning"] = torch.zeros_like(log["reconstructions"])
+ # for i in range(quant_c.shape[0]):
+ # log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
+ # log["conditioning_rec"] = log["conditioning"]
+ # elif self.cond_stage_key != "image":
+ # cond_rec = self.cond_stage_model.decode(quant_c)
+ # if self.cond_stage_key == "segmentation":
+ # # get image from segmentation mask
+ # num_classes = cond_rec.shape[1]
+
+ # c = torch.argmax(c, dim=1, keepdim=True)
+ # c = F.one_hot(c, num_classes=num_classes)
+ # c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ # c = self.cond_stage_model.to_rgb(c)
+
+ # cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+ # cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+ # cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+ # cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+ # log["conditioning_rec"] = cond_rec
+ # log["conditioning"] = c
+
+ return log
+
+ def render_triplane(self, triplane, batch):
+ batch_size = triplane.shape[0]
+ rgb_list = []
+ for b in range(batch_size):
+ rgb, cur_psnr_list = self.first_stage_model.render_triplane_eg3d_decoder(
+ triplane[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb = to8b(rgb.detach().cpu().numpy())
+ rgb_list.append(rgb[1])
+
+ return np.stack(rgb_list, 0)
+
+ def get_input(self, key, batch):
+ x = batch[key]
+ # if len(x.shape) == 3:
+ # x = x[..., None]
+ # if len(x.shape) == 4:
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ # if x.dtype == torch.double:
+ # x = x.float()
+ return x
+
+ def get_xc(self, batch, N=None):
+ x = self.get_input(self.first_stage_key, batch)
+ # c = self.get_input(self.cond_stage_key, batch)
+ if N is not None:
+ x = x[:N]
+ # c = c[:N]
+ return x, None
+
+ def shared_step(self, batch, batch_idx):
+ x, c = self.get_xc(batch)
+ logits, target = self(x, c)
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ if batch_idx == 0:
+ imgs = self.log_images(batch)
+ for i in range(imgs['inputs'].shape[0]):
+ self.logger.experiment.log({
+ "val/vis/inputs": [wandb.Image(imgs['inputs'][i])],
+ "val/vis/reconstructions": [wandb.Image(imgs['reconstructions'][i])],
+ "val/vis/samples_half": [wandb.Image(imgs['samples_half'][i])],
+ "val/vis/samples_nopix": [wandb.Image(imgs['samples_nopix'][i])],
+ # "val/vis/samples_det": [wandb.Image(imgs['samples_det'][i])],
+ })
+ return loss
+
+ def test_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ imgs = self.log_images(batch, temperature=1.8)
+ print("Saved to {}".format(self.logger.log_dir))
+ for i in range(imgs['inputs'].shape[0]):
+ imageio.imwrite(os.path.join(self.logger.log_dir, "inputs_{}_{}.png".format(batch_idx, i)), imgs['inputs'][i])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "reconstructions_{}_{}.png".format(batch_idx, i)), imgs['reconstructions'][i])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "samples_half_{}_{}.png".format(batch_idx, i)), imgs['samples_half'][i])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "samples_nopix_{}_{}.png".format(batch_idx, i)), imgs['samples_nopix'][i])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "samples_det_{}_{}.png".format(batch_idx, i)), imgs['samples_det'][i])
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Following minGPT:
+ This long function is unfortunately doing something very simple and is being very defensive:
+ We are separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ We are then returning the PyTorch optimizer object.
+ """
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear, )
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.transformer.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add('pos_emb')
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params), )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+ return optimizer
diff --git a/3DTopia/model/sv_vae_triplane.py b/3DTopia/model/sv_vae_triplane.py
new file mode 100644
index 0000000000000000000000000000000000000000..60f9e51edc84df596d3477c25f3aeeb5b8e7150b
--- /dev/null
+++ b/3DTopia/model/sv_vae_triplane.py
@@ -0,0 +1,111 @@
+import os
+import imageio
+import numpy as np
+import torch
+import torchvision
+import torch.nn as nn
+import pytorch_lightning as pl
+import wandb
+
+import lpips
+from pytorch_msssim import SSIM
+
+from utility.initialize import instantiate_from_config
+
+class VAE(pl.LightningModule):
+ def __init__(self, vae_configs, renderer_configs, lr=1e-3, weight_decay=1e-2,
+ kld_weight=1, mse_weight=1, lpips_weight=0.1, ssim_weight=0.1,
+ log_image_freq=50):
+ super().__init__()
+ self.save_hyperparameters()
+
+ self.lr = lr
+ self.weight_decay = weight_decay
+ self.kld_weight = kld_weight
+ self.mse_weight = mse_weight
+ self.lpips_weight = lpips_weight
+ self.ssim_weight = ssim_weight
+ self.log_image_freq = log_image_freq
+
+ self.vae = instantiate_from_config(vae_configs)
+ self.renderer = instantiate_from_config(renderer_configs)
+
+ self.lpips_fn = lpips.LPIPS(net='alex')
+ self.ssim_fn = SSIM(data_range=1, size_average=True, channel=3)
+
+ self.triplane_render_kwargs = {
+ 'depth_resolution': 64,
+ 'disparity_space_sampling': False,
+ 'box_warp': 2.4,
+ 'depth_resolution_importance': 64,
+ 'clamp_mode': 'softplus',
+ 'white_back': True,
+ }
+
+ def forward(self, batch, is_train):
+ encoder_img, input_img, input_ray_o, input_ray_d, \
+ target_img, target_ray_o, target_ray_d = batch
+ grid, mu, logvar = self.vae(encoder_img, is_train)
+
+ cat_ray_o = torch.cat([input_ray_o, target_ray_o], 0)
+ cat_ray_d = torch.cat([input_ray_d, target_ray_d], 0)
+ render_out = self.renderer(torch.cat([grid, grid], 0), cat_ray_o, cat_ray_d, self.triplane_render_kwargs)
+ render_gt = torch.cat([input_img, target_img], 0)
+
+ return render_out['rgb_marched'], render_out['depth_final'], \
+ render_out['weights'], mu, logvar, render_gt
+
+ def calc_loss(self, render, mu, logvar, render_gt):
+ mse = torch.mean((render - render_gt) ** 2)
+ ssim_loss = 1 - self.ssim_fn(render, render_gt)
+ lpips_loss = self.lpips_fn((render * 2) - 1, (render_gt * 2) - 1).mean()
+ kld_loss = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1))
+
+ loss = self.mse_weight * mse + self.ssim_weight * ssim_loss + \
+ self.lpips_weight * lpips_loss + self.kld_weight * kld_loss
+
+ return {
+ 'loss': loss,
+ 'mse': mse,
+ 'ssim': ssim_loss,
+ 'lpips': lpips_loss,
+ 'kld': kld_loss,
+ }
+
+ def log_dict(self, loss_dict, prefix):
+ for k, v in loss_dict.items():
+ self.log(prefix + k, v, on_step=True, logger=True)
+
+ def make_grid(self, render, depth, render_gt):
+ bs = render.shape[0] // 2
+ grid = torchvision.utils.make_grid(
+ torch.stack([render_gt[0], render_gt[bs], render[0], depth[0], render[bs], depth[bs]], 0))
+ grid = (grid.detach().cpu().permute(1, 2, 0) * 255.).numpy().astype(np.uint8)
+ return grid
+
+ def training_step(self, batch, batch_idx):
+ render, depth, weights, mu, logvar, render_gt = self.forward(batch, True)
+ loss_dict = self.calc_loss(render, mu, logvar, render_gt)
+ self.log_dict(loss_dict, 'train/')
+ if batch_idx % self.log_image_freq == 0:
+ self.logger.experiment.log({
+ 'train/vis': [wandb.Image(self.make_grid(
+ render, depth, render_gt
+ ))]
+ })
+ return loss_dict['loss']
+
+ def validation_step(self, batch, batch_idx):
+ render, depth, _, mu, logvar, render_gt = self.forward(batch, False)
+ loss_dict = self.calc_loss(render, mu, logvar, render_gt)
+ self.log_dict(loss_dict, 'val/')
+ if batch_idx % self.log_image_freq == 0:
+ self.logger.experiment.log({
+ 'val/vis': [wandb.Image(self.make_grid(
+ render, depth, render_gt
+ ))]
+ })
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
+ return optimizer
diff --git a/3DTopia/model/triplane_vae.py b/3DTopia/model/triplane_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..43ff1973d3ed49fa3bd07c9527a6650cb3b4b393
--- /dev/null
+++ b/3DTopia/model/triplane_vae.py
@@ -0,0 +1,2656 @@
+import os
+import imageio
+import torch
+import wandb
+import numpy as np
+import pytorch_lightning as pl
+import torch.nn.functional as F
+
+from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion
+from utility.initialize import instantiate_from_config
+from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr
+from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ learning_rate=1e-3,
+ ckpt_path=None,
+ ignore_keys=[],
+ colorize_nlabels=None,
+ monitor=None,
+ decoder_ckpt=None,
+ norm=False,
+ renderer_type='nerf',
+ renderer_config=dict(
+ rgbnet_dim=18,
+ rgbnet_width=128,
+ viewpe=0,
+ feape=0
+ ),
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+ self.norm = norm
+ self.renderer_config = renderer_config
+ self.learning_rate = learning_rate
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ # self.loss = instantiate_from_config(lossconfig)
+ self.lossconfig = lossconfig
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ self.decoder_ckpt = decoder_ckpt
+ self.renderer_type = renderer_type
+ # if decoder_ckpt is not None:
+ assert self.renderer_type in ['nerf', 'eg3d']
+ if self.renderer_type == 'nerf':
+ self.triplane_decoder, self.triplane_render_kwargs = self.create_nerf(decoder_ckpt)
+ elif self.renderer_type == 'eg3d':
+ self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt)
+ else:
+ raise NotImplementedError
+
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+ self.latent_list = []
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x, rollout=False):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def unrollout(self, *args, **kwargs):
+ pass
+
+ def loss(self, inputs, reconstructions, posteriors, prefix, batch=None):
+ reconstructions = reconstructions.contiguous()
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions)
+ rec_loss = torch.sum(rec_loss) / rec_loss.shape[0]
+ kl_loss = posteriors.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss
+
+ ret_dict = {
+ prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(),
+ prefix+'rec_loss': rec_loss,
+ prefix+'kl_loss': kl_loss,
+ prefix+'loss': loss,
+ prefix+'mean': posteriors.mean.mean(),
+ prefix+'logvar': posteriors.logvar.mean(),
+ }
+
+ render_weight = self.lossconfig.get("render_weight", 0)
+ tv_weight = self.lossconfig.get("tv_weight", 0)
+ l1_weight = self.lossconfig.get("l1_weight", 0)
+ latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0)
+ latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0)
+
+ triplane_rec = self.unrollout(reconstructions)
+ if render_weight > 0 and batch is not None:
+ rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img'])
+ render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256
+ loss += render_weight * render_loss
+ ret_dict[prefix + 'render_loss'] = render_loss
+ if tv_weight > 0:
+ tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).sum() / triplane_rec.shape[0]
+ tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).sum() / triplane_rec.shape[0]
+ tvloss = tvloss_y + tvloss_x
+ loss += tv_weight * tvloss
+ ret_dict[prefix + 'tv_loss'] = tvloss
+ if l1_weight > 0:
+ l1 = (triplane_rec ** 2).sum() / triplane_rec.shape[0]
+ loss += l1_weight * l1
+ ret_dict[prefix + 'l1_loss'] = l1
+ if latent_tv_weight > 0:
+ latent = posteriors.mean
+ latent_tv_y = torch.abs(latent[:, :, :-1] - latent[:, :, 1:]).sum() / latent.shape[0]
+ latent_tv_x = torch.abs(latent[:, :, :, :-1] - latent[:, :, :, 1:]).sum() / latent.shape[0]
+ latent_tv_loss = latent_tv_y + latent_tv_x
+ loss += latent_tv_loss * latent_tv_weight
+ ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss
+ ret_dict[prefix + 'latent_max'] = latent.max()
+ ret_dict[prefix + 'latent_min'] = latent.min()
+ if latent_l1_weight > 0:
+ latent = posteriors.mean
+ latent_l1_loss = (latent ** 2).sum() / latent.shape[0]
+ loss += latent_l1_loss * latent_l1_weight
+ ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss
+
+ return loss, ret_dict
+
+ def training_step(self, batch, batch_idx):
+ # inputs = self.get_input(batch, self.image_key)
+ inputs = batch['triplane']
+ reconstructions, posterior = self(inputs)
+
+ # if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/')
+ # self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ # if optimizer_idx == 1:
+ # # train the discriminator
+ # discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ # last_layer=self.get_last_layer(), split="train")
+
+ # self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ # self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ # return discloss
+
+ def validation_step(self, batch, batch_idx):
+ # # inputs = self.get_input(batch, self.image_key)
+ # inputs = batch['triplane']
+ # reconstructions, posterior = self(inputs, sample_posterior=False)
+ # aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/')
+
+ # # discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ # # last_layer=self.get_last_layer(), split="val")
+
+ # # self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ # self.log_dict(log_dict_ae)
+ # # self.log_dict(log_dict_disc)
+ # return self.log_dict
+
+ inputs = batch['triplane']
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/')
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def create_eg3d_decoder(self, decoder_ckpt):
+ triplane_decoder = Renderer_TriPlane(**self.renderer_config)
+ if decoder_ckpt is not None:
+ pretrain_pth = torch.load(decoder_ckpt, map_location='cpu')
+ pretrain_pth = {
+ '.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items()
+ }
+ triplane_decoder.load_state_dict(pretrain_pth)
+ render_kwargs = {
+ 'depth_resolution': 128,
+ 'disparity_space_sampling': False,
+ 'box_warp': 2.4,
+ 'depth_resolution_importance': 128,
+ 'clamp_mode': 'softplus',
+ 'white_back': True,
+ 'det': True
+ }
+ return triplane_decoder, render_kwargs
+
+ def render_triplane_eg3d_decoder(self, triplane, batch_rays, target):
+ ray_o = batch_rays[:, 0]
+ ray_d = batch_rays[:, 1]
+ psnr_list = []
+ rec_img_list = []
+ res = triplane.shape[-2]
+ for i in range(ray_o.shape[0]):
+ with torch.no_grad():
+ render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res),
+ ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False)
+ rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1)
+ psnr = mse2psnr(img2mse(rec_img[0], target[i]))
+ psnr_list.append(psnr)
+ rec_img_list.append(rec_img)
+ return torch.cat(rec_img_list, 0), psnr_list
+
+ def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024):
+ assert batch_rays.shape[1] == 1
+ sel = torch.randint(batch_rays.shape[-2], [sample_num])
+ ray_o = batch_rays[:, 0, 0, sel]
+ ray_d = batch_rays[:, 0, 1, sel]
+ res = triplane.shape[-2]
+ render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res),
+ ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False)
+ rec_img = render_out['rgb_marched']
+ target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :]
+ return rec_img, target
+
+ def create_nerf(self, decoder_ckpt):
+ # decoder_ckpt = '/mnt/petrelfs/share_data/caoziang/shapenet_triplane_car/003000.tar'
+
+ multires = 10
+ netchunk = 1024*64
+ i_embed = 0
+ perturb = 0
+ raw_noise_std = 0
+
+ triplanechannel=18
+ triplanesize=256
+ chunk=4096
+ num_instance=1
+ batch_size=1
+ use_viewdirs = True
+ white_bkgd = False
+ lrate_decay = 6
+ netdepth=1
+ netwidth=64
+ N_samples = 512
+ N_importance = 0
+ N_rand = 8192
+ multires_views=10
+ precrop_iters = 0
+ precrop_frac = 0.5
+ i_weights=3000
+
+ embed_fn, input_ch = get_embedder(multires, i_embed)
+ embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed)
+ output_ch = 4
+ skips = [4]
+ model = NeRF(D=netdepth, W=netwidth,
+ input_ch=triplanechannel, size=triplanesize,output_ch=output_ch, skips=skips,
+ input_ch_views=input_ch_views, use_viewdirs=use_viewdirs, num_instance=num_instance)
+
+ network_query_fn = lambda inputs, viewdirs, label,network_fn : \
+ run_network(inputs, viewdirs, network_fn,
+ embed_fn=embed_fn,
+ embeddirs_fn=embeddirs_fn,label=label,
+ netchunk=netchunk)
+
+ ckpt = torch.load(decoder_ckpt)
+ model.load_state_dict(ckpt['network_fn_state_dict'])
+
+ render_kwargs_test = {
+ 'network_query_fn' : network_query_fn,
+ 'perturb' : perturb,
+ 'N_samples' : N_samples,
+ # 'network_fn' : model,
+ 'use_viewdirs' : use_viewdirs,
+ 'white_bkgd' : white_bkgd,
+ 'raw_noise_std' : raw_noise_std,
+ }
+ render_kwargs_test['ndc'] = False
+ render_kwargs_test['lindisp'] = False
+ render_kwargs_test['perturb'] = False
+ render_kwargs_test['raw_noise_std'] = 0.
+
+ return model, render_kwargs_test
+
+ def render_triplane(self, triplane, batch_rays, target, near, far, chunk=4096):
+ self.triplane_decoder.tri_planes.copy_(triplane.detach())
+ self.triplane_render_kwargs['network_fn'] = self.triplane_decoder
+ # print(triplane.device)
+ # print(batch_rays.device)
+ # print(target.device)
+ # print(near.device)
+ # print(far.device)
+ with torch.no_grad():
+ rgb, _, _, psnr_list = \
+ render_path1(batch_rays, chunk, self.triplane_render_kwargs, gt_imgs=target,
+ near=near, far=far, label=torch.Tensor([0]).long().to(triplane.device))
+ return rgb, psnr_list
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ # inputs = batch['triplane']
+ # reconstructions, posterior = self(inputs, sample_posterior=False)
+ # aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/')
+ # self.log_dict(log_dict_ae)
+
+ # batch_size = inputs.shape[0]
+ # psnr_list = [] # between rec and gt
+ # psnr_input_list = [] # between input and gt
+ # psnr_rec_list = [] # between input and rec
+
+ # mean = torch.Tensor([
+ # 0.2820, 0.4103, -0.2988, 0.1491, 0.4429, -0.3117, 0.2830, 0.4115,
+ # -0.3032, 0.1530, 0.4466, -0.3165, 0.2617, 0.3837, -0.2692, 0.1098,
+ # 0.4101, -0.2922
+ # ]).reshape(1, 18, 1, 1).to(inputs.device)
+ # std = torch.Tensor([
+ # 1.1696, 1.1287, 1.1733, 1.1583, 1.1238, 1.1675, 1.1978, 1.1585, 1.1949,
+ # 1.1660, 1.1576, 1.1998, 1.1987, 1.1546, 1.1930, 1.1724, 1.1450, 1.2027
+ # ]).reshape(1, 18, 1, 1).to(inputs.device)
+
+ # if self.norm:
+ # reconstructions_unnormalize = reconstructions * std + mean
+ # else:
+ # reconstructions_unnormalize = reconstructions
+
+ # for b in range(batch_size):
+ # # rgb_input, cur_psnr_list_input = self.render_triplane(
+ # # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ # # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ # # )
+ # # rgb, cur_psnr_list = self.render_triplane(
+ # # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ # # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ # # )
+
+ # if self.renderer_type == 'nerf':
+ # rgb_input, cur_psnr_list_input = self.render_triplane(
+ # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ # )
+ # rgb, cur_psnr_list = self.render_triplane(
+ # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ # )
+ # elif self.renderer_type == 'eg3d':
+ # rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ # )
+ # rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ # )
+ # else:
+ # raise NotImplementedError
+
+ # cur_psnr_list_rec = []
+ # for i in range(rgb.shape[0]):
+ # cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ # rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ # rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ # rgb = to8b(rgb.detach().cpu().numpy())
+
+ # if batch_idx < 1:
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ # psnr_list += cur_psnr_list
+ # psnr_input_list += cur_psnr_list_input
+ # psnr_rec_list += cur_psnr_list_rec
+
+ # self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ # self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ # self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ inputs = batch['triplane']
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+ # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0]
+ # res = inputs.shape[1]
+ # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0]
+ # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0]
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2)
+
+ np_z = z.detach().cpu().numpy()
+ # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f:
+ # np.save(f, np_z)
+
+ self.latent_list.append(np_z)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ # opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ # lr=lr, betas=(0.5, 0.9))
+ # return [opt_ae, opt_disc], []
+ return opt_ae
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+ latent = np.concatenate(self.latent_list)
+ q75, q25 = np.percentile(latent.reshape(-1), [75 ,25])
+ median = np.median(latent.reshape(-1))
+ iqr = q75 - q25
+ norm_iqr = iqr * 0.7413
+ print("Norm IQR: {}".format(norm_iqr))
+ print("Inverse Norm IQR: {}".format(1/norm_iqr))
+ print("Median: {}".format(median))
+
+
+class AutoencoderKLRollOut(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ x = self.rollout(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/')
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/')
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/')
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+ # if batch_idx < 1:
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ # mean = torch.Tensor([
+ # 0.2820, 0.4103, -0.2988, 0.1491, 0.4429, -0.3117, 0.2830, 0.4115,
+ # -0.3032, 0.1530, 0.4466, -0.3165, 0.2617, 0.3837, -0.2692, 0.1098,
+ # 0.4101, -0.2922
+ # ]).reshape(1, 18, 1, 1).to(inputs.device)
+ # std = torch.Tensor([
+ # 1.1696, 1.1287, 1.1733, 1.1583, 1.1238, 1.1675, 1.1978, 1.1585, 1.1949,
+ # 1.1660, 1.1576, 1.1998, 1.1987, 1.1546, 1.1930, 1.1724, 1.1450, 1.2027
+ # ]).reshape(1, 18, 1, 1).to(inputs.device)
+
+ mean = torch.Tensor([
+ -1.8449, -1.8242, 0.9667, -1.0187, 1.0647, -0.5422, -1.8632, -1.8435,
+ 0.9314, -1.0261, 1.0356, -0.5484, -1.8543, -1.8348, 0.9109, -1.0169,
+ 1.0160, -0.5467
+ ]).reshape(1, 18, 1, 1).to(inputs.device)
+ std = torch.Tensor([
+ 1.7593, 1.6127, 2.7132, 1.5500, 2.7893, 0.7707, 2.1114, 1.9198, 2.6586,
+ 1.8021, 2.5473, 1.0305, 1.7042, 1.7507, 2.4270, 1.4365, 2.2511, 0.8792
+ ]).reshape(1, 18, 1, 1).to(inputs.device)
+
+ if self.norm:
+ reconstructions_unnormalize = reconstructions * std + mean
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ # for b in range(batch_size):
+ # if self.renderer_type == 'nerf':
+ # rgb_input, cur_psnr_list_input = self.render_triplane(
+ # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ # )
+ # rgb, cur_psnr_list = self.render_triplane(
+ # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ # )
+ # elif self.renderer_type == 'eg3d':
+ # rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ # )
+ # rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ # )
+ # else:
+ # raise NotImplementedError
+
+ # cur_psnr_list_rec = []
+ # for i in range(rgb.shape[0]):
+ # cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ # rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ # rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ # rgb = to8b(rgb.detach().cpu().numpy())
+
+ # # if batch_idx < 1:
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ # psnr_list += cur_psnr_list
+ # psnr_input_list += cur_psnr_list_input
+ # psnr_rec_list += cur_psnr_list_rec
+
+ # self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ # self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ # self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+
+class AutoencoderKLRollOut3DAware(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ try:
+ ckpt_path = kwargs['ckpt_path']
+ kwargs['ckpt_path'] = None
+ except:
+ ckpt_path = None
+
+ super().__init__(*args, **kwargs)
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ ddconfig = kwargs['ddconfig']
+ ddconfig['z_channels'] *= 3
+ del self.decoder
+ del self.post_quant_conv
+ self.decoder = Decoder(**ddconfig)
+ self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.MaxPool2d((res, 1))
+ y_mp = torch.nn.MaxPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ x = self.to3daware(self.rollout(x))
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ z = self.to3daware(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs))
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_3daware(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_3daware"):
+ self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+ colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0]
+ res = inputs.shape[1]
+ colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0]
+ colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0]
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+
+class AutoencoderKLRollOut3DAwareOnlyInput(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ try:
+ ckpt_path = kwargs['ckpt_path']
+ kwargs['ckpt_path'] = None
+ except:
+ ckpt_path = None
+
+ super().__init__(*args, **kwargs)
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ # ddconfig = kwargs['ddconfig']
+ # ddconfig['z_channels'] *= 3
+ # self.decoder = Decoder(**ddconfig)
+ # self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.MaxPool2d((res, 1))
+ y_mp = torch.nn.MaxPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ x = self.to3daware(self.rollout(x))
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ # z = self.to3daware(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs))
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/')
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/')
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/')
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+
+class AutoencoderKLRollOut3DAwareMeanPool(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ try:
+ ckpt_path = kwargs['ckpt_path']
+ kwargs['ckpt_path'] = None
+ except:
+ ckpt_path = None
+
+ super().__init__(*args, **kwargs)
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ ddconfig = kwargs['ddconfig']
+ ddconfig['z_channels'] *= 3
+ self.decoder = Decoder(**ddconfig)
+ self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.AvgPool2d((res, 1))
+ y_mp = torch.nn.AvgPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ x = self.to3daware(self.rollout(x))
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ z = self.to3daware(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs))
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/')
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/')
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_3daware(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_3daware"):
+ self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/')
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+ colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0]
+ res = inputs.shape[1]
+ colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0]
+ colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0]
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+
+class AutoencoderKLGroupConv(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ try:
+ ckpt_path = kwargs['ckpt_path']
+ kwargs['ckpt_path'] = None
+ except:
+ ckpt_path = None
+
+ super().__init__(*args, **kwargs)
+ self.latent_list = []
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ ddconfig = kwargs['ddconfig']
+ # ddconfig['z_channels'] *= 3
+ del self.decoder
+ del self.encoder
+ self.encoder = Encoder_GroupConv(**ddconfig)
+ self.decoder = Decoder_GroupConv(**ddconfig)
+
+ if "mean" in ddconfig:
+ print("Using mean std!!")
+ self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
+ self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
+ else:
+ self.triplane_mean = None
+ self.triplane_std = None
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.MaxPool2d((res, 1))
+ y_mp = torch.nn.MaxPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ # x = self.to3daware(self.rollout(x))
+ x = self.rollout(x)
+ if self.triplane_mean is not None:
+ x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ # z = self.to3daware(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if self.triplane_mean is not None:
+ dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ rgb_input = np.stack([rgb_input[..., 2], rgb_input[..., 1], rgb_input[..., 0]], -1)
+ rgb = np.stack([rgb[..., 2], rgb[..., 1], rgb[..., 0]], -1)
+
+ if b % 2 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)],
+ "val/latent_vis": [wandb.Image(colorize_z)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_3daware(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_3daware"):
+ self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+
+ import os
+ import random
+ import string
+ # z_np = z.detach().cpu().numpy()
+ z_np = inputs.detach().cpu().numpy()
+ fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy'
+ with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f:
+ np.save(f, z_np)
+
+ # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0]
+ # res = inputs.shape[1]
+ # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0]
+ # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0]
+ if batch_idx < 0:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2)
+
+ np_z = z.detach().cpu().numpy()
+ # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f:
+ # np.save(f, np_z)
+
+ self.latent_list.append(np_z)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ if True:
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+ latent = np.concatenate(self.latent_list)
+ q75, q25 = np.percentile(latent.reshape(-1), [75 ,25])
+ median = np.median(latent.reshape(-1))
+ iqr = q75 - q25
+ norm_iqr = iqr * 0.7413
+ print("Norm IQR: {}".format(norm_iqr))
+ print("Inverse Norm IQR: {}".format(1/norm_iqr))
+ print("Median: {}".format(median))
+
+ def loss(self, inputs, reconstructions, posteriors, prefix, batch=None):
+ reconstructions = reconstructions.contiguous()
+ # rec_loss = torch.abs(inputs.contiguous() - reconstructions)
+ # rec_loss = torch.sum(rec_loss) / rec_loss.shape[0]
+ rec_loss = F.mse_loss(inputs.contiguous(), reconstructions)
+ kl_loss = posteriors.kl()
+ # kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ kl_loss = kl_loss.mean()
+ loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss
+
+ ret_dict = {
+ prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(),
+ prefix+'rec_loss': rec_loss,
+ prefix+'kl_loss': kl_loss,
+ prefix+'loss': loss,
+ prefix+'mean': posteriors.mean.mean(),
+ prefix+'logvar': posteriors.logvar.mean(),
+ }
+
+
+ latent = posteriors.mean
+ ret_dict[prefix + 'latent_max'] = latent.max()
+ ret_dict[prefix + 'latent_min'] = latent.min()
+
+ render_weight = self.lossconfig.get("render_weight", 0)
+ tv_weight = self.lossconfig.get("tv_weight", 0)
+ l1_weight = self.lossconfig.get("l1_weight", 0)
+ latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0)
+ latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0)
+
+ triplane_rec = self.unrollout(reconstructions)
+ if render_weight > 0 and batch is not None:
+ rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img'])
+ # render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256
+ render_loss = F.mse_loss(rgb_rendered, target)
+ loss += render_weight * render_loss
+ ret_dict[prefix + 'render_loss'] = render_loss
+ if tv_weight > 0:
+ tvloss_y = F.mse_loss(triplane_rec[:, :, :-1], triplane_rec[:, :, 1:])
+ tvloss_x = F.mse_loss(triplane_rec[:, :, :, :-1], triplane_rec[:, :, :, 1:])
+ tvloss = tvloss_y + tvloss_x
+ loss += tv_weight * tvloss
+ ret_dict[prefix + 'tv_loss'] = tvloss
+ if l1_weight > 0:
+ l1 = (triplane_rec ** 2).mean()
+ loss += l1_weight * l1
+ ret_dict[prefix + 'l1_loss'] = l1
+ if latent_tv_weight > 0:
+ latent = posteriors.mean
+ latent_tv_y = F.mse_loss(latent[:, :, :-1], latent[:, :, 1:])
+ latent_tv_x = F.mse_loss(latent[:, :, :, :-1], latent[:, :, :, 1:])
+ latent_tv_loss = latent_tv_y + latent_tv_x
+ loss += latent_tv_loss * latent_tv_weight
+ ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss
+ if latent_l1_weight > 0:
+ latent = posteriors.mean
+ latent_l1_loss = (latent ** 2).mean()
+ loss += latent_l1_loss * latent_l1_weight
+ ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss
+
+ return loss, ret_dict
+
+
+class AutoencoderKLGroupConvLateFusion(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ try:
+ ckpt_path = kwargs['ckpt_path']
+ kwargs['ckpt_path'] = None
+ except:
+ ckpt_path = None
+
+ super().__init__(*args, **kwargs)
+ self.latent_list = []
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ ddconfig = kwargs['ddconfig']
+ del self.decoder
+ del self.encoder
+ self.encoder = Encoder_GroupConv_LateFusion(**ddconfig)
+ self.decoder = Decoder_GroupConv_LateFusion(**ddconfig)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.MaxPool2d((res, 1))
+ y_mp = torch.nn.MaxPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ x = self.rollout(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_3daware(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_3daware"):
+ self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+ # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0]
+ # res = inputs.shape[1]
+ # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0]
+ # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0]
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2)
+
+ np_z = z.detach().cpu().numpy()
+ # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f:
+ # np.save(f, np_z)
+
+ self.latent_list.append(np_z)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+ latent = np.concatenate(self.latent_list)
+ q75, q25 = np.percentile(latent.reshape(-1), [75 ,25])
+ median = np.median(latent.reshape(-1))
+ iqr = q75 - q25
+ norm_iqr = iqr * 0.7413
+ print("Norm IQR: {}".format(norm_iqr))
+ print("Inverse Norm IQR: {}".format(1/norm_iqr))
+ print("Median: {}".format(median))
+
+
+from module.model_2d import ViTEncoder, ViTDecoder
+
+class AutoencoderVIT(AutoencoderKL):
+ def __init__(self, *args, **kwargs):
+ try:
+ ckpt_path = kwargs['ckpt_path']
+ kwargs['ckpt_path'] = None
+ except:
+ ckpt_path = None
+
+ super().__init__(*args, **kwargs)
+ self.latent_list = []
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+
+ ddconfig = kwargs['ddconfig']
+ # ddconfig['z_channels'] *= 3
+ del self.decoder
+ del self.encoder
+ del self.quant_conv
+ del self.post_quant_conv
+
+ assert ddconfig["z_channels"] == 256
+ self.encoder = ViTEncoder(
+ image_size=(256, 256*3),
+ patch_size=(256//32, 256//32),
+ dim=768,
+ depth=12,
+ heads=12,
+ mlp_dim=3072,
+ channels=8)
+ self.decoder = ViTDecoder(
+ image_size=(256, 256*3),
+ patch_size=(256//32, 256//32),
+ dim=768,
+ depth=12,
+ heads=12,
+ mlp_dim=3072,
+ channels=8)
+
+ self.quant_conv = torch.nn.Conv2d(768, 2*self.embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, 768, 1)
+
+ if "mean" in ddconfig:
+ print("Using mean std!!")
+ self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
+ self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
+ else:
+ self.triplane_mean = None
+ self.triplane_std = None
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.MaxPool2d((res, 1))
+ y_mp = torch.nn.MaxPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ # x = self.to3daware(self.rollout(x))
+ x = self.rollout(x)
+ if self.triplane_mean is not None:
+ x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, unrollout=False):
+ # z = self.to3daware(z)
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if self.triplane_mean is not None:
+ dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ assert not self.norm
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)]
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_3daware(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_3daware"):
+ self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, posterior = self(inputs, sample_posterior=False)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None)
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ z = posterior.mode()
+ colorize_z = self.to_rgb(z)[0]
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+
+ import os
+ import random
+ import string
+ # z_np = z.detach().cpu().numpy()
+ z_np = inputs.detach().cpu().numpy()
+ fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy'
+ with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f:
+ np.save(f, z_np)
+
+ # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0]
+ # res = inputs.shape[1]
+ # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0]
+ # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0]
+ # if batch_idx < 10:
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1)
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2)
+
+ np_z = z.detach().cpu().numpy()
+ # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f:
+ # np.save(f, np_z)
+
+ self.latent_list.append(np_z)
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.psum.device != z.device:
+ self.psum = self.psum.to(z.device)
+ self.psum_sq = self.psum_sq.to(z.device)
+ self.psum_min = self.psum_min.to(z.device)
+ self.psum_max = self.psum_max.to(z.device)
+ self.psum += z.sum()
+ self.psum_sq += (z ** 2).sum()
+ self.psum_min += z.reshape(-1).min(-1)[0]
+ self.psum_max += z.reshape(-1).max(-1)[0]
+ assert len(z.shape) == 4
+ self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3]
+ self.len_dset += 1
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ if True:
+ for b in range(batch_size):
+ if self.renderer_type == 'nerf':
+ rgb_input, cur_psnr_list_input = self.render_triplane(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ rgb, cur_psnr_list = self.render_triplane(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b],
+ batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1)
+ )
+ elif self.renderer_type == 'eg3d':
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ else:
+ raise NotImplementedError
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ # if batch_idx < 10:
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+ latent = np.concatenate(self.latent_list)
+ q75, q25 = np.percentile(latent.reshape(-1), [75 ,25])
+ median = np.median(latent.reshape(-1))
+ iqr = q75 - q25
+ norm_iqr = iqr * 0.7413
+ print("Norm IQR: {}".format(norm_iqr))
+ print("Inverse Norm IQR: {}".format(1/norm_iqr))
+ print("Median: {}".format(median))
diff --git a/3DTopia/model/triplane_vqvae.py b/3DTopia/model/triplane_vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe6be8fae4ad14394eb1f98477f5483e0ae04115
--- /dev/null
+++ b/3DTopia/model/triplane_vqvae.py
@@ -0,0 +1,418 @@
+import os
+import imageio
+import torch
+import wandb
+import numpy as np
+import pytorch_lightning as pl
+import torch.nn.functional as F
+
+from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion
+from utility.initialize import instantiate_from_config
+from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr
+from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane
+from module.quantise import VectorQuantiser
+from module.quantize_taming import EMAVectorQuantizer, VectorQuantizer2, QuantizeEMAReset
+
+class CVQVAE(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ learning_rate=1e-3,
+ ckpt_path=None,
+ ignore_keys=[],
+ colorize_nlabels=None,
+ monitor=None,
+ decoder_ckpt=None,
+ norm=True,
+ renderer_type='nerf',
+ is_cvqvae=False,
+ renderer_config=dict(
+ rgbnet_dim=18,
+ rgbnet_width=128,
+ viewpe=0,
+ feape=0
+ ),
+ vector_quantizer_config=dict(
+ num_embed=1024,
+ beta=0.25,
+ distance='cos',
+ anchor='closest',
+ first_batch=False,
+ contras_loss=True,
+ )
+ ):
+ super().__init__()
+ self.save_hyperparameters()
+ self.norm = norm
+ self.renderer_config = renderer_config
+ self.learning_rate = learning_rate
+
+ ddconfig['double_z'] = False
+ self.encoder = Encoder_GroupConv(**ddconfig)
+ self.decoder = Decoder_GroupConv(**ddconfig)
+
+ self.lossconfig = lossconfig
+
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.decoder_ckpt = decoder_ckpt
+ self.renderer_type = renderer_type
+ if decoder_ckpt is not None:
+ self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt)
+
+ vector_quantizer_config['embed_dim'] = embed_dim
+
+ if is_cvqvae:
+ self.vector_quantizer = VectorQuantiser(
+ **vector_quantizer_config
+ )
+ else:
+ self.vector_quantizer = EMAVectorQuantizer(
+ n_embed=vector_quantizer_config['num_embed'],
+ codebook_dim = embed_dim,
+ beta=vector_quantizer_config['beta']
+ )
+ # self.vector_quantizer = VectorQuantizer2(
+ # n_e = vector_quantizer_config['num_embed'],
+ # e_dim = embed_dim,
+ # beta = vector_quantizer_config['beta']
+ # )
+ # self.vector_quantizer = QuantizeEMAReset(
+ # nb_code = vector_quantizer_config['num_embed'],
+ # code_dim = embed_dim,
+ # mu = vector_quantizer_config['beta'],
+ # )
+
+ self.psum = torch.zeros([1])
+ self.psum_sq = torch.zeros([1])
+ self.psum_min = torch.zeros([1])
+ self.psum_max = torch.zeros([1])
+ self.count = 0
+ self.len_dset = 0
+ self.latent_list = []
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=True)
+ print(f"Restored from {path}")
+
+ def encode(self, x, rollout=False):
+ if rollout:
+ x = self.rollout(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ z_q, loss, (perplexity, min_encodings, encoding_indices) = self.vector_quantizer(moments)
+ return z_q, loss, perplexity, encoding_indices
+
+ def decode(self, z, unrollout=False):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ if unrollout:
+ dec = self.unrollout(dec)
+ return dec
+
+ def forward(self, input):
+ z_q, loss, perplexity, encoding_indices = self.encode(input)
+ dec = self.decode(z_q)
+ return dec, loss, perplexity, encoding_indices
+
+ def rollout(self, triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+ def to3daware(self, triplane):
+ res = triplane.shape[-2]
+ plane1 = triplane[..., :res]
+ plane2 = triplane[..., res:2*res]
+ plane3 = triplane[..., 2*res:3*res]
+
+ x_mp = torch.nn.MaxPool2d((res, 1))
+ y_mp = torch.nn.MaxPool2d((1, res))
+ x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2)
+ y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2)
+ # for plane1
+ plane21 = x_mp_rep(plane2)
+ plane31 = torch.flip(y_mp_rep(plane3), (3,))
+ new_plane1 = torch.cat([plane1, plane21, plane31], 1)
+ # for plane2
+ plane12 = y_mp_rep(plane1)
+ plane32 = x_mp_rep(plane3)
+ new_plane2 = torch.cat([plane2, plane12, plane32], 1)
+ # for plane3
+ plane13 = torch.flip(x_mp_rep(plane1), (2,))
+ plane23 = y_mp_rep(plane2)
+ new_plane3 = torch.cat([plane3, plane13, plane23], 1)
+
+ new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous()
+ return new_plane
+
+ def unrollout(self, triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+ def training_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, vq_loss, perplexity, encoding_indices = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, vq_loss, prefix='train/', batch=batch)
+ log_dict_ae['train/perplexity'] = perplexity
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, vq_loss, perplexity, encoding_indices = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, vq_loss, prefix='val/', batch=None)
+ log_dict_ae['val/perplexity'] = perplexity
+ self.log_dict(log_dict_ae)
+
+ reconstructions = self.unrollout(reconstructions)
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+ batch_size = inputs.shape[0]
+ for b in range(batch_size):
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if b % 4 == 0 and batch_idx < 10:
+ rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1)
+ self.logger.experiment.log({
+ "val/vis": [wandb.Image(rgb_all)],
+ })
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ return self.log_dict
+
+ def to_rgb(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_triplane(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_triplane"):
+ self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def to_rgb_3daware(self, plane):
+ x = plane.float()
+ if not hasattr(self, "colorize_3daware"):
+ self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware)
+ x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8)
+ return x
+
+ def test_step(self, batch, batch_idx):
+ inputs = self.rollout(batch['triplane'])
+ reconstructions, vq_loss, perplexity, encoding_indices = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, vq_loss, prefix='test/', batch=None)
+ log_dict_ae['test/perplexity'] = perplexity
+ self.log_dict(log_dict_ae)
+
+ batch_size = inputs.shape[0]
+ psnr_list = [] # between rec and gt
+ psnr_input_list = [] # between input and gt
+ psnr_rec_list = [] # between input and rec
+
+ colorize_triplane_input = self.to_rgb_triplane(inputs)[0]
+ colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0]
+
+ reconstructions = self.unrollout(reconstructions)
+
+ if self.norm:
+ assert NotImplementedError
+ else:
+ reconstructions_unnormalize = reconstructions
+
+ if True:
+ for b in range(batch_size):
+ rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder(
+ batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+ rgb, cur_psnr_list = self.render_triplane_eg3d_decoder(
+ reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b],
+ )
+
+ cur_psnr_list_rec = []
+ for i in range(rgb.shape[0]):
+ cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i])))
+
+ rgb_input = to8b(rgb_input.detach().cpu().numpy())
+ rgb_gt = to8b(batch['img'][b].detach().cpu().numpy())
+ rgb = to8b(rgb.detach().cpu().numpy())
+
+ if batch_idx < 10:
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1])
+ imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1])
+
+ psnr_list += cur_psnr_list
+ psnr_input_list += cur_psnr_list_input
+ psnr_rec_list += cur_psnr_list_rec
+
+ self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True)
+ self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True)
+ self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True)
+
+ def on_test_epoch_end(self):
+ mean = self.psum / self.count
+ mean_min = self.psum_min / self.len_dset
+ mean_max = self.psum_max / self.len_dset
+ var = (self.psum_sq / self.count) - (mean ** 2)
+ std = torch.sqrt(var)
+
+ print("mean min: {}".format(mean_min))
+ print("mean max: {}".format(mean_max))
+ print("mean: {}".format(mean))
+ print("std: {}".format(std))
+
+ latent = np.concatenate(self.latent_list)
+ q75, q25 = np.percentile(latent.reshape(-1), [75 ,25])
+ median = np.median(latent.reshape(-1))
+ iqr = q75 - q25
+ norm_iqr = iqr * 0.7413
+ print("Norm IQR: {}".format(norm_iqr))
+ print("Inverse Norm IQR: {}".format(1/norm_iqr))
+ print("Median: {}".format(median))
+
+ def loss(self, inputs, reconstructions, vq_loss, prefix, batch=None):
+ reconstructions = reconstructions.contiguous()
+ rec_loss = F.mse_loss(inputs.contiguous(), reconstructions)
+ loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.vq_weight * vq_loss
+
+ ret_dict = {
+ prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(),
+ prefix+'rec_loss': rec_loss,
+ prefix+'vq_loss': vq_loss,
+ prefix+'loss': loss,
+ }
+
+ render_weight = self.lossconfig.get("render_weight", 0)
+ tv_weight = self.lossconfig.get("tv_weight", 0)
+ l1_weight = self.lossconfig.get("l1_weight", 0)
+ latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0)
+ latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0)
+
+ triplane_rec = self.unrollout(reconstructions)
+ if render_weight > 0 and batch is not None:
+ rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img'])
+ render_loss = F.mse(rgb_rendered, target)
+ loss += render_weight * render_loss
+ ret_dict[prefix + 'render_loss'] = render_loss
+ if tv_weight > 0:
+ tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).mean()
+ tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).mean()
+ tvloss = tvloss_y + tvloss_x
+ loss += tv_weight * tvloss
+ ret_dict[prefix + 'tv_loss'] = tvloss
+ if l1_weight > 0:
+ l1 = (triplane_rec ** 2).mean()
+ loss += l1_weight * l1
+ ret_dict[prefix + 'l1_loss'] = l1
+
+ ret_dict[prefix+'loss'] = loss
+
+ return loss, ret_dict
+
+ def create_eg3d_decoder(self, decoder_ckpt):
+ triplane_decoder = Renderer_TriPlane(**self.renderer_config)
+ pretrain_pth = torch.load(decoder_ckpt, map_location='cpu')
+ pretrain_pth = {
+ '.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items()
+ }
+ # import pdb; pdb.set_trace()
+ triplane_decoder.load_state_dict(pretrain_pth)
+ render_kwargs = {
+ 'depth_resolution': 128,
+ 'disparity_space_sampling': False,
+ 'box_warp': 2.4,
+ 'depth_resolution_importance': 128,
+ 'clamp_mode': 'softplus',
+ 'white_back': True,
+ 'det': True
+ }
+ return triplane_decoder, render_kwargs
+
+ def render_triplane_eg3d_decoder(self, triplane, batch_rays, target):
+ ray_o = batch_rays[:, 0]
+ ray_d = batch_rays[:, 1]
+ psnr_list = []
+ rec_img_list = []
+ res = triplane.shape[-2]
+ for i in range(ray_o.shape[0]):
+ with torch.no_grad():
+ render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res),
+ ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False)
+ rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1)
+ psnr = mse2psnr(img2mse(rec_img[0], target[i]))
+ psnr_list.append(psnr)
+ rec_img_list.append(rec_img)
+ return torch.cat(rec_img_list, 0), psnr_list
+
+ def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024):
+ assert batch_rays.shape[1] == 1
+ sel = torch.randint(batch_rays.shape[-2], [sample_num])
+ ray_o = batch_rays[:, 0, 0, sel]
+ ray_d = batch_rays[:, 0, 1, sel]
+ res = triplane.shape[-2]
+ render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res),
+ ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False)
+ rec_img = render_out['rgb_marched']
+ target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :]
+ return rec_img, target
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters())+
+ list(self.vector_quantizer.parameters()),
+ lr=lr)
+ return opt_ae
diff --git a/3DTopia/module/model_2d.py b/3DTopia/module/model_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e501c20953c37314b00c4689b37c2c969a60c47d
--- /dev/null
+++ b/3DTopia/module/model_2d.py
@@ -0,0 +1,2206 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from utility.initialize import instantiate_from_config
+from .nn_2d import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none", "vanilla_groupconv", "crossattention"], f'attn_type {attn_type} unknown'
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == 'vanilla_groupconv':
+ return AttnBlock_GroupConv(in_channels)
+ elif attn_type == 'crossattention':
+ num_heads = 8
+ return TriplaneAttentionBlock(in_channels, num_heads, in_channels // num_heads, True)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ # print("Working with z of shape {} = {} dimensions.".format(
+ # self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class ResnetBlock_GroupConv(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels * 3, 32 * 3)
+ self.conv1 = torch.nn.Conv2d(in_channels * 3,
+ out_channels * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels * 3, 32 * 3)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels * 3,
+ out_channels * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels * 3,
+ out_channels * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels * 3,
+ out_channels * 3,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ groups=3)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ assert temb is None
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+def rollout(triplane):
+ res = triplane.shape[-1]
+ ch = triplane.shape[1]
+ triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res)
+ return triplane
+
+def unrollout(triplane):
+ res = triplane.shape[-2]
+ ch = 3 * triplane.shape[1]
+ triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res)
+ return triplane
+
+class Upsample_GroupConv(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels * 3,
+ in_channels * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample_GroupConv(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels * 3,
+ in_channels * 3,
+ kernel_size=3,
+ stride=2,
+ padding=0,
+ groups=3)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+class AttnBlock_GroupConv(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x, temp=None):
+ x = rollout(x)
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return unrollout(x+h_)
+
+
+from torch import nn, einsum
+from inspect import isfunction
+from einops import rearrange, repeat
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ x = x.permute(0, 2, 1)
+ context = context.permute(0, 2, 1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out).permute(0, 2, 1)
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+class TriplaneAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+
+ self.plane1_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+ self.plane2_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+ self.plane3_ca = CrossAttention(channels, channels, self.num_heads, num_head_channels)
+
+ def forward(self, x, temp=None):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ x = rollout(x)
+
+ b, c, *spatial = x.shape
+ res = x.shape[-2]
+ plane1 = x[..., :res].reshape(b, c, -1)
+ plane2 = x[..., res:res*2].reshape(b, c, -1)
+ plane3 = x[..., 2*res:3*res].reshape(b, c, -1)
+ x = x.reshape(b, c, -1)
+
+ plane1_output = self.plane1_ca(self.norm(plane1), self.norm(x))
+ plane2_output = self.plane2_ca(self.norm(plane2), self.norm(x))
+ plane3_output = self.plane3_ca(self.norm(plane3), self.norm(x))
+
+ h = torch.cat([plane1_output, plane2_output, plane3_output], -1)
+
+ x = (x + h).reshape(b, c, *spatial)
+
+ return unrollout(x)
+
+
+class Encoder_GroupConv(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False,
+ attn_type="vanilla_groupconv", mid_layers=1,
+ **ignore_kwargs):
+ super().__init__()
+ assert not use_linear_attn
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ # self.conv_in = torch.nn.Conv2d(in_channels,
+ # self.ch,
+ # kernel_size=3,
+ # stride=1,
+ # padding=1)
+ self.conv_in = torch.nn.Conv2d(in_channels * 3,
+ self.ch * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample_GroupConv(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.attn_type = attn_type
+ self.mid = nn.Module()
+ if attn_type == 'crossattention':
+ self.mid.block_1 = nn.ModuleList()
+ for _ in range(mid_layers):
+ self.mid.block_1.append(
+ ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ )
+ self.mid.block_1.append(
+ make_attn(block_in, attn_type=attn_type)
+ )
+ else:
+ self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in * 3, 32 * 3)
+ self.conv_out = torch.nn.Conv2d(block_in * 3,
+ 2*z_channels * 3 if double_z else z_channels * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ x = unrollout(x)
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ if self.attn_type == 'crossattention':
+ for m in self.mid.block_1:
+ h = m(h, temb)
+ else:
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+
+ h = rollout(h)
+
+ return h
+
+class Decoder_GroupConv(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla_groupconv", mid_layers=1, **ignorekwargs):
+ super().__init__()
+ assert not use_linear_attn
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ # print("Working with z of shape {} = {} dimensions.".format(
+ # self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels * 3,
+ block_in * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+
+ # middle
+ self.mid = nn.Module()
+ self.attn_type = attn_type
+ if attn_type == 'crossattention':
+ self.mid.block_1 = nn.ModuleList()
+ for _ in range(mid_layers):
+ self.mid.block_1.append(
+ ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ )
+ self.mid.block_1.append(
+ make_attn(block_in, attn_type=attn_type)
+ )
+ else:
+ self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample_GroupConv(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in * 3, 32 * 3)
+ self.conv_out = torch.nn.Conv2d(block_in * 3,
+ out_ch * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ z = unrollout(z)
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ if self.attn_type == 'crossattention':
+ for m in self.mid.block_1:
+ h = m(h, temb)
+ else:
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+
+ h = rollout(h)
+
+ return h
+
+
+
+# not success attempts
+class CrossAttnFuseBlock_GroupConv(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.q1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.q2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.fuse_out = torch.nn.Conv2d(in_channels * 3,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ x = rollout(x)
+
+ b, c, *spatial = x.shape
+ res = x.shape[-2]
+ plane1 = x[..., :res].reshape(b, c, res, res)
+ plane2 = x[..., res:res*2].reshape(b, c, res, res)
+ plane3 = x[..., 2*res:3*res].reshape(b, c, res, res)
+
+ # h_ = x
+ # h_ = self.norm(h_)
+ # q = self.q(h_)
+ # k = self.k(h_)
+ # v = self.v(h_)
+
+ q0 = self.q0(self.norm(plane2))
+ k0 = self.k0(self.norm(plane2))
+ v0 = self.v0(self.norm(plane2))
+
+ q1 = self.q1(self.norm(plane2))
+ k1 = self.k1(self.norm(plane1))
+ v1 = self.v1(self.norm(plane1))
+
+ q2 = self.q2(self.norm(plane2))
+ k2 = self.k2(self.norm(plane3))
+ v2 = self.v2(self.norm(plane3))
+
+ def compute_attention(q, k, v):
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ return h_
+
+ h0 = compute_attention(q0, k0, v0)
+ h0 = self.proj_out0(h0)
+
+ h1 = compute_attention(q1, k1, v1)
+ h1 = self.proj_out1(h1)
+
+ h2 = compute_attention(q2, k2, v2)
+ h2 = self.proj_out2(h2)
+
+ fuse_out = self.fuse_out(
+ torch.cat([h0, h1, h2], 1)
+ )
+
+ return fuse_out
+
+class CrossAttnDecodeBlock_GroupConv(nn.Module):
+ def __init__(self, in_channels, h, w):
+ super().__init__()
+ self.in_channels = in_channels
+ self.h = h
+ self.w = w
+
+ self.norm = Normalize(in_channels)
+ self.q0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.q1 = torch.nn.Parameter(torch.randn(1, self.in_channels, h, w))
+ self.q1.requires_grad = True
+
+ self.k1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.q2 = torch.nn.Parameter(torch.randn(1, self.in_channels, h, w))
+ self.q2.requires_grad = True
+
+ self.k2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out0 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out1 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out2 = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.fuse_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ # x = rollout(x)
+
+ b, c, *spatial = x.shape
+ res = x.shape[-2]
+ # plane1 = x[..., :res].reshape(b, c, res, res)
+ # plane2 = x[..., res:res*2].reshape(b, c, res, res)
+ # plane3 = x[..., 2*res:3*res].reshape(b, c, res, res)
+
+ # h_ = x
+ # h_ = self.norm(h_)
+ # q = self.q(h_)
+ # k = self.k(h_)
+ # v = self.v(h_)
+
+ q0 = self.q0(self.norm(x))
+ k0 = self.k0(self.norm(x))
+ v0 = self.v0(self.norm(x))
+
+ q1 = self.q1.repeat(b, 1, 1, 1)
+ k1 = self.k1(self.norm(x))
+ v1 = self.v1(self.norm(x))
+
+ q2 = self.q2.repeat(b, 1, 1, 1)
+ k2 = self.k2(self.norm(x))
+ v2 = self.v2(self.norm(x))
+
+ def compute_attention(q, k, v):
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+ return h_
+
+ h0 = compute_attention(q0, k0, v0)
+ h0 = self.proj_out0(h0)
+
+ h1 = compute_attention(q1, k1, v1)
+ h1 = self.proj_out1(h1)
+
+ h2 = compute_attention(q2, k2, v2)
+ h2 = self.proj_out2(h2)
+
+ fuse_out = self.fuse_out(
+ torch.cat([h1, h0, h2], -1)
+ )
+
+ fuse_out = unrollout(fuse_out)
+
+ return fuse_out
+
+class Encoder_GroupConv_LateFusion(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla_groupconv",
+ **ignore_kwargs):
+ super().__init__()
+ assert not use_linear_attn
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels * 3,
+ self.ch * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample_GroupConv(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # fuse to one plane
+ self.fuse = CrossAttnFuseBlock_GroupConv(block_in)
+
+ # end
+ self.norm_out = Normalize(block_in, 32)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ x = unrollout(x)
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ h = self.fuse(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+
+ # h = rollout(h)
+
+ return h
+
+class Decoder_GroupConv_LateFusion(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla_groupconv", **ignorekwargs):
+ super().__init__()
+ assert not use_linear_attn
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ # print("Working with z of shape {} = {} dimensions.".format(
+ # self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # triplane decoder
+ self.triplane_decoder = CrossAttnDecodeBlock_GroupConv(block_in, curr_res, curr_res)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock_GroupConv(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample_GroupConv(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in * 3, 32 * 3)
+ self.conv_out = torch.nn.Conv2d(block_in * 3,
+ out_ch * 3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=3)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ h = self.triplane_decoder(h)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+
+ h = rollout(h)
+
+ return h
+
+
+# VIT Encoder and Decoder from https://github.com/thuanz123/enhancing-transformers/blob/main/enhancing/modules/stage1/layers.py
+# ------------------------------------------------------------------------------------
+# Enhancing Transformers
+# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
+# Licensed under the MIT License [see LICENSE for details]
+# ------------------------------------------------------------------------------------
+# Modified from ViT-Pytorch (https://github.com/lucidrains/vit-pytorch)
+# Copyright (c) 2020 Phil Wang. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import math
+import numpy as np
+from typing import Union, Tuple, List
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from einops.layers.torch import Rearrange
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size):
+ """
+ grid_size: int or (int, int) of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_size = (grid_size, grid_size) if type(grid_size) != tuple else grid_size
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def init_weights(m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ w = m.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim: int, fn: nn.Module) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
+ return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim: int, hidden_dim: int) -> None:
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.Tanh(),
+ nn.Linear(hidden_dim, dim)
+ )
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+
+ self.attend = nn.Softmax(dim = -1)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
+
+ self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity()
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ qkv = self.to_qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
+
+ attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+ attn = self.attend(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ return self.to_out(out)
+
+
+class Transformer(nn.Module):
+ def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for idx in range(depth):
+ layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)),
+ PreNorm(dim, FeedForward(dim, mlp_dim))])
+ self.layers.append(layer)
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+class ViTEncoder(nn.Module):
+ def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
+ dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
+ super().__init__()
+ image_height, image_width = image_size if isinstance(image_size, tuple) \
+ else (image_size, image_size)
+ patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
+ else (patch_size, patch_size)
+
+ assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
+ en_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))
+
+ self.num_patches = (image_height // patch_height) * (image_width // patch_width)
+ self.patch_dim = channels * patch_height * patch_width
+
+ self.to_patch_embedding = nn.Sequential(
+ nn.Conv2d(channels, dim, kernel_size=patch_size, stride=patch_size),
+ Rearrange('b c h w -> b (h w) c'),
+ )
+
+ self.patch_height = patch_height
+ self.patch_width = patch_width
+ self.image_height = image_height
+ self.image_width = image_width
+ self.dim = dim
+
+ self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False)
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
+
+ self.apply(init_weights)
+
+ def forward(self, img: torch.FloatTensor) -> torch.FloatTensor:
+ x = self.to_patch_embedding(img)
+ x = x + self.en_pos_embedding
+ x = self.transformer(x)
+
+ x = Rearrange('b h w c -> b c h w')(x.reshape(-1, self.image_height // self.patch_height, self.image_width // self.patch_width, self.dim))
+
+ return x
+
+
+class ViTDecoder(nn.Module):
+ def __init__(self, image_size: Union[Tuple[int, int], int], patch_size: Union[Tuple[int, int], int],
+ dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 3, dim_head: int = 64) -> None:
+ super().__init__()
+ image_height, image_width = image_size if isinstance(image_size, tuple) \
+ else (image_size, image_size)
+ patch_height, patch_width = patch_size if isinstance(patch_size, tuple) \
+ else (patch_size, patch_size)
+
+ assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
+ de_pos_embedding = get_2d_sincos_pos_embed(dim, (image_height // patch_height, image_width // patch_width))
+
+ self.num_patches = (image_height // patch_height) * (image_width // patch_width)
+ self.patch_dim = channels * patch_height * patch_width
+
+ self.patch_height = patch_height
+ self.patch_width = patch_width
+ self.image_height = image_height
+ self.image_width = image_width
+ self.dim = dim
+
+ self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
+ self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False)
+ self.to_pixel = nn.Sequential(
+ Rearrange('b (h w) c -> b c h w', h=image_height // patch_height),
+ nn.ConvTranspose2d(dim, channels, kernel_size=patch_size, stride=patch_size)
+ )
+
+ self.apply(init_weights)
+
+ def forward(self, token: torch.FloatTensor) -> torch.FloatTensor:
+ token = Rearrange('b c h w -> b (h w) c')(token)
+
+ x = token + self.de_pos_embedding
+ x = self.transformer(x)
+ x = self.to_pixel(x)
+
+ return x
+
+ def get_last_layer(self) -> nn.Parameter:
+ return self.to_pixel[-1].weight
diff --git a/3DTopia/module/nn_2d.py b/3DTopia/module/nn_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d44cc7f0b0aa6106053343d509c73bf7b466aba0
--- /dev/null
+++ b/3DTopia/module/nn_2d.py
@@ -0,0 +1,546 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+# zero123/zero123/ldm/modules/diffusionmodules/util.py
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+
+# zero123/zero123/ldm/modules/attention.py
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
+ disable_self_attn=disable_self_attn)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ x = self.proj_out(x)
+ return x + x_in
+
+
+def exists(x):
+ return x is not None
diff --git a/3DTopia/module/quantise.py b/3DTopia/module/quantise.py
new file mode 100644
index 0000000000000000000000000000000000000000..07540d6d99f649a6cbd9cb3c8b6a608a77f4e4fa
--- /dev/null
+++ b/3DTopia/module/quantise.py
@@ -0,0 +1,159 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantiser(nn.Module):
+ """
+ Improved version over vector quantiser, with the dynamic initialisation
+ for these unoptimised "dead" points.
+ num_embed: number of codebook entry
+ embed_dim: dimensionality of codebook entry
+ beta: weight for the commitment loss
+ distance: distance for looking up the closest code
+ anchor: anchor sampled methods
+ first_batch: if true, the offline version of our model
+ contras_loss: if true, use the contras_loss to further improve the performance
+ """
+ def __init__(self, num_embed, embed_dim, beta, distance='cos',
+ anchor='probrandom', first_batch=False, contras_loss=False):
+ super().__init__()
+
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+ self.beta = beta
+ self.distance = distance
+ self.anchor = anchor
+ self.first_batch = first_batch
+ self.contras_loss = contras_loss
+ self.decay = 0.99
+ self.init = False
+
+ self.pool = FeaturePool(self.num_embed, self.embed_dim)
+ self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
+ self.register_buffer("embed_prob", torch.zeros(self.num_embed))
+
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.embed_dim)
+
+ # clculate the distance
+ if self.distance == 'l2':
+ # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
+ torch.sum(self.embedding.weight ** 2, dim=1) + \
+ 2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
+ elif self.distance == 'cos':
+ # cosine distances from z to embeddings e_j
+ normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
+ normed_codebook = F.normalize(self.embedding.weight, dim=1)
+ d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
+
+ # encoding
+ sort_distance, indices = d.sort(dim=1)
+ # look up the closest point for the indices
+ encoding_indices = indices[:,-1]
+ encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+
+ # quantise and unflatten
+ z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+ # count
+ # import pdb
+ # pdb.set_trace()
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ min_encodings = encodings
+
+ # online clustered reinitialisation for unoptimized points
+ if self.training:
+ # calculate the average usage of code entries
+ self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
+ # running average updates
+ if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
+ # closest sampling
+ if self.anchor == 'closest':
+ sort_distance, indices = d.sort(dim=0)
+ random_feat = z_flattened.detach()[indices[-1,:]]
+ # feature pool based random sampling
+ elif self.anchor == 'random':
+ random_feat = self.pool.query(z_flattened.detach())
+ # probabilitical based random sampling
+ elif self.anchor == 'probrandom':
+ norm_distance = F.softmax(d.t(), dim=1)
+ prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
+ random_feat = z_flattened.detach()[prob]
+ # decay parameter based on the average usage
+ decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
+ self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
+ if self.first_batch:
+ self.init = True
+ # contrastive loss
+ if self.contras_loss:
+ sort_distance, indices = d.sort(dim=0)
+ dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
+ dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
+ dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
+ contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
+ loss += contra_loss
+
+ return z_q, loss, (perplexity, min_encodings, encoding_indices)
+
+class FeaturePool():
+ """
+ This class implements a feature buffer that stores previously encoded features
+
+ This buffer enables us to initialize the codebook using a history of generated features
+ rather than the ones produced by the latest encoders
+ """
+ def __init__(self, pool_size, dim=64):
+ """
+ Initialize the FeaturePool class
+
+ Parameters:
+ pool_size(int) -- the size of featue buffer
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.nums_features = 0
+ self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
+
+ def query(self, features):
+ """
+ return features from the pool
+ """
+ self.features = self.features.to(features.device)
+ if self.nums_features < self.pool_size:
+ if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
+ random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
+ self.features = features[random_feat_id]
+ self.nums_features = self.pool_size
+ else:
+ # if the mini-batch is not large nuough, just store it for the next update
+ num = self.nums_features + features.size(0)
+ self.features[self.nums_features:num] = features
+ self.nums_features = num
+ else:
+ if features.size(0) > int(self.pool_size):
+ random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
+ self.features = features[random_feat_id]
+ else:
+ random_id = torch.randperm(self.pool_size)
+ self.features[random_id[:features.size(0)]] = features
+
+ return self.features
diff --git a/3DTopia/module/quantize_taming.py b/3DTopia/module/quantize_taming.py
new file mode 100644
index 0000000000000000000000000000000000000000..e018d4c07ed5ddd513ca46ea2f063000bdd6ddc1
--- /dev/null
+++ b/3DTopia/module/quantize_taming.py
@@ -0,0 +1,564 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ #.........\end
+
+ # with:
+ # .........\start
+ #min_encoding_indices = torch.argmin(d, dim=1)
+ #z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:,None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:,self.used,...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:,self.used,...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b*h*w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = 0
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0],-1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad = False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ #normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, codebook_dim, beta, decay=0.99, eps=1e-5,
+ remap=None, unknown_index="random"):
+ super().__init__()
+ self.codebook_dim = codebook_dim
+ self.num_tokens = n_embed
+ self.beta = beta
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def dequantize(self, ids):
+ return self.embedding(ids)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ #z, 'b c h w -> b h w c'
+ z = rearrange(z, 'b c h w -> b h w c')
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ #EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ #EMA embedding average
+ embed_sum = encodings.transpose(0,1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ #normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ #z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
+ return z_q, loss, (perplexity, encodings, encoding_indices)
+
+
+class QuantizeEMAReset(nn.Module):
+ def __init__(self, nb_code, code_dim, mu):
+ super().__init__()
+ self.nb_code = nb_code
+ self.code_dim = code_dim
+ self.mu = mu
+ self.reset_codebook()
+
+ def reset_codebook(self):
+ self.init = False
+ self.code_sum = None
+ self.code_count = None
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim).to(device))
+
+ def _tile(self, x):
+ nb_code_x, code_dim = x.shape
+ if nb_code_x < self.nb_code:
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
+ std = 0.01 / np.sqrt(code_dim)
+ out = x.repeat(n_repeats, 1)
+ out = out + torch.randn_like(out) * std
+ else :
+ out = x
+ return out
+
+ def init_codebook(self, x):
+ out = self._tile(x)
+ self.codebook = out[:self.nb_code]
+ self.code_sum = self.codebook.clone()
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
+ self.init = True
+
+ @torch.no_grad()
+ def compute_perplexity(self, code_idx) :
+ # Calculate new centres
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
+
+ code_count = code_onehot.sum(dim=-1) # nb_code
+ prob = code_count / torch.sum(code_count)
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
+ return perplexity
+
+ @torch.no_grad()
+ def update_codebook(self, x, code_idx):
+
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
+
+ code_sum = torch.matmul(code_onehot, x) # nb_code, w
+ code_count = code_onehot.sum(dim=-1) # nb_code
+
+ out = self._tile(x)
+ code_rand = out[:self.nb_code]
+
+ # Update centres
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum # w, nb_code
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count # nb_code
+
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
+
+ self.codebook = usage * code_update + (1 - usage) * code_rand
+ prob = code_count / torch.sum(code_count)
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
+
+
+ return perplexity
+
+ def quantize(self, x):
+ # Calculate latent code x_l
+ k_w = self.codebook.t()
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - 2 * torch.matmul(x, k_w) + torch.sum(k_w ** 2, dim=0,
+ keepdim=True) # (N * L, b)
+ _, code_idx = torch.min(distance, dim=-1)
+ return code_idx
+
+ def dequantize(self, code_idx):
+ x = F.embedding(code_idx, self.codebook)
+ return x
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+
+ # Preprocess
+ # x = self.preprocess(x)
+ x = rearrange(x, 'b c h w -> b h w c')
+ x = x.reshape(-1, self.code_dim)
+
+ # Init codebook if not inited
+ if self.training and not self.init:
+ self.init_codebook(x)
+
+ # quantize and dequantize through bottleneck
+ code_idx = self.quantize(x)
+ x_d = self.dequantize(code_idx)
+
+ # Update embeddings
+ if self.training:
+ perplexity = self.update_codebook(x, code_idx)
+ else :
+ perplexity = self.compute_perplexity(code_idx)
+
+ # Loss
+ commit_loss = F.mse_loss(x, x_d.detach())
+
+ # Passthrough
+ x_d = x + (x_d - x).detach()
+
+ # Postprocess
+ x_d = x_d.view(N, H, W, C).permute(0, 3, 1, 2).contiguous()
+
+ return x_d, commit_loss, (perplexity, code_idx, code_idx)
\ No newline at end of file
diff --git a/3DTopia/module/renderer.py b/3DTopia/module/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff0e703093f86bbe1693960db5574815aca6a56
--- /dev/null
+++ b/3DTopia/module/renderer.py
@@ -0,0 +1,463 @@
+import os
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# TriPlane Utils
+class MipRayMarcher2(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def run_forward(self, colors, densities, depths, rendering_options):
+ deltas = depths[:, :, 1:] - depths[:, :, :-1]
+ colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
+ densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
+ depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+
+
+ if rendering_options['clamp_mode'] == 'softplus':
+ densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better
+ else:
+ assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
+
+ density_delta = densities_mid * deltas
+
+ alpha = 1 - torch.exp(-density_delta)
+
+ alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
+ weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
+
+ composite_rgb = torch.sum(weights * colors_mid, -2)
+ weight_total = weights.sum(2)
+ # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
+ composite_depth = torch.sum(weights * depths_mid, -2)
+
+ # clip the composite to min/max range of depths
+ composite_depth = torch.nan_to_num(composite_depth, float('inf'))
+ # composite_depth = torch.nan_to_num(composite_depth, 0.)
+ composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
+
+ if rendering_options.get('white_back', False):
+ composite_rgb = composite_rgb + 1 - weight_total
+
+ composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
+
+ return composite_rgb, composite_depth, weights
+
+ def forward(self, colors, densities, depths, rendering_options):
+ composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
+
+ return composite_rgb, composite_depth, weights
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+ """
+ return torch.tensor([[[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]],
+ [[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+
+ # # ORIGINAL
+ # N, M, C = coordinates.shape
+ # xy_coords = coordinates[..., [0, 1]]
+ # xz_coords = coordinates[..., [0, 2]]
+ # zx_coords = coordinates[..., [2, 0]]
+ # return torch.stack([xy_coords, xz_coords, zx_coords], dim=1).reshape(N*3, M, 2)
+
+ # FIXED
+ N, M, _ = coordinates.shape
+ xy_coords = coordinates[..., [0, 1]]
+ yz_coords = coordinates[..., [1, 2]]
+ zx_coords = coordinates[..., [2, 0]]
+ return torch.stack([xy_coords, yz_coords, zx_coords], dim=1).reshape(N*3, M, 2)
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ plane_features = plane_features.view(N*n_planes, C, H, W)
+
+ coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
+
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+def sample_from_3dgrid(grid, coordinates):
+ """
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+ Expects grid in shape (1, channels, H, W, D)
+ (Also works if grid has batch size)
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
+ """
+ batch_size, n_coords, n_dims = coordinates.shape
+ sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ N, C, H, W, D = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
+ return sampled_features
+
+class FullyConnectedLayer(nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ # self.weight = torch.nn.Parameter(torch.full([out_features, in_features], np.float32(0)))
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+class TriPlane_Decoder(nn.Module):
+ def __init__(self, dim=12, width=128):
+ super().__init__()
+ self.net = torch.nn.Sequential(
+ FullyConnectedLayer(dim, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, 1 + 3)
+ )
+
+ def forward(self, sampled_features):
+ sampled_features = sampled_features.mean(1)
+ x = sampled_features
+
+ N, M, C = x.shape
+ x = x.view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+ return {'rgb': rgb, 'sigma': sigma}
+
+class Renderer_TriPlane(nn.Module):
+ def __init__(self, rgbnet_dim=18, rgbnet_width=128):
+ super(Renderer_TriPlane, self).__init__()
+ self.decoder = TriPlane_Decoder(dim=rgbnet_dim//3, width=rgbnet_width)
+ self.ray_marcher = MipRayMarcher2()
+ self.plane_axes = generate_planes()
+
+ def forward(self, planes, ray_origins, ray_directions, rendering_options, whole_img=False):
+ self.plane_axes = self.plane_axes.to(ray_origins.device)
+
+ ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
+ is_ray_valid = ray_end > ray_start
+ if torch.any(is_ray_valid).item():
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
+
+ batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
+
+ # Coarse Pass
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+
+
+ out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options)
+ colors_coarse = out['rgb']
+ densities_coarse = out['sigma']
+ colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1])
+ densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1)
+
+ # Fine Pass
+ N_importance = rendering_options['depth_resolution_importance']
+ if N_importance > 0:
+ _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
+
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+
+ out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options)
+ colors_fine = out['rgb']
+ densities_fine = out['sigma']
+ colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1])
+ densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1)
+
+ all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
+ depths_fine, colors_fine, densities_fine)
+
+ # Aggregate
+ rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
+ else:
+ rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+ # return rgb_final, depth_final, weights.sum(2)
+ if whole_img:
+ H = W = int(ray_origins.shape[1] ** 0.5)
+ rgb_final = rgb_final.permute(0, 2, 1).reshape(-1, 3, H, W).contiguous()
+ depth_final = depth_final.permute(0, 2, 1).reshape(-1, 1, H, W).contiguous()
+ depth_final = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min())
+ depth_final = depth_final.repeat(1, 3, 1, 1)
+ # rgb_final = torch.clip(rgb_final, min=0, max=1)
+ rgb_final = (rgb_final + 1) / 2.
+ weights = weights.sum(2).reshape(rgb_final.shape[0], rgb_final.shape[2], rgb_final.shape[3])
+ return {
+ 'rgb_marched': rgb_final,
+ 'depth_final': depth_final,
+ 'weights': weights,
+ }
+ else:
+ rgb_final = (rgb_final + 1) / 2.
+ return {
+ 'rgb_marched': rgb_final,
+ 'depth_final': depth_final,
+ }
+
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
+ sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+
+ out = decoder(sampled_features)
+ if options.get('density_noise', 0) > 0:
+ out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
+ return out
+
+ def sort_samples(self, all_depths, all_colors, all_densities):
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+ return all_depths, all_colors, all_densities
+
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
+ all_depths = torch.cat([depths1, depths2], dim = -2)
+ all_colors = torch.cat([colors1, colors2], dim = -2)
+ all_densities = torch.cat([densities1, densities2], dim = -2)
+
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+
+ return all_depths, all_colors, all_densities
+
+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
+ """
+ Return depths of approximately uniformly spaced samples along rays.
+ """
+ N, M, _ = ray_origins.shape
+ if disparity_space_sampling:
+ depths_coarse = torch.linspace(0,
+ 1,
+ depth_resolution,
+ device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = 1/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
+ else:
+ if type(ray_start) == torch.Tensor:
+ depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+ else:
+ depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+
+ return depths_coarse
+
+ def sample_importance(self, z_vals, weights, N_importance):
+ """
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
+ """
+ with torch.no_grad():
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+ weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+ # smooth weights
+ weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+ weights = weights + 0.01
+
+ z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+ N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
+ return importance_z_vals
+
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+ """
+ Sample @N_importance samples from @bins with distribution defined by @weights.
+ Inputs:
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+ weights: (N_rays, N_samples_)
+ N_importance: the number of samples to draw from the distribution
+ det: deterministic or not
+ eps: a small number to prevent division by zero
+ Outputs:
+ samples: the sampled samples
+ """
+ N_rays, N_samples_ = weights.shape
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
+ # padded to 0~1 inclusive
+
+ if det:
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
+ u = u.expand(N_rays, N_importance)
+ else:
+ u = torch.rand(N_rays, N_importance, device=bins.device)
+ u = u.contiguous()
+
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.clamp_min(inds-1, 0)
+ above = torch.clamp_max(inds, N_samples_)
+
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
+ cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
+
+ denom = cdf_g[...,1]-cdf_g[...,0]
+ denom[denomnchpwq', x)
+ x = x.reshape(b, 3, self.out_channel//3, self.out_reso, self.out_reso).contiguous()
+ return x
+
+
+class SingleImageToTriplaneVAE(nn.Module):
+ def __init__(self, backbone='dino_vits8', input_reso=256, out_reso=128, out_channel=18, z_dim=32,
+ decoder_depth=16, decoder_heads=16, decoder_mlp_dim=1024, decoder_dim_head=64, dropout=0):
+ super().__init__()
+ self.backbone = backbone
+
+ self.input_image_size = input_reso
+ self.out_reso = out_reso
+ self.out_channel = out_channel
+ self.z_dim = z_dim
+
+ self.decoder_depth = decoder_depth
+ self.decoder_heads = decoder_heads
+ self.decoder_mlp_dim = decoder_mlp_dim
+ self.decoder_dim_head = decoder_dim_head
+
+ self.dropout = dropout
+ self.patch_size = 8 if '8' in backbone else 16
+
+ if 'dino' in backbone:
+ self.vit = torch.hub.load('facebookresearch/dino:main', backbone)
+ self.embed_dim = self.vit.embed_dim
+ self.preprocess = None
+ else:
+ raise NotImplementedError
+
+ self.fc_mu = nn.Linear(self.embed_dim, self.z_dim)
+ self.fc_var = nn.Linear(self.embed_dim, self.z_dim)
+
+ self.vit_decoder = TriplaneDecoder((self.input_image_size // self.patch_size) ** 2, self.z_dim,
+ depth=self.decoder_depth, heads=self.decoder_heads, mlp_dim=self.decoder_mlp_dim,
+ out_channel=self.out_channel, out_reso=self.out_reso, dim_head = self.decoder_dim_head, dropout=0)
+
+ def forward(self, x, is_train):
+ assert x.shape[-1] == self.input_image_size
+ bs = x.shape[0]
+ if 'dino' in self.backbone:
+ z = self.vit.get_intermediate_layers(x, n=1)[0][:, 1:] # [bs, 1024, self.embed_dim]
+ else:
+ raise NotImplementedError
+
+ z = z.reshape(-1, z.shape[-1])
+ mu = self.fc_mu(z)
+ logvar = self.fc_var(z)
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ if is_train:
+ rep_z = eps * std + mu
+ else:
+ rep_z = eps
+ rep_z = rep_z.reshape(bs, -1, self.z_dim)
+ out = self.vit_decoder(rep_z)
+
+ return out, mu, logvar
diff --git a/3DTopia/sample_stage1.py b/3DTopia/sample_stage1.py
new file mode 100644
index 0000000000000000000000000000000000000000..60a1192df3f429604c28721b94fe9b9eb28e07c5
--- /dev/null
+++ b/3DTopia/sample_stage1.py
@@ -0,0 +1,299 @@
+import os
+import cv2
+import json
+import torch
+import mcubes
+import trimesh
+import argparse
+import numpy as np
+from tqdm import tqdm
+import imageio.v2 as imageio
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.dpm_solver import DPMSolverSampler
+
+from utility.initialize import instantiate_from_config, get_obj_from_str
+from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
+from utility.triplane_renderer.renderer import get_rays, to8b
+from safetensors.torch import load_file
+from huggingface_hub import hf_hub_download
+
+import warnings
+warnings.filterwarnings("ignore", category=UserWarning)
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+def add_text(rgb, caption):
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ # org
+ gap = 30
+ org = (gap, gap)
+ # fontScale
+ fontScale = 0.6
+ # Blue color in BGR
+ color = (255, 0, 0)
+ # Line thickness of 2 px
+ thickness = 1
+ break_caption = []
+ for i in range(len(caption) // 30 + 1):
+ break_caption_i = caption[i*30:(i+1)*30]
+ break_caption.append(break_caption_i)
+ for i, bci in enumerate(break_caption):
+ cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA)
+ return rgb
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default='configs/default.yaml')
+ parser.add_argument("--ckpt", type=str, default=None)
+ parser.add_argument("--test_folder", type=str, default="stage1")
+ parser.add_argument("--seed", type=int, default=None)
+ parser.add_argument("--sampler", type=str, default="ddpm")
+ parser.add_argument("--samples", type=int, default=1)
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--steps", type=int, default=1000)
+ parser.add_argument("--text", nargs='+', default='a robot')
+ parser.add_argument("--text_file", type=str, default=None)
+ parser.add_argument("--no_video", action='store_true', default=False)
+ parser.add_argument("--render_res", type=int, default=128)
+ parser.add_argument("--no_mcubes", action='store_true', default=False)
+ parser.add_argument("--mcubes_res", type=int, default=128)
+ parser.add_argument("--cfg_scale", type=float, default=1)
+ args = parser.parse_args()
+
+ if args.text is not None:
+ text = [' '.join(args.text),]
+ elif args.text_file is not None:
+ if args.text_file.endswith('.json'):
+ with open(args.text_file, 'r') as f:
+ json_file = json.load(f)
+ text = json_file
+ text = [l.strip('.') for l in text]
+ else:
+ with open(args.text_file, 'r') as f:
+ text = f.readlines()
+ text = [l.strip() for l in text]
+ else:
+ raise NotImplementedError
+
+ print(text)
+
+ configs = OmegaConf.load(args.config)
+ if args.seed is not None:
+ pl.seed_everything(args.seed)
+
+ log_dir = os.path.join('results', args.config.split('/')[-1].split('.')[0], args.test_folder)
+ os.makedirs(log_dir, exist_ok=True)
+
+ if args.ckpt == None:
+ ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
+ else:
+ ckpt = args.ckpt
+
+ if ckpt.endswith(".ckpt"):
+ model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
+ elif ckpt.endswith(".safetensors"):
+ model = get_obj_from_str(configs.model["target"])(**configs.model.params)
+ model_ckpt = load_file(ckpt)
+ model.load_state_dict(model_ckpt)
+ else:
+ raise NotImplementedError
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ class DummySampler:
+ def __init__(self, model):
+ self.model = model
+
+ def sample(self, S, batch_size, shape, verbose, conditioning=None, *args, **kwargs):
+ return self.model.sample(
+ conditioning, batch_size, shape=[batch_size, ] + shape, *args, **kwargs
+ ), None
+
+ if args.sampler == 'dpm':
+ raise NotImplementedError
+ # sampler = DPMSolverSampler(model)
+ elif args.sampler == 'plms':
+ raise NotImplementedError
+ # sampler = PLMSSampler(model)
+ elif args.sampler == 'ddim':
+ sampler = DDIMSampler(model)
+ elif args.sampler == 'ddpm':
+ sampler = DummySampler(model)
+ else:
+ raise NotImplementedError
+
+ img_size = configs.model.params.unet_config.params.image_size
+ channels = configs.model.params.unet_config.params.in_channels
+ shape = [channels, img_size, img_size * 3]
+ plane_axes = generate_planes()
+
+ pose_folder = 'assets/sample_data/pose'
+ poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)])
+ batch_rays_list = []
+ H = args.render_res
+ ratio = 512 // H
+ for p in poses_fname:
+ c2w = np.loadtxt(p).reshape(4, 4)
+ c2w[:3, 3] *= 2.2
+ c2w = np.array([
+ [1, 0, 0, 0],
+ [0, 0, -1, 0],
+ [0, 1, 0, 0],
+ [0, 0, 0, 1]
+ ]) @ c2w
+
+ k = np.array([
+ [560 / ratio, 0, H * 0.5],
+ [0, 560 / ratio, H * 0.5],
+ [0, 0, 1]
+ ])
+
+ rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4]))
+ coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1)
+ coords = torch.reshape(coords, [-1,2]).long()
+ rays_o = rays_o[coords[:, 0], coords[:, 1]]
+ rays_d = rays_d[coords[:, 0], coords[:, 1]]
+ batch_rays = torch.stack([rays_o, rays_d], 0)
+ batch_rays_list.append(batch_rays)
+ batch_rays_list = torch.stack(batch_rays_list, 0)
+
+ for text_idx, text_i in enumerate(text):
+ text_connect = '_'.join(text_i.split(' '))
+ for s in range(args.samples):
+ batch_size = args.batch_size
+ with torch.no_grad():
+ # with model.ema_scope():
+ noise = None
+ c = model.get_learned_conditioning([text_i])
+ unconditional_c = torch.zeros_like(c)
+ if args.cfg_scale != 1:
+ assert args.sampler == 'ddim'
+ sample, _ = sampler.sample(
+ S=args.steps,
+ batch_size=batch_size,
+ shape=shape,
+ verbose=False,
+ x_T = noise,
+ conditioning = c.repeat(batch_size, 1, 1),
+ unconditional_guidance_scale=args.cfg_scale,
+ unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1)
+ )
+ else:
+ sample, _ = sampler.sample(
+ S=args.steps,
+ batch_size=batch_size,
+ shape=shape,
+ verbose=False,
+ x_T = noise,
+ conditioning = c.repeat(batch_size, 1, 1),
+ )
+ decode_res = model.decode_first_stage(sample)
+
+ for b in range(batch_size):
+ def render_img(v):
+ rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder(
+ decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device),
+ )
+ rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0]
+ rgb_sample = np.stack(
+ [rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1
+ )
+ # rgb_sample = add_text(rgb_sample, text_i)
+ return rgb_sample
+
+ if not args.no_mcubes:
+ # prepare volumn for marching cube
+ res = args.mcubes_res
+ c_list = torch.linspace(-1.2, 1.2, steps=res)
+ grid_x, grid_y, grid_z = torch.meshgrid(
+ c_list, c_list, c_list, indexing='ij'
+ )
+ coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device)
+ plane_axes = generate_planes()
+ feats = sample_from_planes(
+ plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4
+ )
+ fake_dirs = torch.zeros_like(coords)
+ fake_dirs[..., 0] = 1
+ out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs)
+ u = out['sigma'].reshape(res, res, res).detach().cpu().numpy()
+ del out
+
+ # marching cube
+ vertices, triangles = mcubes.marching_cubes(u, 10)
+ min_bound = np.array([-1.2, -1.2, -1.2])
+ max_bound = np.array([1.2, 1.2, 1.2])
+ vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :]
+ pt_vertices = torch.from_numpy(vertices).to(device)
+
+ # extract vertices color
+ res_triplane = 256
+ render_kwargs = {
+ 'depth_resolution': 128,
+ 'disparity_space_sampling': False,
+ 'box_warp': 2.4,
+ 'depth_resolution_importance': 128,
+ 'clamp_mode': 'softplus',
+ 'white_back': True,
+ 'det': True
+ }
+ rays_o_list = [
+ np.array([0, 0, 2]),
+ np.array([0, 0, -2]),
+ np.array([0, 2, 0]),
+ np.array([0, -2, 0]),
+ np.array([2, 0, 0]),
+ np.array([-2, 0, 0]),
+ ]
+ rgb_final = None
+ diff_final = None
+ for rays_o in tqdm(rays_o_list):
+ rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
+ rays_d = pt_vertices.reshape(-1, 3) - rays_o
+ rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
+ dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1)
+
+ render_out = model.first_stage_model.triplane_decoder(
+ decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane),
+ rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs,
+ whole_img=False, tvloss=False
+ )
+ rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
+ depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy()
+ depth_diff = np.abs(dist - depth)
+
+ if rgb_final is None:
+ rgb_final = rgb.copy()
+ diff_final = depth_diff.copy()
+
+ else:
+ ind = diff_final > depth_diff
+ rgb_final[ind] = rgb[ind]
+ diff_final[ind] = depth_diff[ind]
+
+
+ # bgr to rgb
+ rgb_final = np.stack([
+ rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0]
+ ], -1)
+
+ # export to ply
+ mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8))
+ trimesh.exchange.export.export_mesh(mesh, os.path.join(log_dir, f"{text_connect}_{s}_{b}.ply"), file_type='ply')
+
+ if not args.no_video:
+ view_num = len(batch_rays_list)
+ video_list = []
+ for v in tqdm(range(view_num//4, view_num//4 * 3, 2)):
+ rgb_sample = render_img(v)
+ video_list.append(rgb_sample)
+ imageio.mimwrite(os.path.join(log_dir, "{}_{}_{}.mp4".format(text_connect, s, b)), np.stack(video_list, 0))
+ else:
+ rgb_sample = render_img(104)
+ imageio.imwrite(os.path.join(log_dir, "{}_{}_{}.jpg".format(text_connect, s, b)), rgb_sample)
+
+if __name__ == '__main__':
+ main()
diff --git a/3DTopia/taming/data/ade20k.py b/3DTopia/taming/data/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..366dae97207dbb8356598d636e14ad084d45bc76
--- /dev/null
+++ b/3DTopia/taming/data/ade20k.py
@@ -0,0 +1,124 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/ade20k_examples.txt",
+ data_root="data/ade20k_images",
+ segmentation_root="data/ade20k_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=151, shift_segmentation=False)
+
+
+# With semantic map and scene label
+class ADE20kBase(Dataset):
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
+ self.split = self.get_split()
+ self.n_labels = 151 # unknown + 150
+ self.data_csv = {"train": "data/ade20k_train.txt",
+ "validation": "data/ade20k_test.txt"}[self.split]
+ self.data_root = "data/ade20k_root"
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
+ self.scene_categories = f.read().splitlines()
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, "images", l)
+ for l in self.image_paths],
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
+ l.replace(".jpg", ".png"))
+ for l in self.image_paths],
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
+ for l in self.image_paths],
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size if size is not None else None
+ else:
+ self.crop_size = crop_size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+
+ if crop_size is not None:
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image, mask=segmentation)
+ else:
+ processed = {"image": image, "mask": segmentation}
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class ADE20kTrain(ADE20kBase):
+ # default to random_crop=True
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ interpolation=interpolation, crop_size=crop_size)
+
+ def get_split(self):
+ return "train"
+
+
+class ADE20kValidation(ADE20kBase):
+ def get_split(self):
+ return "validation"
+
+
+if __name__ == "__main__":
+ dset = ADE20kValidation()
+ ex = dset[0]
+ for k in ["image", "scene_category", "segmentation"]:
+ print(type(ex[k]))
+ try:
+ print(ex[k].shape)
+ except:
+ print(ex[k])
diff --git a/3DTopia/taming/data/annotated_objects_coco.py b/3DTopia/taming/data/annotated_objects_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..af000ecd943d7b8a85d7eb70195c9ecd10ab5edc
--- /dev/null
+++ b/3DTopia/taming/data/annotated_objects_coco.py
@@ -0,0 +1,139 @@
+import json
+from itertools import chain
+from pathlib import Path
+from typing import Iterable, Dict, List, Callable, Any
+from collections import defaultdict
+
+from tqdm import tqdm
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, ImageDescription, Category
+
+COCO_PATH_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_train2017.json',
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
+ 'files': 'train2017'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_val2017.json',
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
+ 'files': 'val2017'
+ }
+}
+
+
+def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
+ return {
+ str(img['id']): ImageDescription(
+ id=img['id'],
+ license=img.get('license'),
+ file_name=img['file_name'],
+ coco_url=img['coco_url'],
+ original_size=(img['width'], img['height']),
+ date_captured=img.get('date_captured'),
+ flickr_url=img.get('flickr_url')
+ )
+ for img in description_json
+ }
+
+
+def load_categories(category_json: Iterable) -> Dict[str, Category]:
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
+ for cat in category_json if cat['name'] != 'other'}
+
+
+def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
+ annotations = defaultdict(list)
+ total = sum(len(a) for a in annotations_json)
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
+ image_id = str(ann['image_id'])
+ if image_id not in image_descriptions:
+ raise ValueError(f'image_id [{image_id}] has no image description.')
+ category_id = ann['category_id']
+ try:
+ category_no = category_no_for_id(str(category_id))
+ except KeyError:
+ continue
+
+ width, height = image_descriptions[image_id].original_size
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
+
+ annotations[image_id].append(
+ Annotation(
+ id=ann['id'],
+ area=bbox[2]*bbox[3], # use bbox area
+ is_group_of=ann['iscrowd'],
+ image_id=ann['image_id'],
+ bbox=bbox,
+ category_id=str(category_id),
+ category_no=category_no
+ )
+ )
+ return dict(annotations)
+
+
+class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ coco/
+ ├── annotations
+ │ ├── instances_train2017.json
+ │ ├── instances_val2017.json
+ │ ├── stuff_train2017.json
+ │ └── stuff_val2017.json
+ ├── train2017
+ │ ├── 000000000009.jpg
+ │ ├── 000000000025.jpg
+ │ └── ...
+ ├── val2017
+ │ ├── 000000000139.jpg
+ │ ├── 000000000285.jpg
+ │ └── ...
+ @param: split: one of 'train' or 'validation'
+ @param: desired image size (give square images)
+ """
+ super().__init__(**kwargs)
+ self.use_things = use_things
+ self.use_stuff = use_stuff
+
+ with open(self.paths['instances_annotations']) as f:
+ inst_data_json = json.load(f)
+ with open(self.paths['stuff_annotations']) as f:
+ stuff_data_json = json.load(f)
+
+ category_jsons = []
+ annotation_jsons = []
+ if self.use_things:
+ category_jsons.append(inst_data_json['categories'])
+ annotation_jsons.append(inst_data_json['annotations'])
+ if self.use_stuff:
+ category_jsons.append(stuff_data_json['categories'])
+ annotation_jsons.append(stuff_data_json['annotations'])
+
+ self.categories = load_categories(chain(*category_jsons))
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
+ self.min_objects_per_image, self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in COCO_PATH_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
+ return COCO_PATH_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ # noinspection PyProtectedMember
+ return self.image_descriptions[image_id]._asdict()
diff --git a/3DTopia/taming/data/annotated_objects_dataset.py b/3DTopia/taming/data/annotated_objects_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cc346a1c76289a4964d7dc8a29582172f33dc0
--- /dev/null
+++ b/3DTopia/taming/data/annotated_objects_dataset.py
@@ -0,0 +1,218 @@
+from pathlib import Path
+from typing import Optional, List, Callable, Dict, Any, Union
+import warnings
+
+import PIL.Image as pil_image
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import load_object_from_string
+from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
+from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
+
+
+class AnnotatedObjectsDataset(Dataset):
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
+ no_object_classes: Optional[int] = None):
+ self.data_path = data_path
+ self.split = split
+ self.keys = keys
+ self.target_image_size = target_image_size
+ self.min_object_area = min_object_area
+ self.min_objects_per_image = min_objects_per_image
+ self.max_objects_per_image = max_objects_per_image
+ self.crop_method = crop_method
+ self.random_flip = random_flip
+ self.no_tokens = no_tokens
+ self.use_group_parameter = use_group_parameter
+ self.encode_crop = encode_crop
+
+ self.annotations = None
+ self.image_descriptions = None
+ self.categories = None
+ self.category_ids = None
+ self.category_number = None
+ self.image_ids = None
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
+ self.paths = self.build_paths(self.data_path)
+ self._conditional_builders = None
+ self.category_allow_list = None
+ if category_allow_list_target:
+ allow_list = load_object_from_string(category_allow_list_target)
+ self.category_allow_list = {name for name, _ in allow_list}
+ self.category_mapping = {}
+ if category_mapping_target:
+ self.category_mapping = load_object_from_string(category_mapping_target)
+ self.no_object_classes = no_object_classes
+
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
+ top_level = Path(top_level)
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
+ for path in sub_paths.values():
+ if not path.exists():
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
+ return sub_paths
+
+ @staticmethod
+ def load_image_from_disk(path: Path) -> Image:
+ return pil_image.open(path).convert('RGB')
+
+ @staticmethod
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
+ transform_functions = []
+ if crop_method == 'none':
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
+ elif crop_method == 'center':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ CenterCropReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-1d':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ RandomCrop1dReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-2d':
+ transform_functions.extend([
+ Random2dCropReturnCoordinates(target_image_size),
+ transforms.Resize(target_image_size)
+ ])
+ elif crop_method is None:
+ return None
+ else:
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
+ if random_flip:
+ transform_functions.append(RandomHorizontalFlipReturn())
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
+ return transform_functions
+
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
+ crop_bbox = None
+ flipped = None
+ for t in self.transform_functions:
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
+ crop_bbox, x = t(x)
+ elif isinstance(t, RandomHorizontalFlipReturn):
+ flipped, x = t(x)
+ else:
+ x = t(x)
+ return crop_bbox, flipped, x
+
+ @property
+ def no_classes(self) -> int:
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
+
+ @property
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
+ if self._conditional_builders is None:
+ self._conditional_builders = {
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ ),
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ )
+ }
+ return self._conditional_builders
+
+ def filter_categories(self) -> None:
+ if self.category_allow_list:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
+ if self.category_mapping:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
+
+ def setup_category_id_and_number(self) -> None:
+ self.category_ids = list(self.categories.keys())
+ self.category_ids.sort()
+ if '/m/01s55n' in self.category_ids:
+ self.category_ids.remove('/m/01s55n')
+ self.category_ids.append('/m/01s55n')
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
+ if self.category_allow_list is not None and self.category_mapping is None \
+ and len(self.category_ids) != len(self.category_allow_list):
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
+ 'Make sure all names in category_allow_list exist.')
+
+ def clean_up_annotations_and_image_descriptions(self) -> None:
+ image_id_set = set(self.image_ids)
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
+
+ @staticmethod
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
+ filtered = {}
+ for image_id, annotations in all_annotations.items():
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
+ filtered[image_id] = annotations_with_min_area
+ return filtered
+
+ def __len__(self):
+ return len(self.image_ids)
+
+ def __getitem__(self, n: int) -> Dict[str, Any]:
+ image_id = self.get_image_id(n)
+ sample = self.get_image_description(image_id)
+ sample['annotations'] = self.get_annotation(image_id)
+
+ if 'image' in self.keys:
+ sample['image_path'] = str(self.get_image_path(image_id))
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
+ sample['image'] = convert_pil_to_tensor(sample['image'])
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
+ sample['image'] = sample['image'].permute(1, 2, 0)
+
+ for conditional, builder in self.conditional_builders.items():
+ if conditional in self.keys:
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
+
+ if self.keys:
+ # only return specified keys
+ sample = {key: sample[key] for key in self.keys}
+ return sample
+
+ def get_image_id(self, no: int) -> str:
+ return self.image_ids[no]
+
+ def get_annotation(self, image_id: str) -> str:
+ return self.annotations[image_id]
+
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
+ return self.categories[category_id].name
+
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
+ return self.categories[self.get_category_id(category_no)].name
+
+ def get_category_number(self, category_id: str) -> int:
+ return self.category_number[category_id]
+
+ def get_category_id(self, category_no: int) -> str:
+ return self.category_ids[category_no]
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ raise NotImplementedError()
+
+ def get_path_structure(self):
+ raise NotImplementedError
+
+ def get_image_path(self, image_id: str) -> Path:
+ raise NotImplementedError
diff --git a/3DTopia/taming/data/annotated_objects_open_images.py b/3DTopia/taming/data/annotated_objects_open_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..aede6803d2cef7a74ca784e7907d35fba6c71239
--- /dev/null
+++ b/3DTopia/taming/data/annotated_objects_open_images.py
@@ -0,0 +1,137 @@
+from collections import defaultdict
+from csv import DictReader, reader as TupleReader
+from pathlib import Path
+from typing import Dict, List, Any
+import warnings
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, Category
+from tqdm import tqdm
+
+OPEN_IMAGES_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
+ 'file_list': 'train-images-boxable.csv',
+ 'files': 'train'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'validation-annotations-bbox.csv',
+ 'file_list': 'validation-images.csv',
+ 'files': 'validation'
+ },
+ 'test': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'test-annotations-bbox.csv',
+ 'file_list': 'test-images.csv',
+ 'files': 'test'
+ }
+}
+
+
+def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
+ with open(descriptor_path) as file:
+ reader = DictReader(file)
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
+ width = float(row['XMax']) - float(row['XMin'])
+ height = float(row['YMax']) - float(row['YMin'])
+ area = width * height
+ category_id = row['LabelName']
+ if category_id in category_mapping:
+ category_id = category_mapping[category_id]
+ if area >= min_object_area and category_id in category_no_for_id:
+ annotations[row['ImageID']].append(
+ Annotation(
+ id=i,
+ image_id=row['ImageID'],
+ source=row['Source'],
+ category_id=category_id,
+ category_no=category_no_for_id[category_id],
+ confidence=float(row['Confidence']),
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
+ area=area,
+ is_occluded=bool(int(row['IsOccluded'])),
+ is_truncated=bool(int(row['IsTruncated'])),
+ is_group_of=bool(int(row['IsGroupOf'])),
+ is_depiction=bool(int(row['IsDepiction'])),
+ is_inside=bool(int(row['IsInside']))
+ )
+ )
+ if 'train' in str(descriptor_path) and i < 14000000:
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
+ return dict(annotations)
+
+
+def load_image_ids(csv_path: Path) -> List[str]:
+ with open(csv_path) as file:
+ reader = DictReader(file)
+ return [row['image_name'] for row in reader]
+
+
+def load_categories(csv_path: Path) -> Dict[str, Category]:
+ with open(csv_path) as file:
+ reader = TupleReader(file)
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
+
+
+class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
+ def __init__(self, use_additional_parameters: bool, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ open_images/
+ │ oidv6-train-annotations-bbox.csv
+ ├── class-descriptions-boxable.csv
+ ├── oidv6-train-annotations-bbox.csv
+ ├── test
+ │ ├── 000026e7ee790996.jpg
+ │ ├── 000062a39995e348.jpg
+ │ └── ...
+ ├── test-annotations-bbox.csv
+ ├── test-images.csv
+ ├── train
+ │ ├── 000002b66c9c498e.jpg
+ │ ├── 000002b97e5471a0.jpg
+ │ └── ...
+ ├── train-images-boxable.csv
+ ├── validation
+ │ ├── 0001eeaf4aed83f9.jpg
+ │ ├── 0004886b7d043cfd.jpg
+ │ └── ...
+ ├── validation-annotations-bbox.csv
+ └── validation-images.csv
+ @param: split: one of 'train', 'validation' or 'test'
+ @param: desired image size (returns square images)
+ """
+
+ super().__init__(**kwargs)
+ self.use_additional_parameters = use_additional_parameters
+
+ self.categories = load_categories(self.paths['class_descriptions'])
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = {}
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
+ self.category_number)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
+ self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in OPEN_IMAGES_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
+ return OPEN_IMAGES_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ image_path = self.get_image_path(image_id)
+ return {'file_path': str(image_path), 'file_name': image_path.name}
diff --git a/3DTopia/taming/data/base.py b/3DTopia/taming/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7
--- /dev/null
+++ b/3DTopia/taming/data/base.py
@@ -0,0 +1,70 @@
+import bisect
+import numpy as np
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset, ConcatDataset
+
+
+class ConcatDatasetWithIndex(ConcatDataset):
+ """Modified from original pytorch code to return dataset idx"""
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
+
+
+class ImagePaths(Dataset):
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
+ self.size = size
+ self.random_crop = random_crop
+
+ self.labels = dict() if labels is None else labels
+ self.labels["file_path_"] = paths
+ self._length = len(paths)
+
+ if self.size is not None and self.size > 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ if not self.random_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return self._length
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def __getitem__(self, i):
+ example = dict()
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
+ for k in self.labels:
+ example[k] = self.labels[k][i]
+ return example
+
+
+class NumpyPaths(ImagePaths):
+ def preprocess_image(self, image_path):
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
+ image = np.transpose(image, (1,2,0))
+ image = Image.fromarray(image, mode="RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
diff --git a/3DTopia/taming/data/coco.py b/3DTopia/taming/data/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b2f7838448cb63dcf96daffe9470d58566d975a
--- /dev/null
+++ b/3DTopia/taming/data/coco.py
@@ -0,0 +1,176 @@
+import os
+import json
+import albumentations
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/coco_examples.txt",
+ data_root="data/coco_images",
+ segmentation_root="data/coco_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=183, shift_segmentation=True)
+
+
+class CocoBase(Dataset):
+ """needed for (image, caption, segmentation) pairs"""
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
+ crop_size=None, force_no_crop=False, given_files=None):
+ self.split = self.get_split()
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size
+ else:
+ self.crop_size = crop_size
+
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
+ self.stuffthing = use_stuffthing # include thing in segmentation
+ if self.onehot and not self.stuffthing:
+ raise NotImplemented("One hot mode is only supported for the "
+ "stuffthings version because labels are stored "
+ "a bit different.")
+
+ data_json = datajson
+ with open(data_json) as json_file:
+ self.json_data = json.load(json_file)
+ self.img_id_to_captions = dict()
+ self.img_id_to_filepath = dict()
+ self.img_id_to_segmentation_filepath = dict()
+
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
+ "captions_val2017.json"]
+ if self.stuffthing:
+ self.segmentation_prefix = (
+ "data/cocostuffthings/val2017" if
+ data_json.endswith("captions_val2017.json") else
+ "data/cocostuffthings/train2017")
+ else:
+ self.segmentation_prefix = (
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
+ data_json.endswith("captions_val2017.json") else
+ "data/coco/annotations/stuff_train2017_pixelmaps")
+
+ imagedirs = self.json_data["images"]
+ self.labels = {"image_ids": list()}
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
+ self.img_id_to_captions[imgdir["id"]] = list()
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
+ self.segmentation_prefix, pngfilename)
+ if given_files is not None:
+ if pngfilename in given_files:
+ self.labels["image_ids"].append(imgdir["id"])
+ else:
+ self.labels["image_ids"].append(imgdir["id"])
+
+ capdirs = self.json_data["annotations"]
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
+ # there are in average 5 captions per image
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
+
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
+ if self.split=="validation":
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler, self.cropper],
+ additional_targets={"segmentation": "image"})
+ if force_no_crop:
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler],
+ additional_targets={"segmentation": "image"})
+
+ def __len__(self):
+ return len(self.labels["image_ids"])
+
+ def preprocess_image(self, image_path, segmentation_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+
+ segmentation = Image.open(segmentation_path)
+ if not self.onehot and not segmentation.mode == "RGB":
+ segmentation = segmentation.convert("RGB")
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.onehot:
+ assert self.stuffthing
+ # stored in caffe format: unlabeled==255. stuff and thing from
+ # 0-181. to be compatible with the labels in
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
+ # we shift stuffthing one to the right and put unlabeled in zero
+ # as long as segmentation is uint8 shifting to right handles the
+ # latter too
+ assert segmentation.dtype == np.uint8
+ segmentation = segmentation + 1
+
+ processed = self.preprocessor(image=image, segmentation=segmentation)
+ image, segmentation = processed["image"], processed["segmentation"]
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ if self.onehot:
+ assert segmentation.dtype == np.uint8
+ # make it one hot
+ n_labels = 183
+ flatseg = np.ravel(segmentation)
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
+ onehot[np.arange(flatseg.size), flatseg] = True
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
+ segmentation = onehot
+ else:
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
+ return image, segmentation
+
+ def __getitem__(self, i):
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
+ image, segmentation = self.preprocess_image(img_path, seg_path)
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
+ # randomly draw one of all available captions per image
+ caption = captions[np.random.randint(0, len(captions))]
+ example = {"image": image,
+ "caption": [str(caption[0])],
+ "segmentation": segmentation,
+ "img_path": img_path,
+ "seg_path": seg_path,
+ "filename_": img_path.split(os.sep)[-1]
+ }
+ return example
+
+
+class CocoImagesAndCaptionsTrain(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
+ super().__init__(size=size,
+ dataroot="data/coco/train2017",
+ datajson="data/coco/annotations/captions_train2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
+
+ def get_split(self):
+ return "train"
+
+
+class CocoImagesAndCaptionsValidation(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
+ given_files=None):
+ super().__init__(size=size,
+ dataroot="data/coco/val2017",
+ datajson="data/coco/annotations/captions_val2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
+ given_files=given_files)
+
+ def get_split(self):
+ return "validation"
diff --git a/3DTopia/taming/data/conditional_builder/objects_bbox.py b/3DTopia/taming/data/conditional_builder/objects_bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..15881e76b7ab2a914df8f2dfe08ae4f0c6c511b5
--- /dev/null
+++ b/3DTopia/taming/data/conditional_builder/objects_bbox.py
@@ -0,0 +1,60 @@
+from itertools import cycle
+from typing import List, Tuple, Callable, Optional
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
+ pad_list, get_plot_font_size, absolute_bbox
+
+
+class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
+ @property
+ def object_descriptor_length(self) -> int:
+ return 3
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_triples = [
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
+ for ann in annotations
+ ]
+ empty_triple = (self.none, self.none, self.none)
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
+ return object_triples
+
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ object_triples = grouper(conditional_list, 3)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
+ for object_triple in object_triples if object_triple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ font = ImageFont.truetype(
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
+ size=get_plot_font_size(font_size, figure_size)
+ )
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
+ annotation = self.representation_to_annotation(representation)
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
+ bbox = absolute_bbox(bbox, width, height)
+ draw.rectangle(bbox, outline=color, width=line_width)
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
diff --git a/3DTopia/taming/data/conditional_builder/objects_center_points.py b/3DTopia/taming/data/conditional_builder/objects_center_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a480329cc47fb38a7b8729d424e092b77d40749
--- /dev/null
+++ b/3DTopia/taming/data/conditional_builder/objects_center_points.py
@@ -0,0 +1,168 @@
+import math
+import random
+import warnings
+from itertools import cycle
+from typing import List, Optional, Tuple, Callable
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
+ absolute_bbox, rescale_annotations
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+
+class ObjectsCenterPointsConditionalBuilder:
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
+ use_group_parameter: bool, use_additional_parameters: bool):
+ self.no_object_classes = no_object_classes
+ self.no_max_objects = no_max_objects
+ self.no_tokens = no_tokens
+ self.encode_crop = encode_crop
+ self.no_sections = int(math.sqrt(self.no_tokens))
+ self.use_group_parameter = use_group_parameter
+ self.use_additional_parameters = use_additional_parameters
+
+ @property
+ def none(self) -> int:
+ return self.no_tokens - 1
+
+ @property
+ def object_descriptor_length(self) -> int:
+ return 2
+
+ @property
+ def embedding_dim(self) -> int:
+ extra_length = 2 if self.encode_crop else 0
+ return self.no_max_objects * self.object_descriptor_length + extra_length
+
+ def tokenize_coordinates(self, x: float, y: float) -> int:
+ """
+ Express 2d coordinates with one number.
+ Example: assume self.no_tokens = 16, then no_sections = 4:
+ 0 0 0 0
+ 0 0 # 0
+ 0 0 0 0
+ 0 0 0 x
+ Then the # position corresponds to token 6, the x position to token 15.
+ @param x: float in [0, 1]
+ @param y: float in [0, 1]
+ @return: discrete tokenized coordinate
+ """
+ x_discrete = int(round(x * (self.no_sections - 1)))
+ y_discrete = int(round(y * (self.no_sections - 1)))
+ return y_discrete * self.no_sections + x_discrete
+
+ def coordinates_from_token(self, token: int) -> (float, float):
+ x = token % self.no_sections
+ y = token // self.no_sections
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
+
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
+ x0, y0 = self.coordinates_from_token(token1)
+ x1, y1 = self.coordinates_from_token(token2)
+ return x0, y0, x1 - x0, y1 - y0
+
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
+
+ def inverse_build(self, conditional: LongTensor) \
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
+ for object_tuple in table_of_content if object_tuple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ circle_size = get_circle_size(figure_size)
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
+ size=get_plot_font_size(font_size, figure_size))
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
+ x_abs, y_abs = x * width, y * height
+ ann = self.representation_to_annotation(representation)
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
+
+ def object_representation(self, annotation: Annotation) -> int:
+ modifier = 0
+ if self.use_group_parameter:
+ modifier |= 1 * (annotation.is_group_of is True)
+ if self.use_additional_parameters:
+ modifier |= 2 * (annotation.is_occluded is True)
+ modifier |= 4 * (annotation.is_depiction is True)
+ modifier |= 8 * (annotation.is_inside is True)
+ return annotation.category_no + self.no_object_classes * modifier
+
+ def representation_to_annotation(self, representation: int) -> Annotation:
+ category_no = representation % self.no_object_classes
+ modifier = representation // self.no_object_classes
+ # noinspection PyTypeChecker
+ return Annotation(
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
+ category_no=category_no,
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
+ )
+
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
+ return list(self.token_pair_from_bbox(crop_coordinates))
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_tuples = [
+ (self.object_representation(a),
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
+ for a in annotations
+ ]
+ empty_tuple = (self.none, self.none)
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
+ return object_tuples
+
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
+ -> LongTensor:
+ if len(annotations) == 0:
+ warnings.warn('Did not receive any annotations.')
+ if len(annotations) > self.no_max_objects:
+ warnings.warn('Received more annotations than allowed.')
+ annotations = annotations[:self.no_max_objects]
+
+ if not crop_coordinates:
+ crop_coordinates = FULL_CROP
+
+ random.shuffle(annotations)
+ annotations = filter_annotations(annotations, crop_coordinates)
+ if self.encode_crop:
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
+ if horizontal_flip:
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
+ extra = self._crop_encoder(crop_coordinates)
+ else:
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
+ extra = []
+
+ object_tuples = self._make_object_descriptors(annotations)
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
+ assert len(flattened) == self.embedding_dim
+ assert all(0 <= value < self.no_tokens for value in flattened)
+ return LongTensor(flattened)
diff --git a/3DTopia/taming/data/conditional_builder/utils.py b/3DTopia/taming/data/conditional_builder/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0ee175f2e05a80dbc71c22acbecb22dddadbb42
--- /dev/null
+++ b/3DTopia/taming/data/conditional_builder/utils.py
@@ -0,0 +1,105 @@
+import importlib
+from typing import List, Any, Tuple, Optional
+
+from taming.data.helper_types import BoundingBox, Annotation
+
+# source: seaborn, color palette tab10
+COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
+BLACK = (0, 0, 0)
+GRAY_75 = (63, 63, 63)
+GRAY_50 = (127, 127, 127)
+GRAY_25 = (191, 191, 191)
+WHITE = (255, 255, 255)
+FULL_CROP = (0., 0., 1., 1.)
+
+
+def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
+ """
+ Give intersection area of two rectangles.
+ @param rectangle1: (x0, y0, w, h) of first rectangle
+ @param rectangle2: (x0, y0, w, h) of second rectangle
+ """
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
+ return x_overlap * y_overlap
+
+
+def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
+
+
+def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
+ bbox = relative_bbox
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+
+
+def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
+
+
+def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
+ List[Annotation]:
+ def clamp(x: float):
+ return max(min(x, 1.), 0.)
+
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ if flip:
+ x0 = 1 - (x0 + w)
+ return x0, y0, w, h
+
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
+
+
+def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
+
+
+def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
+ sl = slice(1) if short else slice(None)
+ string = ''
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
+ return string
+ if annotation.is_group_of:
+ string += 'group'[sl] + ','
+ if annotation.is_occluded:
+ string += 'occluded'[sl] + ','
+ if annotation.is_depiction:
+ string += 'depiction'[sl] + ','
+ if annotation.is_inside:
+ string += 'inside'[sl]
+ return '(' + string.strip(",") + ')'
+
+
+def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
+ if font_size is None:
+ font_size = 10
+ if max(figure_size) >= 256:
+ font_size = 12
+ if max(figure_size) >= 512:
+ font_size = 15
+ return font_size
+
+
+def get_circle_size(figure_size: Tuple[int, int]) -> int:
+ circle_size = 2
+ if max(figure_size) >= 256:
+ circle_size = 3
+ if max(figure_size) >= 512:
+ circle_size = 4
+ return circle_size
+
+
+def load_object_from_string(object_string: str) -> Any:
+ """
+ Source: https://stackoverflow.com/a/10773699
+ """
+ module_name, class_name = object_string.rsplit(".", 1)
+ return getattr(importlib.import_module(module_name), class_name)
diff --git a/3DTopia/taming/data/custom.py b/3DTopia/taming/data/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f302a4b55ba1e8ec282ec3292b6263c06dfb91
--- /dev/null
+++ b/3DTopia/taming/data/custom.py
@@ -0,0 +1,38 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class CustomBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ return example
+
+
+
+class CustomTrain(CustomBase):
+ def __init__(self, size, training_images_list_file):
+ super().__init__()
+ with open(training_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
+class CustomTest(CustomBase):
+ def __init__(self, size, test_images_list_file):
+ super().__init__()
+ with open(test_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
diff --git a/3DTopia/taming/data/faceshq.py b/3DTopia/taming/data/faceshq.py
new file mode 100644
index 0000000000000000000000000000000000000000..6912d04b66a6d464c1078e4b51d5da290f5e767e
--- /dev/null
+++ b/3DTopia/taming/data/faceshq.py
@@ -0,0 +1,134 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class FacesBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+ self.keys = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ ex = {}
+ if self.keys is not None:
+ for k in self.keys:
+ ex[k] = example[k]
+ else:
+ ex = example
+ return ex
+
+
+class CelebAHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class CelebAHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FacesHQTrain(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQTrain(size=size, keys=keys)
+ d2 = FFHQTrain(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
+
+
+class FacesHQValidation(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQValidation(size=size, keys=keys)
+ d2 = FFHQValidation(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
diff --git a/3DTopia/taming/data/helper_types.py b/3DTopia/taming/data/helper_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79
--- /dev/null
+++ b/3DTopia/taming/data/helper_types.py
@@ -0,0 +1,49 @@
+from typing import Dict, Tuple, Optional, NamedTuple, Union
+from PIL.Image import Image as pil_image
+from torch import Tensor
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+Image = Union[Tensor, pil_image]
+BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
+CropMethodType = Literal['none', 'random', 'center', 'random-2d']
+SplitType = Literal['train', 'validation', 'test']
+
+
+class ImageDescription(NamedTuple):
+ id: int
+ file_name: str
+ original_size: Tuple[int, int] # w, h
+ url: Optional[str] = None
+ license: Optional[int] = None
+ coco_url: Optional[str] = None
+ date_captured: Optional[str] = None
+ flickr_url: Optional[str] = None
+ flickr_id: Optional[str] = None
+ coco_id: Optional[str] = None
+
+
+class Category(NamedTuple):
+ id: str
+ super_category: Optional[str]
+ name: str
+
+
+class Annotation(NamedTuple):
+ area: float
+ image_id: str
+ bbox: BoundingBox
+ category_no: int
+ category_id: str
+ id: Optional[int] = None
+ source: Optional[str] = None
+ confidence: Optional[float] = None
+ is_group_of: Optional[bool] = None
+ is_truncated: Optional[bool] = None
+ is_occluded: Optional[bool] = None
+ is_depiction: Optional[bool] = None
+ is_inside: Optional[bool] = None
+ segmentation: Optional[Dict] = None
diff --git a/3DTopia/taming/data/image_transforms.py b/3DTopia/taming/data/image_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..657ac332174e0ac72f68315271ffbd757b771a0f
--- /dev/null
+++ b/3DTopia/taming/data/image_transforms.py
@@ -0,0 +1,132 @@
+import random
+import warnings
+from typing import Union
+
+import torch
+from torch import Tensor
+from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
+from torchvision.transforms.functional import _get_image_size as get_image_size
+
+from taming.data.helper_types import BoundingBox, Image
+
+pil_to_tensor = PILToTensor()
+
+
+def convert_pil_to_tensor(image: Image) -> Tensor:
+ with warnings.catch_warnings():
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
+ warnings.simplefilter("ignore")
+ return pil_to_tensor(image)
+
+
+class RandomCrop1dReturnCoordinates(RandomCrop):
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ width, height = get_image_size(img)
+ # pad the width if needed
+ if self.pad_if_needed and width < self.size[1]:
+ padding = [self.size[1] - width, 0]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and height < self.size[0]:
+ padding = [0, self.size[0] - height]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
+ return bbox, F.crop(img, i, j, h, w)
+
+
+class Random2dCropReturnCoordinates(torch.nn.Module):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+
+ def __init__(self, min_size: int):
+ super().__init__()
+ self.min_size = min_size
+
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ width, height = get_image_size(img)
+ max_size = min(width, height)
+ if max_size <= self.min_size:
+ size = max_size
+ else:
+ size = random.randint(self.min_size, max_size)
+ top = random.randint(0, height - size)
+ left = random.randint(0, width - size)
+ bbox = left / width, top / height, size / width, size / height
+ return bbox, F.crop(img, top, left, size, size)
+
+
+class CenterCropReturnCoordinates(CenterCrop):
+ @staticmethod
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
+ if width > height:
+ w = height / width
+ h = 1.0
+ x0 = 0.5 - w / 2
+ y0 = 0.
+ else:
+ w = 1.0
+ h = width / height
+ x0 = 0.
+ y0 = 0.5 - h / 2
+ return x0, y0, w, h
+
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ width, height = get_image_size(img)
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
+
+
+class RandomHorizontalFlipReturn(RandomHorizontalFlip):
+ def forward(self, img: Image) -> (bool, Image):
+ """
+ Additionally to flipping, returns a boolean whether it was flipped or not.
+ Args:
+ img (PIL Image or Tensor): Image to be flipped.
+
+ Returns:
+ flipped: whether the image was flipped or not
+ PIL Image or Tensor: Randomly flipped image.
+
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ if torch.rand(1) < self.p:
+ return True, F.hflip(img)
+ return False, img
diff --git a/3DTopia/taming/data/imagenet.py b/3DTopia/taming/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a02ec44ba4af9e993f58c91fa43482a4ecbe54c
--- /dev/null
+++ b/3DTopia/taming/data/imagenet.py
@@ -0,0 +1,558 @@
+import os, tarfile, glob, shutil
+import yaml
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+import albumentations
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths
+from taming.util import download, retrieve
+import taming.data.utils as bdu
+
+
+def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
+ synsets = []
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ for idx in indices:
+ synsets.append(str(di2s[idx]))
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
+ return synsets
+
+
+def str_to_indices(string):
+ """Expects a string in the format '32-123, 256, 280-321'"""
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
+ subs = string.split(",")
+ indices = []
+ for sub in subs:
+ subsubs = sub.split("-")
+ assert len(subsubs) > 0
+ if len(subsubs) == 1:
+ indices.append(int(subsubs[0]))
+ else:
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
+ indices.extend(rang)
+ return sorted(indices)
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ self.class_labels = [class_dict[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=retrieve(self.config, "size", default=0),
+ random_crop=self.random_crop)
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+def get_preprocessor(size=None, random_crop=False, additional_targets=None,
+ crop_size=None):
+ if size is not None and size > 0:
+ transforms = list()
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
+ transforms.append(rescaler)
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=size,width=size)
+ transforms.append(cropper)
+ else:
+ cropper = albumentations.RandomCrop(height=size,width=size)
+ transforms.append(cropper)
+ flipper = albumentations.HorizontalFlip()
+ transforms.append(flipper)
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ elif crop_size is not None and crop_size > 0:
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ transforms = [cropper]
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ preprocessor = lambda **kwargs: kwargs
+ return preprocessor
+
+
+def rgba_to_depth(x):
+ assert x.dtype == np.uint8
+ assert len(x.shape) == 3 and x.shape[2] == 4
+ y = x.copy()
+ y.dtype = np.float32
+ y = y.reshape(x.shape[:2])
+ return np.ascontiguousarray(y)
+
+
+class BaseWithDepth(Dataset):
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
+
+ def __init__(self, config=None, size=None, random_crop=False,
+ crop_size=None, root=None):
+ self.config = config
+ self.base_dset = self.get_base_dset()
+ self.preprocessor = get_preprocessor(
+ size=size,
+ crop_size=crop_size,
+ random_crop=random_crop,
+ additional_targets={"depth": "image"})
+ self.crop_size = crop_size
+ if self.crop_size is not None:
+ self.rescaler = albumentations.Compose(
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
+ additional_targets={"depth": "image"})
+ if root is not None:
+ self.DEFAULT_DEPTH_ROOT = root
+
+ def __len__(self):
+ return len(self.base_dset)
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = self.base_dset[i]
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
+ # up if necessary
+ h,w,c = e["image"].shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ out = self.rescaler(image=e["image"], depth=e["depth"])
+ e["image"] = out["image"]
+ e["depth"] = out["depth"]
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+class ImageNetTrainWithDepth(BaseWithDepth):
+ # default to random_crop=True
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetTrain()
+ else:
+ return ImageNetTrain({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
+ return fid
+
+
+class ImageNetValidationWithDepth(BaseWithDepth):
+ def __init__(self, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(**kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetValidation()
+ else:
+ return ImageNetValidation({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
+ return fid
+
+
+class RINTrainWithDepth(ImageNetTrainWithDepth):
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class RINValidationWithDepth(ImageNetValidationWithDepth):
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class DRINExamples(Dataset):
+ def __init__(self):
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
+ with open("data/drin_examples.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ self.image_paths = [os.path.join("data/drin_images",
+ relpath) for relpath in relpaths]
+ self.depth_paths = [os.path.join("data/drin_depth",
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = dict()
+ e["image"] = self.preprocess_image(self.image_paths[i])
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
+ if factor is None or factor==1:
+ return x
+
+ dtype = x.dtype
+ assert dtype in [np.float32, np.float64]
+ assert x.min() >= -1
+ assert x.max() <= 1
+
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
+ "bicubic": Image.BICUBIC}[keepmode]
+
+ lr = (x+1.0)*127.5
+ lr = lr.clip(0,255).astype(np.uint8)
+ lr = Image.fromarray(lr)
+
+ h, w, _ = x.shape
+ nh = h//factor
+ nw = w//factor
+ assert nh > 0 and nw > 0, (nh, nw)
+
+ lr = lr.resize((nw,nh), Image.BICUBIC)
+ if keepshapes:
+ lr = lr.resize((w,h), keepmode)
+ lr = np.array(lr)/127.5-1.0
+ lr = lr.astype(dtype)
+
+ return lr
+
+
+class ImageNetScale(Dataset):
+ def __init__(self, size=None, crop_size=None, random_crop=False,
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
+ self.base = self.get_base()
+
+ self.size = size
+ self.crop_size = crop_size if crop_size is not None else self.size
+ self.random_crop = random_crop
+ self.up_factor = up_factor
+ self.hr_factor = hr_factor
+ self.keep_mode = keep_mode
+
+ transforms = list()
+
+ if self.size is not None and self.size > 0:
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ self.rescaler = rescaler
+ transforms.append(rescaler)
+
+ if self.crop_size is not None and self.crop_size > 0:
+ if len(transforms) == 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
+
+ if not self.random_crop:
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
+ transforms.append(cropper)
+
+ if len(transforms) > 0:
+ if self.up_factor is not None:
+ additional_targets = {"lr": "image"}
+ else:
+ additional_targets = None
+ self.preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ # adjust resolution
+ image = imscale(image, self.hr_factor, keepshapes=False)
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+ if self.up_factor is None:
+ image = self.preprocessor(image=image)["image"]
+ example["image"] = image
+ else:
+ lr = imscale(image, self.up_factor, keepshapes=True,
+ keepmode=self.keep_mode)
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+class ImageNetScaleTrain(ImageNetScale):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetScaleValidation(ImageNetScale):
+ def get_base(self):
+ return ImageNetValidation()
+
+
+from skimage.feature import canny
+from skimage.color import rgb2gray
+
+
+class ImageNetEdges(ImageNetScale):
+ def __init__(self, up_factor=1, **kwargs):
+ super().__init__(up_factor=1, **kwargs)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+
+ lr = canny(rgb2gray(image), sigma=2)
+ lr = lr.astype(np.float32)
+ lr = lr[:,:,None][:,:,[0,0,0]]
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+
+class ImageNetEdgesTrain(ImageNetEdges):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetEdgesValidation(ImageNetEdges):
+ def get_base(self):
+ return ImageNetValidation()
diff --git a/3DTopia/taming/data/open_images_helper.py b/3DTopia/taming/data/open_images_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243
--- /dev/null
+++ b/3DTopia/taming/data/open_images_helper.py
@@ -0,0 +1,379 @@
+open_images_unify_categories_for_coco = {
+ '/m/03bt1vf': '/m/01g317',
+ '/m/04yx4': '/m/01g317',
+ '/m/05r655': '/m/01g317',
+ '/m/01bl7v': '/m/01g317',
+ '/m/0cnyhnx': '/m/01xq0k1',
+ '/m/01226z': '/m/018xm',
+ '/m/05ctyq': '/m/018xm',
+ '/m/058qzx': '/m/04ctx',
+ '/m/06pcq': '/m/0l515',
+ '/m/03m3pdh': '/m/02crq1',
+ '/m/046dlr': '/m/01x3z',
+ '/m/0h8mzrc': '/m/01x3z',
+}
+
+
+top_300_classes_plus_coco_compatibility = [
+ ('Man', 1060962),
+ ('Clothing', 986610),
+ ('Tree', 748162),
+ ('Woman', 611896),
+ ('Person', 610294),
+ ('Human face', 442948),
+ ('Girl', 175399),
+ ('Building', 162147),
+ ('Car', 159135),
+ ('Plant', 155704),
+ ('Human body', 137073),
+ ('Flower', 133128),
+ ('Window', 127485),
+ ('Human arm', 118380),
+ ('House', 114365),
+ ('Wheel', 111684),
+ ('Suit', 99054),
+ ('Human hair', 98089),
+ ('Human head', 92763),
+ ('Chair', 88624),
+ ('Boy', 79849),
+ ('Table', 73699),
+ ('Jeans', 57200),
+ ('Tire', 55725),
+ ('Skyscraper', 53321),
+ ('Food', 52400),
+ ('Footwear', 50335),
+ ('Dress', 50236),
+ ('Human leg', 47124),
+ ('Toy', 46636),
+ ('Tower', 45605),
+ ('Boat', 43486),
+ ('Land vehicle', 40541),
+ ('Bicycle wheel', 34646),
+ ('Palm tree', 33729),
+ ('Fashion accessory', 32914),
+ ('Glasses', 31940),
+ ('Bicycle', 31409),
+ ('Furniture', 30656),
+ ('Sculpture', 29643),
+ ('Bottle', 27558),
+ ('Dog', 26980),
+ ('Snack', 26796),
+ ('Human hand', 26664),
+ ('Bird', 25791),
+ ('Book', 25415),
+ ('Guitar', 24386),
+ ('Jacket', 23998),
+ ('Poster', 22192),
+ ('Dessert', 21284),
+ ('Baked goods', 20657),
+ ('Drink', 19754),
+ ('Flag', 18588),
+ ('Houseplant', 18205),
+ ('Tableware', 17613),
+ ('Airplane', 17218),
+ ('Door', 17195),
+ ('Sports uniform', 17068),
+ ('Shelf', 16865),
+ ('Drum', 16612),
+ ('Vehicle', 16542),
+ ('Microphone', 15269),
+ ('Street light', 14957),
+ ('Cat', 14879),
+ ('Fruit', 13684),
+ ('Fast food', 13536),
+ ('Animal', 12932),
+ ('Vegetable', 12534),
+ ('Train', 12358),
+ ('Horse', 11948),
+ ('Flowerpot', 11728),
+ ('Motorcycle', 11621),
+ ('Fish', 11517),
+ ('Desk', 11405),
+ ('Helmet', 10996),
+ ('Truck', 10915),
+ ('Bus', 10695),
+ ('Hat', 10532),
+ ('Auto part', 10488),
+ ('Musical instrument', 10303),
+ ('Sunglasses', 10207),
+ ('Picture frame', 10096),
+ ('Sports equipment', 10015),
+ ('Shorts', 9999),
+ ('Wine glass', 9632),
+ ('Duck', 9242),
+ ('Wine', 9032),
+ ('Rose', 8781),
+ ('Tie', 8693),
+ ('Butterfly', 8436),
+ ('Beer', 7978),
+ ('Cabinetry', 7956),
+ ('Laptop', 7907),
+ ('Insect', 7497),
+ ('Goggles', 7363),
+ ('Shirt', 7098),
+ ('Dairy Product', 7021),
+ ('Marine invertebrates', 7014),
+ ('Cattle', 7006),
+ ('Trousers', 6903),
+ ('Van', 6843),
+ ('Billboard', 6777),
+ ('Balloon', 6367),
+ ('Human nose', 6103),
+ ('Tent', 6073),
+ ('Camera', 6014),
+ ('Doll', 6002),
+ ('Coat', 5951),
+ ('Mobile phone', 5758),
+ ('Swimwear', 5729),
+ ('Strawberry', 5691),
+ ('Stairs', 5643),
+ ('Goose', 5599),
+ ('Umbrella', 5536),
+ ('Cake', 5508),
+ ('Sun hat', 5475),
+ ('Bench', 5310),
+ ('Bookcase', 5163),
+ ('Bee', 5140),
+ ('Computer monitor', 5078),
+ ('Hiking equipment', 4983),
+ ('Office building', 4981),
+ ('Coffee cup', 4748),
+ ('Curtain', 4685),
+ ('Plate', 4651),
+ ('Box', 4621),
+ ('Tomato', 4595),
+ ('Coffee table', 4529),
+ ('Office supplies', 4473),
+ ('Maple', 4416),
+ ('Muffin', 4365),
+ ('Cocktail', 4234),
+ ('Castle', 4197),
+ ('Couch', 4134),
+ ('Pumpkin', 3983),
+ ('Computer keyboard', 3960),
+ ('Human mouth', 3926),
+ ('Christmas tree', 3893),
+ ('Mushroom', 3883),
+ ('Swimming pool', 3809),
+ ('Pastry', 3799),
+ ('Lavender (Plant)', 3769),
+ ('Football helmet', 3732),
+ ('Bread', 3648),
+ ('Traffic sign', 3628),
+ ('Common sunflower', 3597),
+ ('Television', 3550),
+ ('Bed', 3525),
+ ('Cookie', 3485),
+ ('Fountain', 3484),
+ ('Paddle', 3447),
+ ('Bicycle helmet', 3429),
+ ('Porch', 3420),
+ ('Deer', 3387),
+ ('Fedora', 3339),
+ ('Canoe', 3338),
+ ('Carnivore', 3266),
+ ('Bowl', 3202),
+ ('Human eye', 3166),
+ ('Ball', 3118),
+ ('Pillow', 3077),
+ ('Salad', 3061),
+ ('Beetle', 3060),
+ ('Orange', 3050),
+ ('Drawer', 2958),
+ ('Platter', 2937),
+ ('Elephant', 2921),
+ ('Seafood', 2921),
+ ('Monkey', 2915),
+ ('Countertop', 2879),
+ ('Watercraft', 2831),
+ ('Helicopter', 2805),
+ ('Kitchen appliance', 2797),
+ ('Personal flotation device', 2781),
+ ('Swan', 2739),
+ ('Lamp', 2711),
+ ('Boot', 2695),
+ ('Bronze sculpture', 2693),
+ ('Chicken', 2677),
+ ('Taxi', 2643),
+ ('Juice', 2615),
+ ('Cowboy hat', 2604),
+ ('Apple', 2600),
+ ('Tin can', 2590),
+ ('Necklace', 2564),
+ ('Ice cream', 2560),
+ ('Human beard', 2539),
+ ('Coin', 2536),
+ ('Candle', 2515),
+ ('Cart', 2512),
+ ('High heels', 2441),
+ ('Weapon', 2433),
+ ('Handbag', 2406),
+ ('Penguin', 2396),
+ ('Rifle', 2352),
+ ('Violin', 2336),
+ ('Skull', 2304),
+ ('Lantern', 2285),
+ ('Scarf', 2269),
+ ('Saucer', 2225),
+ ('Sheep', 2215),
+ ('Vase', 2189),
+ ('Lily', 2180),
+ ('Mug', 2154),
+ ('Parrot', 2140),
+ ('Human ear', 2137),
+ ('Sandal', 2115),
+ ('Lizard', 2100),
+ ('Kitchen & dining room table', 2063),
+ ('Spider', 1977),
+ ('Coffee', 1974),
+ ('Goat', 1926),
+ ('Squirrel', 1922),
+ ('Cello', 1913),
+ ('Sushi', 1881),
+ ('Tortoise', 1876),
+ ('Pizza', 1870),
+ ('Studio couch', 1864),
+ ('Barrel', 1862),
+ ('Cosmetics', 1841),
+ ('Moths and butterflies', 1841),
+ ('Convenience store', 1817),
+ ('Watch', 1792),
+ ('Home appliance', 1786),
+ ('Harbor seal', 1780),
+ ('Luggage and bags', 1756),
+ ('Vehicle registration plate', 1754),
+ ('Shrimp', 1751),
+ ('Jellyfish', 1730),
+ ('French fries', 1723),
+ ('Egg (Food)', 1698),
+ ('Football', 1697),
+ ('Musical keyboard', 1683),
+ ('Falcon', 1674),
+ ('Candy', 1660),
+ ('Medical equipment', 1654),
+ ('Eagle', 1651),
+ ('Dinosaur', 1634),
+ ('Surfboard', 1630),
+ ('Tank', 1628),
+ ('Grape', 1624),
+ ('Lion', 1624),
+ ('Owl', 1622),
+ ('Ski', 1613),
+ ('Waste container', 1606),
+ ('Frog', 1591),
+ ('Sparrow', 1585),
+ ('Rabbit', 1581),
+ ('Pen', 1546),
+ ('Sea lion', 1537),
+ ('Spoon', 1521),
+ ('Sink', 1512),
+ ('Teddy bear', 1507),
+ ('Bull', 1495),
+ ('Sofa bed', 1490),
+ ('Dragonfly', 1479),
+ ('Brassiere', 1478),
+ ('Chest of drawers', 1472),
+ ('Aircraft', 1466),
+ ('Human foot', 1463),
+ ('Pig', 1455),
+ ('Fork', 1454),
+ ('Antelope', 1438),
+ ('Tripod', 1427),
+ ('Tool', 1424),
+ ('Cheese', 1422),
+ ('Lemon', 1397),
+ ('Hamburger', 1393),
+ ('Dolphin', 1390),
+ ('Mirror', 1390),
+ ('Marine mammal', 1387),
+ ('Giraffe', 1385),
+ ('Snake', 1368),
+ ('Gondola', 1364),
+ ('Wheelchair', 1360),
+ ('Piano', 1358),
+ ('Cupboard', 1348),
+ ('Banana', 1345),
+ ('Trumpet', 1335),
+ ('Lighthouse', 1333),
+ ('Invertebrate', 1317),
+ ('Carrot', 1268),
+ ('Sock', 1260),
+ ('Tiger', 1241),
+ ('Camel', 1224),
+ ('Parachute', 1224),
+ ('Bathroom accessory', 1223),
+ ('Earrings', 1221),
+ ('Headphones', 1218),
+ ('Skirt', 1198),
+ ('Skateboard', 1190),
+ ('Sandwich', 1148),
+ ('Saxophone', 1141),
+ ('Goldfish', 1136),
+ ('Stool', 1104),
+ ('Traffic light', 1097),
+ ('Shellfish', 1081),
+ ('Backpack', 1079),
+ ('Sea turtle', 1078),
+ ('Cucumber', 1075),
+ ('Tea', 1051),
+ ('Toilet', 1047),
+ ('Roller skates', 1040),
+ ('Mule', 1039),
+ ('Bust', 1031),
+ ('Broccoli', 1030),
+ ('Crab', 1020),
+ ('Oyster', 1019),
+ ('Cannon', 1012),
+ ('Zebra', 1012),
+ ('French horn', 1008),
+ ('Grapefruit', 998),
+ ('Whiteboard', 997),
+ ('Zucchini', 997),
+ ('Crocodile', 992),
+
+ ('Clock', 960),
+ ('Wall clock', 958),
+
+ ('Doughnut', 869),
+ ('Snail', 868),
+
+ ('Baseball glove', 859),
+
+ ('Panda', 830),
+ ('Tennis racket', 830),
+
+ ('Pear', 652),
+
+ ('Bagel', 617),
+ ('Oven', 616),
+ ('Ladybug', 615),
+ ('Shark', 615),
+ ('Polar bear', 614),
+ ('Ostrich', 609),
+
+ ('Hot dog', 473),
+ ('Microwave oven', 467),
+ ('Fire hydrant', 20),
+ ('Stop sign', 20),
+ ('Parking meter', 20),
+ ('Bear', 20),
+ ('Flying disc', 20),
+ ('Snowboard', 20),
+ ('Tennis ball', 20),
+ ('Kite', 20),
+ ('Baseball bat', 20),
+ ('Kitchen knife', 20),
+ ('Knife', 20),
+ ('Submarine sandwich', 20),
+ ('Computer mouse', 20),
+ ('Remote control', 20),
+ ('Toaster', 20),
+ ('Sink', 20),
+ ('Refrigerator', 20),
+ ('Alarm clock', 20),
+ ('Wall clock', 20),
+ ('Scissors', 20),
+ ('Hair dryer', 20),
+ ('Toothbrush', 20),
+ ('Suitcase', 20)
+]
diff --git a/3DTopia/taming/data/sflckr.py b/3DTopia/taming/data/sflckr.py
new file mode 100644
index 0000000000000000000000000000000000000000..91101be5953b113f1e58376af637e43f366b3dee
--- /dev/null
+++ b/3DTopia/taming/data/sflckr.py
@@ -0,0 +1,91 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class SegmentationBase(Dataset):
+ def __init__(self,
+ data_csv, data_root, segmentation_root,
+ size=None, random_crop=False, interpolation="bicubic",
+ n_labels=182, shift_segmentation=False,
+ ):
+ self.n_labels = n_labels
+ self.shift_segmentation = shift_segmentation
+ self.data_csv = data_csv
+ self.data_root = data_root
+ self.segmentation_root = segmentation_root
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
+ for l in self.image_paths]
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ assert segmentation.mode == "L", segmentation.mode
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.shift_segmentation:
+ # used to support segmentations containing unlabeled==255 label
+ segmentation = segmentation+1
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image,
+ mask=segmentation
+ )
+ else:
+ processed = {"image": image,
+ "mask": segmentation
+ }
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/sflckr_examples.txt",
+ data_root="data/sflckr_images",
+ segmentation_root="data/sflckr_segmentations",
+ size=size, random_crop=random_crop, interpolation=interpolation)
diff --git a/3DTopia/taming/data/utils.py b/3DTopia/taming/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b3c3d53cd2b6c72b481b59834cf809d3735b394
--- /dev/null
+++ b/3DTopia/taming/data/utils.py
@@ -0,0 +1,169 @@
+import collections
+import os
+import tarfile
+import urllib
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+from taming.data.helper_types import Annotation
+from torch._six import string_classes
+from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
+from tqdm import tqdm
+
+
+def unpack(path):
+ if path.endswith("tar.gz"):
+ with tarfile.open(path, "r:gz") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("tar"):
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("zip"):
+ with zipfile.ZipFile(path, "r") as f:
+ f.extractall(path=os.path.split(path)[0])
+ else:
+ raise NotImplementedError(
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
+ )
+
+
+def reporthook(bar):
+ """tqdm progress bar for downloads."""
+
+ def hook(b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ bar.total = tsize
+ bar.update(b * bsize - bar.n)
+
+ return hook
+
+
+def get_root(name):
+ base = "data/"
+ root = os.path.join(base, name)
+ os.makedirs(root, exist_ok=True)
+ return root
+
+
+def is_prepared(root):
+ return Path(root).joinpath(".ready").exists()
+
+
+def mark_prepared(root):
+ Path(root).joinpath(".ready").touch()
+
+
+def prompt_download(file_, source, target_dir, content_dir=None):
+ targetpath = os.path.join(target_dir, file_)
+ while not os.path.exists(targetpath):
+ if content_dir is not None and os.path.exists(
+ os.path.join(target_dir, content_dir)
+ ):
+ break
+ print(
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
+ )
+ if content_dir is not None:
+ print(
+ "Or place its content into '{}'.".format(
+ os.path.join(target_dir, content_dir)
+ )
+ )
+ input("Press Enter when done...")
+ return targetpath
+
+
+def download_url(file_, url, target_dir):
+ targetpath = os.path.join(target_dir, file_)
+ os.makedirs(target_dir, exist_ok=True)
+ with tqdm(
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
+ ) as bar:
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
+ return targetpath
+
+
+def download_urls(urls, target_dir):
+ paths = dict()
+ for fname, url in urls.items():
+ outpath = download_url(fname, url, target_dir)
+ paths[fname] = outpath
+ return paths
+
+
+def quadratic_crop(x, bbox, alpha=1.0):
+ """bbox is xmin, ymin, xmax, ymax"""
+ im_h, im_w = x.shape[:2]
+ bbox = np.array(bbox, dtype=np.float32)
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ l = int(alpha * max(w, h))
+ l = max(l, 2)
+
+ required_padding = -1 * min(
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
+ )
+ required_padding = int(np.ceil(required_padding))
+ if required_padding > 0:
+ padding = [
+ [required_padding, required_padding],
+ [required_padding, required_padding],
+ ]
+ padding += [[0, 0]] * (len(x.shape) - 2)
+ x = np.pad(x, padding, "reflect")
+ center = center[0] + required_padding, center[1] + required_padding
+ xmin = int(center[0] - l / 2)
+ ymin = int(center[1] - l / 2)
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
+
+
+def custom_collate(batch):
+ r"""source: pytorch 1.9.0, only one modification to original code """
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return custom_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
+ return batch # added
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError('each element in list of batch should be of equal size')
+ transposed = zip(*batch)
+ return [custom_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
diff --git a/3DTopia/taming/lr_scheduler.py b/3DTopia/taming/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e598ed120159c53da6820a55ad86b89f5c70c82d
--- /dev/null
+++ b/3DTopia/taming/lr_scheduler.py
@@ -0,0 +1,34 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n):
+ return self.schedule(n)
+
diff --git a/3DTopia/taming/models/cond_transformer.py b/3DTopia/taming/models/cond_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4c63730fa86ac1b92b37af14c14fb696595b1ab
--- /dev/null
+++ b/3DTopia/taming/models/cond_transformer.py
@@ -0,0 +1,352 @@
+import os, math
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+from taming.modules.util import SOSProvider
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class Net2NetTransformer(pl.LightningModule):
+ def __init__(self,
+ transformer_config,
+ first_stage_config,
+ cond_stage_config,
+ permuter_config=None,
+ ckpt_path=None,
+ ignore_keys=[],
+ first_stage_key="image",
+ cond_stage_key="depth",
+ downsample_cond_size=-1,
+ pkeep=1.0,
+ sos_token=0,
+ unconditional=False,
+ ):
+ super().__init__()
+ self.be_unconditional = unconditional
+ self.sos_token = sos_token
+ self.first_stage_key = first_stage_key
+ self.cond_stage_key = cond_stage_key
+ self.init_first_stage_from_ckpt(first_stage_config)
+ self.init_cond_stage_from_ckpt(cond_stage_config)
+ if permuter_config is None:
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
+ self.permuter = instantiate_from_config(config=permuter_config)
+ self.transformer = instantiate_from_config(config=transformer_config)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.downsample_cond_size = downsample_cond_size
+ self.pkeep = pkeep
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ for k in sd.keys():
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ self.print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def init_first_stage_from_ckpt(self, config):
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.first_stage_model = model
+
+ def init_cond_stage_from_ckpt(self, config):
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__" or self.be_unconditional:
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
+ f"Prepending {self.sos_token} as a sos token.")
+ self.be_unconditional = True
+ self.cond_stage_key = self.first_stage_key
+ self.cond_stage_model = SOSProvider(self.sos_token)
+ else:
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.cond_stage_model = model
+
+ def forward(self, x, c):
+ # one step to produce the logits
+ _, z_indices = self.encode_to_z(x)
+ _, c_indices = self.encode_to_c(c)
+
+ if self.training and self.pkeep < 1.0:
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
+ device=z_indices.device))
+ mask = mask.round().to(dtype=torch.int64)
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+ a_indices = mask*z_indices+(1-mask)*r_indices
+ else:
+ a_indices = z_indices
+
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
+
+ # target includes all sequence elements (no need to handle first one
+ # differently because we are conditioning)
+ target = z_indices
+ # make the prediction
+ logits, _ = self.transformer(cz_indices[:, :-1])
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1:
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
+ if len(indices.shape) > 2:
+ indices = indices.view(c.shape[0], -1)
+ return quant_c, indices
+
+ @torch.no_grad()
+ def decode_to_img(self, index, zshape):
+ index = self.permuter(index, reverse=True)
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
+ index.reshape(-1), shape=bhwc)
+ x = self.first_stage_model.decode(quant_z)
+ return x
+
+ @torch.no_grad()
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+ log = dict()
+
+ N = 4
+ if lr_interface:
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
+ else:
+ x, c = self.get_xc(batch, N)
+ x = x.to(device=self.device)
+ c = c.to(device=self.device)
+
+ quant_z, z_indices = self.encode_to_z(x)
+ quant_c, c_indices = self.encode_to_c(c)
+
+ # create a "half"" sample
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
+
+ # sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
+
+ # det sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ sample=False,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
+
+ # reconstruction
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
+
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
+ label_for_category_no = dataset.get_textual_label_for_category_no
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
+ for i in range(quant_c.shape[0]):
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
+ log["conditioning_rec"] = log["conditioning"]
+ elif self.cond_stage_key != "image":
+ cond_rec = self.cond_stage_model.decode(quant_c)
+ if self.cond_stage_key == "segmentation":
+ # get image from segmentation mask
+ num_classes = cond_rec.shape[1]
+
+ c = torch.argmax(c, dim=1, keepdim=True)
+ c = F.one_hot(c, num_classes=num_classes)
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ c = self.cond_stage_model.to_rgb(c)
+
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+ log["conditioning_rec"] = cond_rec
+ log["conditioning"] = c
+
+ log["samples_half"] = x_sample
+ log["samples_nopix"] = x_sample_nopix
+ log["samples_det"] = x_sample_det
+ return log
+
+ def get_input(self, key, batch):
+ x = batch[key]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ if len(x.shape) == 4:
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ if x.dtype == torch.double:
+ x = x.float()
+ return x
+
+ def get_xc(self, batch, N=None):
+ x = self.get_input(self.first_stage_key, batch)
+ c = self.get_input(self.cond_stage_key, batch)
+ if N is not None:
+ x = x[:N]
+ c = c[:N]
+ return x, c
+
+ def shared_step(self, batch, batch_idx):
+ x, c = self.get_xc(batch)
+ logits, target = self(x, c)
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Following minGPT:
+ This long function is unfortunately doing something very simple and is being very defensive:
+ We are separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ We are then returning the PyTorch optimizer object.
+ """
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear, )
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.transformer.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add('pos_emb')
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params), )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+ return optimizer
diff --git a/3DTopia/taming/models/dummy_cond_stage.py b/3DTopia/taming/models/dummy_cond_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e19938078752e09b926a3e749907ee99a258ca0
--- /dev/null
+++ b/3DTopia/taming/models/dummy_cond_stage.py
@@ -0,0 +1,22 @@
+from torch import Tensor
+
+
+class DummyCondStage:
+ def __init__(self, conditional_key):
+ self.conditional_key = conditional_key
+ self.train = None
+
+ def eval(self):
+ return self
+
+ @staticmethod
+ def encode(c: Tensor):
+ return c, None, (None, None, c)
+
+ @staticmethod
+ def decode(c: Tensor):
+ return c
+
+ @staticmethod
+ def to_rgb(c: Tensor):
+ return c
diff --git a/3DTopia/taming/models/vqgan.py b/3DTopia/taming/models/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6950baa5f739111cd64c17235dca8be3a5f8037
--- /dev/null
+++ b/3DTopia/taming/models/vqgan.py
@@ -0,0 +1,404 @@
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+
+from taming.modules.diffusionmodules.model import Encoder, Decoder
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from taming.modules.vqvae.quantize import GumbelQuantize
+from taming.modules.vqvae.quantize import EMAVectorQuantizer
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap, sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.image_key = image_key
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ return x.float()
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQSegmentationModel(VQModel):
+ def __init__(self, n_labels, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ return opt_ae
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ total_loss = log_dict_ae["val/total_loss"]
+ self.log("val/total_loss", total_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ return aeloss
+
+ @torch.no_grad()
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+
+class VQNoDiscModel(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None
+ ):
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
+ colorize_nlabels=colorize_nlabels)
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
+ output = pl.TrainResult(minimize=aeloss)
+ output.log("train/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return output
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ output = pl.EvalResult(checkpoint_on=rec_loss)
+ output.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae)
+
+ return output
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=self.learning_rate, betas=(0.5, 0.9))
+ return optimizer
+
+
+class GumbelVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ temperature_scheduler_config,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ kl_weight=1e-8,
+ remap=None,
+ ):
+
+ z_channels = ddconfig["z_channels"]
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+
+ self.loss.n_classes = n_embed
+ self.vocab_size = n_embed
+
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
+ n_embed=n_embed,
+ kl_weight=kl_weight, temp_init=1.0,
+ remap=remap)
+
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def temperature_scheduling(self):
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode_code(self, code_b):
+ raise NotImplementedError
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ self.temperature_scheduling()
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ # encode
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, _, _ = self.quantize(h)
+ # decode
+ x_rec = self.decode(quant)
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+ return log
+
+
+class EMAVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
+ embedding_dim=embed_dim,
+ beta=0.25,
+ remap=remap)
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ #Remove self.quantize from parameter list since it is updated via EMA
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
\ No newline at end of file
diff --git a/3DTopia/taming/modules/diffusionmodules/model.py b/3DTopia/taming/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c
--- /dev/null
+++ b/3DTopia/taming/modules/diffusionmodules/model.py
@@ -0,0 +1,776 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, t=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x):
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VUNet(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ in_channels, c_channels,
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(c_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ self.z_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, z):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ z = self.z_in(z)
+ h = torch.cat((h,z),dim=1)
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
diff --git a/3DTopia/taming/modules/discriminator/model.py b/3DTopia/taming/modules/discriminator/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aaa3110d0a7bcd05de7eca1e45101589ca5af05
--- /dev/null
+++ b/3DTopia/taming/modules/discriminator/model.py
@@ -0,0 +1,67 @@
+import functools
+import torch.nn as nn
+
+
+from taming.modules.util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/3DTopia/taming/modules/losses/__init__.py b/3DTopia/taming/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d09caf9eb805f849a517f1b23503e1a4d6ea1ec5
--- /dev/null
+++ b/3DTopia/taming/modules/losses/__init__.py
@@ -0,0 +1,2 @@
+from taming.modules.losses.vqperceptual import DummyLoss
+
diff --git a/3DTopia/taming/modules/losses/lpips.py b/3DTopia/taming/modules/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7280447694ffc302a7636e7e4d6183408e0aa95
--- /dev/null
+++ b/3DTopia/taming/modules/losses/lpips.py
@@ -0,0 +1,123 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+import torch
+import torch.nn as nn
+from torchvision import models
+from collections import namedtuple
+
+from taming.util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer which does a 1x1 conv """
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = [nn.Dropout(), ] if (use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
+ return x/(norm_factor+eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2,3],keepdim=keepdim)
+
diff --git a/3DTopia/taming/modules/losses/segmentation.py b/3DTopia/taming/modules/losses/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5
--- /dev/null
+++ b/3DTopia/taming/modules/losses/segmentation.py
@@ -0,0 +1,22 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BCELoss(nn.Module):
+ def forward(self, prediction, target):
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
+ return loss, {}
+
+
+class BCELossWithQuant(nn.Module):
+ def __init__(self, codebook_weight=1.):
+ super().__init__()
+ self.codebook_weight = codebook_weight
+
+ def forward(self, qloss, target, prediction, split):
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
+ loss = bce_loss + self.codebook_weight*qloss
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
+ "{}/quant_loss".format(split): qloss.detach().mean()
+ }
diff --git a/3DTopia/taming/modules/losses/vqperceptual.py b/3DTopia/taming/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2febd445728479d4cd9aacdb2572cb1f1af04db
--- /dev/null
+++ b/3DTopia/taming/modules/losses/vqperceptual.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+
+
+class DummyLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
+ return d_loss
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train"):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/3DTopia/taming/modules/misc/coord.py b/3DTopia/taming/modules/misc/coord.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09
--- /dev/null
+++ b/3DTopia/taming/modules/misc/coord.py
@@ -0,0 +1,31 @@
+import torch
+
+class CoordStage(object):
+ def __init__(self, n_embed, down_factor):
+ self.n_embed = n_embed
+ self.down_factor = down_factor
+
+ def eval(self):
+ return self
+
+ def encode(self, c):
+ """fake vqmodel interface"""
+ assert 0.0 <= c.min() and c.max() <= 1.0
+ b,ch,h,w = c.shape
+ assert ch == 1
+
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
+ mode="area")
+ c = c.clamp(0.0, 1.0)
+ c = self.n_embed*c
+ c_quant = c.round()
+ c_ind = c_quant.to(dtype=torch.long)
+
+ info = None, None, c_ind
+ return c_quant, None, info
+
+ def decode(self, c):
+ c = c/self.n_embed
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
+ mode="nearest")
+ return c
diff --git a/3DTopia/taming/modules/transformer/mingpt.py b/3DTopia/taming/modules/transformer/mingpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d14b7b68117f4b9f297b2929397cd4f55089334c
--- /dev/null
+++ b/3DTopia/taming/modules/transformer/mingpt.py
@@ -0,0 +1,415 @@
+"""
+taken from: https://github.com/karpathy/minGPT/
+GPT model:
+- the initial stem consists of a combination of token encoding and a positional encoding
+- the meat of it is a uniform sequence of Transformer blocks
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
+ - all blocks feed into a central residual pathway similar to resnets
+- the final decoder is a linear projection into a vanilla Softmax classifier
+"""
+
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from transformers import top_k_top_p_filtering
+
+logger = logging.getLogger(__name__)
+
+
+class GPTConfig:
+ """ base GPT config, params common to all GPT versions """
+ embd_pdrop = 0.1
+ resid_pdrop = 0.1
+ attn_pdrop = 0.1
+
+ def __init__(self, vocab_size, block_size, **kwargs):
+ self.vocab_size = vocab_size
+ self.block_size = block_size
+ for k,v in kwargs.items():
+ setattr(self, k, v)
+
+
+class GPT1Config(GPTConfig):
+ """ GPT-1 like network roughly 125M params """
+ n_layer = 12
+ n_head = 12
+ n_embd = 768
+
+
+class CausalSelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ mask = torch.tril(torch.ones(config.block_size,
+ config.block_size))
+ if hasattr(config, "n_unmasked"):
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
+ self.n_head = config.n_head
+
+ def forward(self, x, layer_past=None):
+ B, T, C = x.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+ present = torch.stack((k, v))
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ if layer_past is None:
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
+
+ att = F.softmax(att, dim=-1)
+ att = self.attn_drop(att)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+ return y, present # TODO: check that this does not break anything
+
+
+class Block(nn.Module):
+ """ an unassuming Transformer block """
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+ self.mlp = nn.Sequential(
+ nn.Linear(config.n_embd, 4 * config.n_embd),
+ nn.GELU(), # nice
+ nn.Linear(4 * config.n_embd, config.n_embd),
+ nn.Dropout(config.resid_pdrop),
+ )
+
+ def forward(self, x, layer_past=None, return_present=False):
+ # TODO: check that training still works
+ if return_present: assert not self.training
+ # layer past: tuple of length two with B, nh, T, hs
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
+
+ x = x + attn
+ x = x + self.mlp(self.ln2(x))
+ if layer_past is not None or return_present:
+ return x, present
+ return x
+
+
+class GPT(nn.Module):
+ """ the full GPT language model, with a context size of block_size """
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
+ # inference only
+ assert not self.training
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ if past is not None:
+ assert past_length is not None
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
+ past_shape = list(past.shape)
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
+ else:
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
+
+ x = self.drop(token_embeddings + position_embeddings)
+ presents = [] # accumulate over layers
+ for i, block in enumerate(self.blocks):
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
+ presents.append(present)
+
+ x = self.ln_f(x)
+ logits = self.head(x)
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
+
+
+class DummyGPT(nn.Module):
+ # for debugging
+ def __init__(self, add_value=1):
+ super().__init__()
+ self.add_value = add_value
+
+ def forward(self, idx):
+ return idx + self.add_value, None
+
+
+class CodeGPT(nn.Module):
+ """Takes in semi-embeddings"""
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.taming_cinln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+
+
+#### sampling utils
+
+def top_k_logits(logits, k):
+ v, ix = torch.topk(logits, k)
+ out = logits.clone()
+ out[out < v[:, [-1]]] = -float('Inf')
+ return out
+
+@torch.no_grad()
+def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
+ """
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
+ of block_size, unlike an RNN that has an infinite context window.
+ """
+ block_size = model.get_block_size()
+ model.eval()
+ for k in range(steps):
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
+ logits, _ = model(x_cond)
+ # pluck the logits at the final step and scale by temperature
+ logits = logits[:, -1, :] / temperature
+ # optionally crop probabilities to only the top k options
+ if top_k is not None:
+ logits = top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = F.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ # append to the sequence and continue
+ x = torch.cat((x, ix), dim=1)
+
+ return x
+
+
+@torch.no_grad()
+def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
+ top_k=None, top_p=None, callback=None):
+ # x is conditioning
+ sample = x
+ cond_len = x.shape[1]
+ past = None
+ for n in range(steps):
+ if callback is not None:
+ callback(n)
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
+ if past is None:
+ past = [present]
+ else:
+ past.append(present)
+ logits = logits[:, -1, :] / temperature
+ if top_k is not None:
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+
+ probs = F.softmax(logits, dim=-1)
+ if not sample_logits:
+ _, x = torch.topk(probs, k=1, dim=-1)
+ else:
+ x = torch.multinomial(probs, num_samples=1)
+ # append to the sequence and continue
+ sample = torch.cat((sample, x), dim=1)
+ del past
+ sample = sample[:, cond_len:] # cut conditioning off
+ return sample
+
+
+#### clustering utils
+
+class KMeans(nn.Module):
+ def __init__(self, ncluster=512, nc=3, niter=10):
+ super().__init__()
+ self.ncluster = ncluster
+ self.nc = nc
+ self.niter = niter
+ self.shape = (3,32,32)
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def is_initialized(self):
+ return self.initialized.item() == 1
+
+ @torch.no_grad()
+ def initialize(self, x):
+ N, D = x.shape
+ assert D == self.nc, D
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
+ for i in range(self.niter):
+ # assign all pixels to the closest codebook element
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
+ # move each codebook element to be the mean of the pixels that assigned to it
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
+ # re-assign any poorly positioned codebook elements
+ nanix = torch.any(torch.isnan(c), dim=1)
+ ndead = nanix.sum().item()
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
+
+ self.C.copy_(c)
+ self.initialized.fill_(1)
+
+
+ def forward(self, x, reverse=False, shape=None):
+ if not reverse:
+ # flatten
+ bs,c,h,w = x.shape
+ assert c == self.nc
+ x = x.reshape(bs,c,h*w,1)
+ C = self.C.permute(1,0)
+ C = C.reshape(1,c,1,self.ncluster)
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
+ return a
+ else:
+ # flatten
+ bs, HW = x.shape
+ """
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
+ c = c[bs*[0],:,:,:]
+ c = c[:,:,HW*[0],:]
+ x = x.reshape(bs, 1, HW, 1)
+ x = x[:,3*[0],:,:]
+ x = torch.gather(c, dim=3, index=x)
+ """
+ x = self.C[x]
+ x = x.permute(0,2,1)
+ shape = shape if shape is not None else self.shape
+ x = x.reshape(bs, *shape)
+
+ return x
diff --git a/3DTopia/taming/modules/transformer/permuter.py b/3DTopia/taming/modules/transformer/permuter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d43bb135adde38d94bf18a7e5edaa4523cd95cf
--- /dev/null
+++ b/3DTopia/taming/modules/transformer/permuter.py
@@ -0,0 +1,248 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class AbstractPermuter(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ def forward(self, x, reverse=False):
+ raise NotImplementedError
+
+
+class Identity(AbstractPermuter):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, reverse=False):
+ return x
+
+
+class Subsample(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ C = 1
+ indices = np.arange(H*W).reshape(C,H,W)
+ while min(H, W) > 1:
+ indices = indices.reshape(C,H//2,2,W//2,2)
+ indices = indices.transpose(0,2,4,1,3)
+ indices = indices.reshape(C*4,H//2, W//2)
+ H = H//2
+ W = W//2
+ C = C*4
+ assert H == W == 1
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx',
+ nn.Parameter(idx, requires_grad=False))
+ self.register_buffer('backward_shuffle_idx',
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+def mortonify(i, j):
+ """(i,j) index to linear morton code"""
+ i = np.uint64(i)
+ j = np.uint64(j)
+
+ z = np.uint(0)
+
+ for pos in range(32):
+ z = (z |
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
+ )
+ return z
+
+
+class ZCurve(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
+ idx = np.argsort(reverseidx)
+ idx = torch.tensor(idx)
+ reverseidx = torch.tensor(reverseidx)
+ self.register_buffer('forward_shuffle_idx',
+ idx)
+ self.register_buffer('backward_shuffle_idx',
+ reverseidx)
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralOut(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralIn(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = idx[::-1]
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class Random(nn.Module):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.random.RandomState(1).permutation(H*W)
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class AlternateParsing(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.arange(W*H).reshape(H,W)
+ for i in range(1, H, 2):
+ indices[i, :] = indices[i, ::-1]
+ idx = indices.flatten()
+ assert len(idx) == H*W
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+if __name__ == "__main__":
+ p0 = AlternateParsing(16, 16)
+ print(p0.forward_shuffle_idx)
+ print(p0.backward_shuffle_idx)
+
+ x = torch.randint(0, 768, size=(11, 256))
+ y = p0(x)
+ xre = p0(y, reverse=True)
+ assert torch.equal(x, xre)
+
+ p1 = SpiralOut(2, 2)
+ print(p1.forward_shuffle_idx)
+ print(p1.backward_shuffle_idx)
diff --git a/3DTopia/taming/modules/util.py b/3DTopia/taming/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8
--- /dev/null
+++ b/3DTopia/taming/modules/util.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+
+
+def count_params(model):
+ total_params = sum(p.numel() for p in model.parameters())
+ return total_params
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True,
+ allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height*width*torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class Labelator(AbstractEncoder):
+ """Net2Net Interface for Class-Conditional Model"""
+ def __init__(self, n_classes, quantize_interface=True):
+ super().__init__()
+ self.n_classes = n_classes
+ self.quantize_interface = quantize_interface
+
+ def encode(self, c):
+ c = c[:,None]
+ if self.quantize_interface:
+ return c, None, [None, None, c.long()]
+ return c
+
+
+class SOSProvider(AbstractEncoder):
+ # for unconditional training
+ def __init__(self, sos_token, quantize_interface=True):
+ super().__init__()
+ self.sos_token = sos_token
+ self.quantize_interface = quantize_interface
+
+ def encode(self, x):
+ # get batch size from data and replicate sos_token
+ c = torch.ones(x.shape[0], 1)*self.sos_token
+ c = c.long().to(x.device)
+ if self.quantize_interface:
+ return c, None, [None, None, c]
+ return c
diff --git a/3DTopia/taming/modules/vqvae/quantize.py b/3DTopia/taming/modules/vqvae/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51
--- /dev/null
+++ b/3DTopia/taming/modules/vqvae/quantize.py
@@ -0,0 +1,445 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ #.........\end
+
+ # with:
+ # .........\start
+ #min_encoding_indices = torch.argmin(d, dim=1)
+ #z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:,None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:,self.used,...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:,self.used,...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b*h*w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0],-1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad = False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ #normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+ remap=None, unknown_index="random"):
+ super().__init__()
+ self.codebook_dim = codebook_dim
+ self.num_tokens = num_tokens
+ self.beta = beta
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ #z, 'b c h w -> b h w c'
+ z = rearrange(z, 'b c h w -> b h w c')
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ #EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ #EMA embedding average
+ embed_sum = encodings.transpose(0,1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ #normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ #z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
+ return z_q, loss, (perplexity, encodings, encoding_indices)
diff --git a/3DTopia/taming/util.py b/3DTopia/taming/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7
--- /dev/null
+++ b/3DTopia/taming/util.py
@@ -0,0 +1,157 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+if __name__ == "__main__":
+ config = {"keya": "a",
+ "keyb": "b",
+ "keyc":
+ {"cc1": 1,
+ "cc2": 2,
+ }
+ }
+ from omegaconf import OmegaConf
+ config = OmegaConf.create(config)
+ print(config)
+ retrieve(config, "keya")
+
diff --git a/3DTopia/utility/initialize.py b/3DTopia/utility/initialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac572368414198aef0efd6ab3b1d43f5c22b0331
--- /dev/null
+++ b/3DTopia/utility/initialize.py
@@ -0,0 +1,13 @@
+import importlib
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
diff --git a/3DTopia/utility/mcubes_from_latent.py b/3DTopia/utility/mcubes_from_latent.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcbad2a16ae1e3e3473f6f2341f2058699540bcb
--- /dev/null
+++ b/3DTopia/utility/mcubes_from_latent.py
@@ -0,0 +1,160 @@
+import os
+import torch
+import argparse
+import mcubes
+import trimesh
+import numpy as np
+from tqdm import tqdm
+from omegaconf import OmegaConf
+from utility.initialize import instantiate_from_config, get_obj_from_str
+from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes
+
+# load model
+parser = argparse.ArgumentParser()
+parser.add_argument("--config", type=str, default=None, required=True)
+parser.add_argument("--ckpt", type=str, default=None, required=True)
+args = parser.parse_args()
+configs = OmegaConf.load(args.config)
+device = 'cuda'
+vae = get_obj_from_str(configs.model.params.first_stage_config['target'])(**configs.model.params.first_stage_config['params'])
+vae = vae.to(device)
+vae.eval()
+
+model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(args.ckpt, map_location='cpu', strict=False, **configs.model.params)
+model = model.to(device)
+
+def extract_mesh(triplane_fname, save_name=None):
+ latent = torch.from_numpy(np.load(triplane_fname)).to(device)
+ with torch.no_grad():
+ with model.ema_scope():
+ triplane = model.decode_first_stage(latent)
+
+ # prepare volumn for marching cube
+ res = 128
+ c_list = torch.linspace(-1.2, 1.2, steps=res)
+ grid_x, grid_y, grid_z = torch.meshgrid(
+ c_list, c_list, c_list, indexing='ij'
+ )
+ coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) # 256x256x256x3
+ plane_axes = generate_planes()
+ feats = sample_from_planes(
+ plane_axes, triplane.reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4
+ )
+ fake_dirs = torch.zeros_like(coords)
+ fake_dirs[..., 0] = 1
+ with torch.no_grad():
+ out = vae.triplane_decoder.decoder(feats, fake_dirs)
+ u = out['sigma'].reshape(res, res, res).detach().cpu().numpy()
+ del out
+
+ # marching cube
+ vertices, triangles = mcubes.marching_cubes(u, 8)
+ min_bound = np.array([-1.2, -1.2, -1.2])
+ max_bound = np.array([1.2, 1.2, 1.2])
+ vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :]
+ pt_vertices = torch.from_numpy(vertices).to(device)
+
+ # extract vertices color
+ res_triplane = 256
+ # rays_d = torch.from_numpy(-vertices / np.sqrt((vertices ** 2).sum(-1)).reshape(-1, 1)).to(device).unsqueeze(0)
+ # rays_o = -rays_d * 2.0
+ render_kwargs = {
+ 'depth_resolution': 128,
+ 'disparity_space_sampling': False,
+ 'box_warp': 2.4,
+ 'depth_resolution_importance': 128,
+ 'clamp_mode': 'softplus',
+ 'white_back': True,
+ 'det': True
+ }
+ # render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane), rays_o, rays_d, render_kwargs, whole_img=False, tvloss=False)
+ # rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
+ # rgb = (rgb * 255).astype(np.uint8)
+ rays_o_list = [
+ np.array([0, 0, 2]),
+ np.array([0, 0, -2]),
+ np.array([0, 2, 0]),
+ np.array([0, -2, 0]),
+ np.array([2, 0, 0]),
+ np.array([-2, 0, 0]),
+ ]
+ rgb_final = None
+ diff_final = None
+ for rays_o in tqdm(rays_o_list):
+ rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device)
+ rays_d = pt_vertices.reshape(-1, 3) - rays_o
+ rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1)
+ dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1)
+
+ # batch_size = 2**14
+ # batch_num = (rays_o.shape[0] // batch_size) + 1
+ # rgb_list = []
+ # depth_diff_list = []
+ # for b in range(batch_num):
+ # cur_rays_o = rays_o[b * batch_size: (b + 1) * batch_size]
+ # cur_rays_d = rays_d[b * batch_size: (b + 1) * batch_size]
+ with torch.no_grad():
+ render_out = vae.triplane_decoder(triplane.reshape(1, 3, -1, res_triplane, res_triplane),
+ rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs,
+ whole_img=False, tvloss=False)
+ rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy()
+ depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy()
+ depth_diff = np.abs(dist - depth)
+
+ # rgb_list.append(rgb)
+ # depth_diff_list.append(depth_diff)
+
+ # del render_out
+ # torch.cuda.empty_cache()
+
+ # rgb = np.concatenate(rgb_list, 0)
+ # depth_diff = np.concatenate(depth_diff_list, 0)
+
+ if rgb_final is None:
+ rgb_final = rgb.copy()
+ diff_final = depth_diff.copy()
+
+ else:
+ ind = diff_final > depth_diff
+ rgb_final[ind] = rgb[ind]
+ diff_final[ind] = depth_diff[ind]
+
+
+ # bgr to rgb
+ rgb_final = np.stack([
+ rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0]
+ ], -1)
+
+ # export to ply
+ mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8))
+ if save_name:
+ trimesh.exchange.export.export_mesh(mesh, save_name, file_type='ply')
+ else:
+ trimesh.exchange.export.export_mesh(mesh, triplane_fname[:-4] + '.ply', file_type='ply')
+
+# load triplane
+# fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/sample_16_0.npy'
+# u = np.load(fname)
+# triplane_fname = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_1/triplane_16_0.npy'
+# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt'
+# folder = 'log/diff_res32ch8_preprocess_ca_text/sample_mesh_opt_simple'
+folder = '/mnt/lustre/hongfangzhou.p/AE3D/log/diff_res32ch8_preprocess_ca_text_new_triplane_96_full_openaimodel_only_cap3d_high_quality_7w/sample_demo_424_prompts_for_demo_30_60_10'
+save_folder = folder + '_extract_mesh'
+os.makedirs(save_folder, exist_ok=True)
+fnames = [f.replace('_sample', 'triplane').replace('mp4', 'npy') for f in os.listdir(folder) if f.startswith('_')]
+prompts = [l.strip() for l in open('test/prompts_for_demo_2.txt', 'r').readlines()][30:60]
+# fnames = [os.path.join(folder, f) for f in os.listdir(folder) if (f.startswith('triplane') and f.endswith('.npy'))]
+fnames = sorted(fnames)
+
+def extract_number(s):
+ return int(s.split('_')[-2])
+
+def extract_id(s):
+ return s.split('_')[-1].split('.')[0]
+
+for fname in fnames:
+ try:
+ print(fname)
+ extract_mesh(os.path.join(folder, fname), os.path.join(save_folder, prompts[extract_number(fname)].replace(' ', '_') + '_' + extract_id(fname) + '.ply'))
+ except Exception as e:
+ print(e)
diff --git a/3DTopia/utility/triplane_renderer/eg3d_renderer.py b/3DTopia/utility/triplane_renderer/eg3d_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..796c47071d2a77d709cf81fba12323155f115b79
--- /dev/null
+++ b/3DTopia/utility/triplane_renderer/eg3d_renderer.py
@@ -0,0 +1,685 @@
+import os
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# TriPlane Utils
+class MipRayMarcher2(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def run_forward(self, colors, densities, depths, rendering_options):
+ deltas = depths[:, :, 1:] - depths[:, :, :-1]
+ colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
+ densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
+ depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
+
+
+ if rendering_options['clamp_mode'] == 'softplus':
+ densities_mid = F.softplus(densities_mid - 1) # activation bias of -1 makes things initialize better
+ else:
+ assert False, "MipRayMarcher only supports `clamp_mode`=`softplus`!"
+
+ density_delta = densities_mid * deltas
+
+ alpha = 1 - torch.exp(-density_delta)
+
+ alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
+ weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
+
+ composite_rgb = torch.sum(weights * colors_mid, -2)
+ weight_total = weights.sum(2)
+ # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
+ composite_depth = torch.sum(weights * depths_mid, -2)
+
+ # clip the composite to min/max range of depths
+ composite_depth = torch.nan_to_num(composite_depth, float('inf'))
+ # composite_depth = torch.nan_to_num(composite_depth, 0.)
+ composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
+
+ if rendering_options.get('white_back', False):
+ composite_rgb = composite_rgb + 1 - weight_total
+
+ composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
+
+ return composite_rgb, composite_depth, weights
+
+ def forward(self, colors, densities, depths, rendering_options):
+ composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options)
+
+ return composite_rgb, composite_depth, weights
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+
+ bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
+ bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+ tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+ tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
+
+def generate_planes():
+ """
+ Defines planes by the three vectors that form the "axes" of the
+ plane. Should work with arbitrary number of planes and planes of
+ arbitrary orientation.
+ """
+ return torch.tensor([[[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]],
+ [[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]],
+ [[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]]], dtype=torch.float32)
+
+def project_onto_planes(planes, coordinates):
+ """
+ Does a projection of a 3D point onto a batch of 2D planes,
+ returning 2D plane coordinates.
+ Takes plane axes of shape n_planes, 3, 3
+ # Takes coordinates of shape N, M, 3
+ # returns projections of shape N*n_planes, M, 2
+ """
+
+ # # ORIGINAL
+ # N, M, C = coordinates.shape
+ # xy_coords = coordinates[..., [0, 1]]
+ # xz_coords = coordinates[..., [0, 2]]
+ # zx_coords = coordinates[..., [2, 0]]
+ # return torch.stack([xy_coords, xz_coords, zx_coords], dim=1).reshape(N*3, M, 2)
+
+ # FIXED
+ N, M, _ = coordinates.shape
+ xy_coords = coordinates[..., [0, 1]]
+ yz_coords = coordinates[..., [1, 2]]
+ zx_coords = coordinates[..., [2, 0]]
+ return torch.stack([xy_coords, yz_coords, zx_coords], dim=1).reshape(N*3, M, 2)
+
+def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
+ assert padding_mode == 'zeros'
+ N, n_planes, C, H, W = plane_features.shape
+ _, M, _ = coordinates.shape
+ plane_features = plane_features.view(N*n_planes, C, H, W)
+
+ coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
+
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
+
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
+ return output_features
+
+def sample_from_3dgrid(grid, coordinates):
+ """
+ Expects coordinates in shape (batch_size, num_points_per_batch, 3)
+ Expects grid in shape (1, channels, H, W, D)
+ (Also works if grid has batch size)
+ Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
+ """
+ batch_size, n_coords, n_dims = coordinates.shape
+ sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
+ coordinates.reshape(batch_size, 1, 1, -1, n_dims),
+ mode='bilinear', padding_mode='zeros', align_corners=False)
+ N, C, H, W, D = sampled_features.shape
+ sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
+ return sampled_features
+
+class FullyConnectedLayer(nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ # self.weight = torch.nn.Parameter(torch.full([out_features, in_features], np.float32(0)))
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+
+def positional_encoding(positions, freqs):
+ freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,)
+ pts = (positions[..., None] * freq_bands).reshape(
+ positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF)
+ pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
+ return pts
+
+# class TriPlane_Decoder(nn.Module):
+# def __init__(self, dim=12, width=128):
+# super().__init__()
+# self.net = torch.nn.Sequential(
+# FullyConnectedLayer(dim, width),
+# torch.nn.Softplus(),
+# FullyConnectedLayer(width, width),
+# torch.nn.Softplus(),
+# FullyConnectedLayer(width, 1 + 3)
+# )
+
+# def forward(self, sampled_features, viewdir):
+# sampled_features = sampled_features.mean(1)
+# x = sampled_features
+
+# N, M, C = x.shape
+# x = x.view(N*M, C)
+
+# x = self.net(x)
+# x = x.view(N, M, -1)
+# rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+# sigma = x[..., 0:1]
+# return {'rgb': rgb, 'sigma': sigma}
+
+class TriPlane_Decoder(nn.Module):
+ def __init__(self, dim=12, width=128):
+ super().__init__()
+ self.net = torch.nn.Sequential(
+ FullyConnectedLayer(dim, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, 1 + 3)
+ )
+
+ # def forward(self, sampled_features, viewdir):
+ # #ipdb.set_trace()
+ # sampled_features = sampled_features.mean(1)
+ # x = sampled_features
+
+ # N, M, C = x.shape
+ # x = x.view(N*M, C)
+
+ # x = self.net(x)
+ # x = x.view(N, M, -1)
+ # rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ # sigma = x[..., 0:1]
+ # return {'rgb': rgb, 'sigma': sigma}
+ def forward(self, sampled_features, viewdir):
+ M = sampled_features.shape[-2]
+ batch_size = 256 * 256
+ num_batches = M // batch_size
+ if num_batches * batch_size < M:
+ num_batches += 1
+ res = {
+ 'rgb': [],
+ 'sigma': [],
+ }
+
+ for b in range(num_batches):
+ p = b * batch_size
+ b_sampled_features = sampled_features[:, :, p:p+batch_size]
+ b_res = self._forward(b_sampled_features)
+ res['rgb'].append(b_res['rgb'])
+ res['sigma'].append(b_res['sigma'])
+ res['rgb'] = torch.cat(res['rgb'], -2)
+ res['sigma'] = torch.cat(res['sigma'], -2)
+
+ return res
+
+ def _forward(self, sampled_features):
+ # N, _, M, C = sampled_features.shape
+ sampled_features = sampled_features.mean(1)
+ x = sampled_features
+ N, M, C = x.shape
+ x = x.view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+
+ # assert self.sigma_dim + self.c_dim == C
+ # sigma_features = sampled_features[..., :self.sigma_dim]
+ # rgb_features = sampled_features[..., -self.c_dim:]
+ # sigma_features = sigma_features.permute(0, 2, 1, 3).reshape(N * M, self.sigma_dim * 3)
+ # rgb_features = rgb_features.permute(0, 2, 1, 3).reshape(N * M, self.c_dim * 3)
+
+ # x = torch.cat([self.sigmanet(sigma_features), self.rgbnet(rgb_features)], -1)
+ # x = x.view(N, M, -1)
+ # rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ # sigma = x[..., 0:1]
+ return {'rgb': rgb, 'sigma': sigma}
+
+class TriPlane_Decoder_PE(nn.Module):
+ def __init__(self, dim=12, width=128, viewpe=2, feape=2):
+ super().__init__()
+ assert viewpe > 0 and feape > 0
+ self.viewpe = viewpe
+ self.feape = feape
+ # self.densitynet = torch.nn.Sequential(
+ # FullyConnectedLayer(dim + 2*feape*dim, width),
+ # torch.nn.Softplus()
+ # )
+ # self.densityout = FullyConnectedLayer(width, 1)
+ # self.rgbnet = torch.nn.Sequential(
+ # FullyConnectedLayer(width + 3 + 2 * viewpe * 3, width),
+ # torch.nn.Softplus(),
+ # FullyConnectedLayer(width, 3)
+ # )
+ self.net = torch.nn.Sequential(
+ FullyConnectedLayer(dim+2*feape*dim+3+2*viewpe*3, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, 1 + 3)
+ )
+
+ def forward(self, sampled_features,viewdir):
+ sampled_features = sampled_features.mean(1)
+ x = sampled_features
+ N, M, C = x.shape
+ x = x.view(N*M, C)
+ viewdir = viewdir.view(N*M, 3)
+ x_pe = positional_encoding(x, self.feape)
+ viewdir_pe = positional_encoding(viewdir, self.viewpe)
+
+ x = torch.cat([x, x_pe, viewdir, viewdir_pe], -1)
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+ # layer1 = self.densitynet(torch.cat([x, x_pe], -1))
+ # sigma = self.densityout(layer1).view(N, M, 1)
+ # rgb = self.rgbnet(torch.cat([layer1, viewdir, viewdir_pe], -1)).view(N, M, -1)
+ # rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001
+ return {'rgb': rgb, 'sigma': sigma}
+
+class TriPlane_Decoder_Decompose(nn.Module):
+ def __init__(self, sigma_dim=12, c_dim=12, width=128):
+ super().__init__()
+ self.rgbnet = torch.nn.Sequential(
+ FullyConnectedLayer(c_dim * 3, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, 3)
+ )
+ self.sigmanet = torch.nn.Sequential(
+ FullyConnectedLayer(sigma_dim * 3, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, width),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(width, 1)
+ )
+ self.sigma_dim = sigma_dim
+ self.c_dim = c_dim
+
+ def forward(self, sampled_features, viewdir):
+ M = sampled_features.shape[-2]
+ batch_size = 256 * 256
+ num_batches = M // batch_size
+ if num_batches * batch_size < M:
+ num_batches += 1
+ res = {
+ 'rgb': [],
+ 'sigma': [],
+ }
+
+ for b in range(num_batches):
+ p = b * batch_size
+ b_sampled_features = sampled_features[:, :, p:p+batch_size]
+ b_res = self._forward(b_sampled_features)
+ res['rgb'].append(b_res['rgb'])
+ res['sigma'].append(b_res['sigma'])
+ res['rgb'] = torch.cat(res['rgb'], -2)
+ res['sigma'] = torch.cat(res['sigma'], -2)
+
+ return res
+
+ def _forward(self, sampled_features):
+ N, _, M, C = sampled_features.shape
+ assert self.sigma_dim + self.c_dim == C
+ sigma_features = sampled_features[..., :self.sigma_dim]
+ rgb_features = sampled_features[..., -self.c_dim:]
+ sigma_features = sigma_features.permute(0, 2, 1, 3).reshape(N * M, self.sigma_dim * 3)
+ rgb_features = rgb_features.permute(0, 2, 1, 3).reshape(N * M, self.c_dim * 3)
+
+ x = torch.cat([self.sigmanet(sigma_features), self.rgbnet(rgb_features)], -1)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+ return {'rgb': rgb, 'sigma': sigma}
+
+class Renderer_TriPlane(nn.Module):
+ # def __init__(self, rgbnet_dim=18, rgbnet_width=128, viewpe=0, feape=0):
+ # super(Renderer_TriPlane, self).__init__()
+ # if viewpe > 0 or feape > 0:
+ # self.decoder = TriPlane_Decoder_PE(dim=rgbnet_dim//3, width=rgbnet_width, viewpe=viewpe, feape=feape)
+ # else:
+ # self.decoder = TriPlane_Decoder(dim=rgbnet_dim//3, width=rgbnet_width)
+ # self.ray_marcher = MipRayMarcher2()
+ # self.plane_axes = generate_planes()
+
+ def __init__(self, rgbnet_dim=18, rgbnet_width=128, viewpe=0, feape=0, sigma_dim=0, c_dim=0):
+ super(Renderer_TriPlane, self).__init__()
+ if viewpe > 0 and feape > 0:
+ self.decoder = TriPlane_Decoder_PE(dim=rgbnet_dim//3, width=rgbnet_width, viewpe=viewpe, feape=feape)
+ elif sigma_dim > 0 and c_dim > 0:
+ self.decoder = TriPlane_Decoder_Decompose(sigma_dim=sigma_dim, c_dim=c_dim, width=rgbnet_width)
+ else:
+ self.decoder = TriPlane_Decoder(dim=rgbnet_dim, width=rgbnet_width)
+ self.ray_marcher = MipRayMarcher2()
+ self.plane_axes = generate_planes()
+
+ def forward(self, planes, ray_origins, ray_directions, rendering_options, whole_img=False, tvloss=False):
+ self.plane_axes = self.plane_axes.to(ray_origins.device)
+
+ ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp'])
+ is_ray_valid = ray_end > ray_start
+ if torch.any(is_ray_valid).item():
+ ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
+ ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
+ depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'],
+ rendering_options['det'])
+
+ batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
+
+ # Coarse Pass
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths_coarse * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
+
+
+ out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options)
+ colors_coarse = out['rgb']
+ densities_coarse = out['sigma']
+ colors_coarse = colors_coarse.reshape(batch_size, num_rays, samples_per_ray, colors_coarse.shape[-1])
+ densities_coarse = densities_coarse.reshape(batch_size, num_rays, samples_per_ray, 1)
+
+ # Fine Pass
+ N_importance = rendering_options['depth_resolution_importance']
+ if N_importance > 0:
+ _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+ depths_fine = self.sample_importance(depths_coarse, weights, N_importance, rendering_options['det'])
+
+ sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
+ sample_coordinates = (ray_origins.unsqueeze(-2) + depths_fine * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
+
+ out = self.run_model(planes, self.decoder, sample_coordinates, sample_directions, rendering_options)
+ colors_fine = out['rgb']
+ densities_fine = out['sigma']
+ colors_fine = colors_fine.reshape(batch_size, num_rays, N_importance, colors_fine.shape[-1])
+ densities_fine = densities_fine.reshape(batch_size, num_rays, N_importance, 1)
+
+ all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
+ depths_fine, colors_fine, densities_fine)
+
+ # Aggregate
+ rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options)
+ else:
+ rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options)
+
+
+ if tvloss:
+ initial_coordinates = torch.rand((batch_size, 1000, 3), device=planes.device) * 2 - 1
+ perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * 0.004
+ all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
+ projected_coordinates = project_onto_planes(self.plane_axes, all_coordinates).unsqueeze(1)
+ N, n_planes, C, H, W = planes.shape
+ _, M, _ = all_coordinates.shape
+ planes = planes.view(N*n_planes, C, H, W)
+ output_features = torch.nn.functional.grid_sample(planes, projected_coordinates.float(), mode='bilinear', padding_mode='zeros', align_corners=False).permute(0, 3, 2, 1).reshape(batch_size, n_planes, M, C)
+ sigma = self.decoder(output_features)['sigma']
+ sigma_initial = sigma[:, :sigma.shape[1]//2]
+ sigma_perturbed = sigma[:, sigma.shape[1]//2:]
+ TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed)
+ else:
+ TVloss = None
+
+ # return rgb_final, depth_final, weights.sum(2)
+ if whole_img:
+ H = W = int(ray_origins.shape[1] ** 0.5)
+ rgb_final = rgb_final.permute(0, 2, 1).reshape(-1, 3, H, W).contiguous()
+ depth_final = depth_final.permute(0, 2, 1).reshape(-1, 1, H, W).contiguous()
+ depth_final = (depth_final - depth_final.min()) / (depth_final.max() - depth_final.min())
+ depth_final = depth_final.repeat(1, 3, 1, 1)
+ # rgb_final = torch.clip(rgb_final, min=0, max=1)
+ rgb_final = (rgb_final + 1) / 2.
+ weights = weights.sum(2).reshape(rgb_final.shape[0], rgb_final.shape[2], rgb_final.shape[3])
+ return {
+ 'rgb_marched': rgb_final,
+ 'depth_final': depth_final,
+ 'weights': weights,
+ 'tvloss': TVloss,
+ }
+ else:
+ rgb_final = (rgb_final + 1) / 2.
+ return {
+ 'rgb_marched': rgb_final,
+ 'depth_final': depth_final,
+ 'tvloss': TVloss,
+ }
+
+ def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
+ sampled_features = sample_from_planes(self.plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
+
+ out = decoder(sampled_features, sample_directions)
+ if options.get('density_noise', 0) > 0:
+ out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
+ return out
+
+ def sort_samples(self, all_depths, all_colors, all_densities):
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+ return all_depths, all_colors, all_densities
+
+ def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
+ all_depths = torch.cat([depths1, depths2], dim = -2)
+ all_colors = torch.cat([colors1, colors2], dim = -2)
+ all_densities = torch.cat([densities1, densities2], dim = -2)
+
+ _, indices = torch.sort(all_depths, dim=-2)
+ all_depths = torch.gather(all_depths, -2, indices)
+ all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
+ all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
+
+ return all_depths, all_colors, all_densities
+
+ def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False, det=False):
+ """
+ Return depths of approximately uniformly spaced samples along rays.
+ """
+ N, M, _ = ray_origins.shape
+ if disparity_space_sampling:
+ depths_coarse = torch.linspace(0,
+ 1,
+ depth_resolution,
+ device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = 1/(depth_resolution - 1)
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+ depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
+ else:
+ if type(ray_start) == torch.Tensor:
+ depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
+ depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
+ if det:
+ depths_coarse += 0.5 * depth_delta[..., None]
+ else:
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
+ else:
+ depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
+ depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
+ if det:
+ depths_coarse += 0.5 * depth_delta
+ else:
+ depths_coarse += torch.rand_like(depths_coarse) * depth_delta
+
+ return depths_coarse
+
+ def sample_importance(self, z_vals, weights, N_importance, det=False):
+ """
+ Return depths of importance sampled points along rays. See NeRF importance sampling for more.
+ """
+ with torch.no_grad():
+ batch_size, num_rays, samples_per_ray, _ = z_vals.shape
+
+ z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
+ weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
+
+ # smooth weights
+ weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
+ weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
+ weights = weights + 0.01
+
+ z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
+ importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
+ N_importance, det=det).detach().reshape(batch_size, num_rays, N_importance, 1)
+ return importance_z_vals
+
+ def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
+ """
+ Sample @N_importance samples from @bins with distribution defined by @weights.
+ Inputs:
+ bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
+ weights: (N_rays, N_samples_)
+ N_importance: the number of samples to draw from the distribution
+ det: deterministic or not
+ eps: a small number to prevent division by zero
+ Outputs:
+ samples: the sampled samples
+ """
+ N_rays, N_samples_ = weights.shape
+ weights = weights + eps # prevent division by zero (don't do inplace op!)
+ pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
+ cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
+ cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
+ # padded to 0~1 inclusive
+
+ if det:
+ u = torch.linspace(0, 1, N_importance, device=bins.device)
+ u = u.expand(N_rays, N_importance)
+ else:
+ u = torch.rand(N_rays, N_importance, device=bins.device)
+ u = u.contiguous()
+
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.clamp_min(inds-1, 0)
+ above = torch.clamp_max(inds, N_samples_)
+
+ inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
+ cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
+ bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
+
+ denom = cdf_g[...,1]-cdf_g[...,0]
+ denom[denom 0.:
+ # get intervals between samples
+ mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
+ upper = torch.cat([mids, z_vals[...,-1:]], -1)
+ lower = torch.cat([z_vals[...,:1], mids], -1)
+ # stratified samples in those intervals
+ t_rand = torch.rand(z_vals.shape)
+
+ # Pytest, overwrite u with numpy's fixed random numbers
+ if pytest:
+ np.random.seed(0)
+ t_rand = np.random.rand(*list(z_vals.shape))
+ t_rand = torch.Tensor(t_rand)
+
+ z_vals = lower + (upper - lower) * t_rand
+
+ pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
+
+
+# raw = run_network(pts)
+ raw = network_query_fn(pts, viewdirs, label,network_fn)
+ rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
+
+
+ ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
+ if retraw:
+ ret['raw'] = raw
+ if N_importance > 0:
+ ret['rgb0'] = rgb_map_0
+ ret['disp0'] = disp_map_0
+ ret['acc0'] = acc_map_0
+ ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
+
+ for k in ret:
+ if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()):
+ print(f"! [Numerical Error] {k} contains nan or inf.")
+
+ return ret
+
+def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
+ """Transforms model's predictions to semantically meaningful values.
+ Args:
+ raw: [num_rays, num_samples along ray, 4]. Prediction from model.
+ z_vals: [num_rays, num_samples along ray]. Integration time.
+ rays_d: [num_rays, 3]. Direction of each ray.
+ Returns:
+ rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
+ disp_map: [num_rays]. Disparity map. Inverse of depth map.
+ acc_map: [num_rays]. Sum of weights along each ray.
+ weights: [num_rays, num_samples]. Weights assigned to each sampled color.
+ depth_map: [num_rays]. Estimated distance to object.
+ """
+ #ipdb.set_trace()
+ act_ff=nn.Softplus()
+
+ raw2alpha = lambda raw, dists, act_fn=act_ff: 1.-torch.exp(-act_fn(raw)*dists)
+
+ dists = z_vals[...,1:] - z_vals[...,:-1]
+ dists = torch.cat([dists, torch.Tensor([1e10]).to(dists.device).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
+
+ dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
+
+ rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
+ noise = 0.
+ if raw_noise_std > 0.:
+ noise = torch.randn(raw[...,3].shape) * raw_noise_std
+
+ # Overwrite randomly sampled data if pytest
+ if pytest:
+ np.random.seed(0)
+ noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
+ noise = torch.Tensor(noise)
+ #ipdb.set_trace()
+ alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples]
+
+ #ipdb.set_trace()
+ weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0],alpha.shape[1], 1)).to(alpha.device), 1.-alpha + 1e-10], -1), -1)[:,:, :-1]
+ rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
+
+ depth_map = torch.sum(weights * z_vals, -1)
+ disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
+ acc_map = torch.sum(weights, -1)
+
+ if white_bkgd:
+ rgb_map = rgb_map + (1.-acc_map[...,None])
+
+ return rgb_map, disp_map, acc_map, weights, depth_map
+
+
+def get_rays(H, W, K, c2w):
+ i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
+ i = i.t()
+ j = j.t()
+ dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
+ # Rotate ray directions from camera frame to the world frame
+ rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
+ rays_o = c2w[:3,-1].expand(rays_d.shape)
+ return rays_o, rays_d
diff --git a/app.py b/app.py
index fde396687d61ef4293c054f93797c08a8999f5b5..38997b1fc168785c5f63bf22b842939aba5de306 100644
--- a/app.py
+++ b/app.py
@@ -1,43 +1,25 @@
-print("Start importing everything.")
-
import pytorch_lightning as pl
-print("pl")
-
import os
import sys
import cv2
import time
import json
-print("Start importing everything.")
import torch
-print("Start importing everything.")
import mcubes
-print("Start importing everything.")
import trimesh
-print("Start importing everything.")
import datetime
import argparse
import subprocess
import numpy as np
-print("Start importing everything.")
import gradio as gr
from tqdm import tqdm
-print("Start importing everything.")
import imageio.v2 as imageio
-print("Start importing everything.")
from omegaconf import OmegaConf
-print("Start importing everything.")
from safetensors.torch import load_file
-print("Start importing everything.")
from huggingface_hub import hf_hub_download
-print("Importing everything done.")
-
-os.system("git clone https://github.com/3DTopia/3DTopia.git")
sys.path.append("3DTopia")
-print("Github clone done.")
-
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler