Jahnavibh commited on
Commit
74db07f
·
1 Parent(s): e2ab956

Upload 133 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. One-2-3-45-master 2/.DS_Store +0 -0
  2. One-2-3-45-master 2/.gitattributes +35 -0
  3. One-2-3-45-master 2/.gitignore +11 -0
  4. One-2-3-45-master 2/LICENSE +201 -0
  5. One-2-3-45-master 2/README.md +221 -0
  6. One-2-3-45-master 2/configs/sd-objaverse-finetune-c_concat-256.yaml +117 -0
  7. One-2-3-45-master 2/demo/.DS_Store +0 -0
  8. One-2-3-45-master 2/demo/.gitattributes +36 -0
  9. One-2-3-45-master 2/demo/.gitignore +4 -0
  10. One-2-3-45-master 2/demo/app.py +639 -0
  11. One-2-3-45-master 2/demo/demo_tmp/.gitignore +1 -0
  12. One-2-3-45-master 2/demo/demo_tmp/.gitkeep +0 -0
  13. One-2-3-45-master 2/demo/instructions_12345.md +10 -0
  14. One-2-3-45-master 2/demo/memora/.gitattributes +35 -0
  15. One-2-3-45-master 2/demo/memora/README.md +12 -0
  16. One-2-3-45-master 2/demo/style.css +33 -0
  17. One-2-3-45-master 2/download_ckpt.py +30 -0
  18. One-2-3-45-master 2/elevation_estimate/.gitignore +3 -0
  19. One-2-3-45-master 2/elevation_estimate/__init__.py +0 -0
  20. One-2-3-45-master 2/elevation_estimate/estimate_wild_imgs.py +10 -0
  21. One-2-3-45-master 2/elevation_estimate/loftr/__init__.py +2 -0
  22. One-2-3-45-master 2/elevation_estimate/loftr/backbone/__init__.py +11 -0
  23. One-2-3-45-master 2/elevation_estimate/loftr/backbone/resnet_fpn.py +199 -0
  24. One-2-3-45-master 2/elevation_estimate/loftr/loftr.py +81 -0
  25. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/__init__.py +2 -0
  26. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/fine_preprocess.py +59 -0
  27. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/linear_attention.py +81 -0
  28. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/transformer.py +101 -0
  29. One-2-3-45-master 2/elevation_estimate/loftr/utils/coarse_matching.py +261 -0
  30. One-2-3-45-master 2/elevation_estimate/loftr/utils/cvpr_ds_config.py +50 -0
  31. One-2-3-45-master 2/elevation_estimate/loftr/utils/fine_matching.py +74 -0
  32. One-2-3-45-master 2/elevation_estimate/loftr/utils/geometry.py +54 -0
  33. One-2-3-45-master 2/elevation_estimate/loftr/utils/position_encoding.py +42 -0
  34. One-2-3-45-master 2/elevation_estimate/loftr/utils/supervision.py +151 -0
  35. One-2-3-45-master 2/elevation_estimate/pyproject.toml +7 -0
  36. One-2-3-45-master 2/elevation_estimate/utils/__init__.py +0 -0
  37. One-2-3-45-master 2/elevation_estimate/utils/elev_est_api.py +205 -0
  38. One-2-3-45-master 2/elevation_estimate/utils/plotting.py +154 -0
  39. One-2-3-45-master 2/elevation_estimate/utils/plt_utils.py +318 -0
  40. One-2-3-45-master 2/elevation_estimate/utils/utils3d.py +62 -0
  41. One-2-3-45-master 2/elevation_estimate/utils/weights/.gitkeep +0 -0
  42. One-2-3-45-master 2/example.ipynb +0 -0
  43. One-2-3-45-master 2/ldm/data/__init__.py +0 -0
  44. One-2-3-45-master 2/ldm/data/base.py +40 -0
  45. One-2-3-45-master 2/ldm/data/coco.py +253 -0
  46. One-2-3-45-master 2/ldm/data/dummy.py +34 -0
  47. One-2-3-45-master 2/ldm/data/imagenet.py +394 -0
  48. One-2-3-45-master 2/ldm/data/inpainting/__init__.py +0 -0
  49. One-2-3-45-master 2/ldm/data/inpainting/synthetic_mask.py +166 -0
  50. One-2-3-45-master 2/ldm/data/laion.py +537 -0
One-2-3-45-master 2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
One-2-3-45-master 2/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
One-2-3-45-master 2/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ exp/
3
+ src/
4
+ *.DS_Store
5
+ *.ipynb
6
+ *.egg-info/
7
+ *.ckpt
8
+ *.pth
9
+
10
+ !example.ipynb
11
+ !reconstruction/exp
One-2-3-45-master 2/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
One-2-3-45-master 2/README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center" width="100%">
2
+ <img src="https://github.com/Dustinpro/Dustinpro/assets/23076389/0fbdb69a-0fb4-4b42-b9da-e0b28532bdfd" width="80%" height="80%">
3
+ </p>
4
+
5
+
6
+ <p align="center">
7
+ [<a href="https://arxiv.org/pdf/2306.16928.pdf"><strong>Paper</strong></a>]
8
+ [<a href="http://one-2-3-45.com"><strong>Project</strong></a>]
9
+ [<a href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45"><strong>Demo</strong></a>]
10
+ [<a href="#citation"><strong>BibTeX</strong></a>]
11
+ </p>
12
+
13
+ <p align="center">
14
+ <a href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45">
15
+ <img alt="Hugging Face Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space_of_the_Week_%F0%9F%94%A5-blue">
16
+ </a>
17
+ </p>
18
+
19
+ One-2-3-45 rethinks how to leverage 2D diffusion models for 3D AIGC and introduces a novel forward-only paradigm that avoids the time-consuming optimization.
20
+
21
+ https://github.com/One-2-3-45/One-2-3-45/assets/16759292/a81d6e32-8d29-43a5-b044-b5112b9f9664
22
+
23
+
24
+
25
+ https://github.com/One-2-3-45/One-2-3-45/assets/16759292/5ecd45ef-8fd3-4643-af4c-fac3050a0428
26
+
27
+
28
+ ## News
29
+ **[09/21/2023]**
30
+ One-2-3-45 is accepted by NeurIPS 2023. See you in New Orleans!
31
+
32
+ **[09/11/2023]**
33
+ Training code released.
34
+
35
+ **[08/18/2023]**
36
+ Inference code released.
37
+
38
+ **[07/24/2023]**
39
+ Our demo reached the HuggingFace top 4 trending and was featured in 🤗 Spaces of the Week 🔥! Special thanks to HuggingFace 🤗 for sponsoring this demo!!
40
+
41
+ **[07/11/2023]**
42
+ [Online interactive demo](https://huggingface.co/spaces/One-2-3-45/One-2-3-45) released! Explore it and create your own 3D models in just 45 seconds!
43
+
44
+ **[06/29/2023]**
45
+ Check out our [paper](https://arxiv.org/pdf/2306.16928.pdf). [[X](https://twitter.com/_akhaliq/status/1674617785119305728)]
46
+
47
+ ## Installation
48
+ Hardware requirement: an NVIDIA GPU with memory >=18GB (_e.g._, RTX 3090 or A10). Tested on Ubuntu.
49
+
50
+ We offer two ways to setup the environment:
51
+
52
+ ### Traditional Installation
53
+ <details>
54
+ <summary>Step 1: Install Debian packages. </summary>
55
+
56
+ ```bash
57
+ sudo apt update && sudo apt install git-lfs libsparsehash-dev build-essential
58
+ ```
59
+ </details>
60
+
61
+ <details>
62
+ <summary>Step 2: Create and activate a conda environment. </summary>
63
+
64
+ ```bash
65
+ conda create -n One2345 python=3.10
66
+ conda activate One2345
67
+ ```
68
+ </details>
69
+
70
+ <details>
71
+ <summary>Step 3: Clone the repository to the local machine. </summary>
72
+
73
+ ```bash
74
+ # Make sure you have git-lfs installed.
75
+ git lfs install
76
+ git clone https://github.com/One-2-3-45/One-2-3-45
77
+ cd One-2-3-45
78
+ ```
79
+ </details>
80
+
81
+ <details>
82
+ <summary>Step 4: Install project dependencies using pip. </summary>
83
+
84
+ ```bash
85
+ # Ensure that the installed CUDA version matches the torch's cuda version.
86
+ # Example: CUDA 11.8 installation
87
+ wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
88
+ sudo sh cuda_11.8.0_520.61.05_linux.run
89
+ export PATH="/usr/local/cuda-11.8/bin:$PATH"
90
+ export LD_LIBRARY_PATH="/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH"
91
+ # Install PyTorch 2.0
92
+ pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
93
+ # Install dependencies
94
+ pip install -r requirements.txt
95
+ # Install inplace_abn and torchsparse
96
+ export TORCH_CUDA_ARCH_LIST="7.0;7.2;8.0;8.6+PTX" # CUDA architectures. Modify according to your hardware.
97
+ export IABN_FORCE_CUDA=1
98
+ pip install inplace_abn
99
+ FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
100
+ ```
101
+ </details>
102
+
103
+ <details>
104
+ <summary>Step 5: Download model checkpoints. </summary>
105
+
106
+ ```bash
107
+ python download_ckpt.py
108
+ ```
109
+ </details>
110
+
111
+
112
+ ### Installation by Docker Images
113
+ <details>
114
+ <summary>Option 1: Pull and Play (environment and checkpoints). (~22.3G)</summary>
115
+
116
+ ```bash
117
+ # Pull the Docker image that contains the full repository.
118
+ docker pull chaoxu98/one2345:demo_1.0
119
+ # An interactive demo will be launched automatically upon running the container.
120
+ # This will provide a public URL like XXXXXXX.gradio.live
121
+ docker run --name One-2-3-45_demo --gpus all -it chaoxu98/one2345:demo_1.0
122
+ ```
123
+ </details>
124
+
125
+ <details>
126
+ <summary>Option 2: Environment Only. (~7.3G)</summary>
127
+
128
+ ```bash
129
+ # Pull the Docker image that installed all project dependencies.
130
+ docker pull chaoxu98/one2345:1.0
131
+ # Start a Docker container named One2345.
132
+ docker run --name One-2-3-45 --gpus all -it chaoxu98/one2345:1.0
133
+ # Get a bash shell in the container.
134
+ docker exec -it One-2-3-45 /bin/bash
135
+ # Clone the repository to the local machine.
136
+ git clone https://github.com/One-2-3-45/One-2-3-45
137
+ cd One-2-3-45
138
+ # Download model checkpoints.
139
+ python download_ckpt.py
140
+ # Refer to getting started for inference.
141
+ ```
142
+ </details>
143
+
144
+ ## Getting Started (Inference)
145
+
146
+ First-time running will take longer time to compile the models.
147
+
148
+ Expected time cost per image: 40s on an NVIDIA A6000.
149
+ ```bash
150
+ # 1. Script
151
+ python run.py --img_path PATH_TO_INPUT_IMG --half_precision
152
+
153
+ # 2. Interactive demo (Gradio) with a friendly web interface
154
+ # An URL will be provided in the output
155
+ # (Local: 127.0.0.1:7860; Public: XXXXXXX.gradio.live)
156
+ cd demo/
157
+ python app.py
158
+
159
+ # 3. Jupyter Notebook
160
+ example.ipynb
161
+ ```
162
+
163
+ ## Training Your Own Model
164
+
165
+ ### Data Preparation
166
+ We use Objaverse-LVIS dataset for training and render the selected shapes (with CC-BY license) into 2D images with Blender.
167
+ #### Download the training images.
168
+ Download all One2345.zip.part-* files (5 files in total) from <a href="https://huggingface.co/datasets/One-2-3-45/training_data/tree/main">here</a> and then cat them into a single .zip file using the following command:
169
+ ```bash
170
+ cat One2345.zip.part-* > One2345.zip
171
+ ```
172
+
173
+ #### Unzip the training images zip file.
174
+ Unzip the zip file into a folder specified by yourself (`YOUR_BASE_FOLDER`) with the following command:
175
+
176
+ ```bash
177
+ unzip One2345.zip -d YOUR_BASE_FOLDER
178
+ ```
179
+
180
+ #### Download meta files.
181
+
182
+ Download `One2345_training_pose.json` and `lvis_split_cc_by.json` from <a href="https://huggingface.co/datasets/One-2-3-45/training_data/tree/main">here</a> and put them into the same folder as the training images (`YOUR_BASE_FOLDER`).
183
+
184
+ Your file structure should look like this:
185
+ ```
186
+ # One2345 is your base folder used in the previous steps
187
+
188
+ One2345
189
+ ├── One2345_training_pose.json
190
+ ├── lvis_split_cc_by.json
191
+ └── zero12345_narrow
192
+ ├── 000-000
193
+ ├── 000-001
194
+ ├── 000-002
195
+ ...
196
+ └── 000-159
197
+
198
+ ```
199
+
200
+ ### Training
201
+ Specify the `trainpath`, `valpath`, and `testpath` in the config file `./reconstruction/confs/one2345_lod_train.conf` to be `YOUR_BASE_FOLDER` used in data preparation steps and run the following command:
202
+ ```bash
203
+ cd reconstruction
204
+ python exp_runner_generic_blender_train.py --mode train --conf confs/one2345_lod_train.conf
205
+ ```
206
+ Experiment logs and checkpoints will be saved in `./reconstruction/exp/`.
207
+
208
+ ## Citation
209
+
210
+ If you find our code helpful, please cite our paper:
211
+
212
+ ```
213
+ @misc{liu2023one2345,
214
+ title={One-2-3-45: Any Single Image to 3D Mesh in 45 Seconds without Per-Shape Optimization},
215
+ author={Minghua Liu and Chao Xu and Haian Jin and Linghao Chen and Mukund Varma T and Zexiang Xu and Hao Su},
216
+ year={2023},
217
+ eprint={2306.16928},
218
+ archivePrefix={arXiv},
219
+ primaryClass={cs.CV}
220
+ }
221
+ ```
One-2-3-45-master 2/configs/sd-objaverse-finetune-c_concat-256.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "image_target"
11
+ cond_stage_key: "image_cond"
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: hybrid
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+
19
+ scheduler_config: # 10000 warmup steps
20
+ target: ldm.lr_scheduler.LambdaLinearScheduler
21
+ params:
22
+ warm_up_steps: [ 100 ]
23
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
24
+ f_start: [ 1.e-6 ]
25
+ f_max: [ 1. ]
26
+ f_min: [ 1. ]
27
+
28
+ unet_config:
29
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
30
+ params:
31
+ image_size: 32 # unused
32
+ in_channels: 8
33
+ out_channels: 4
34
+ model_channels: 320
35
+ attention_resolutions: [ 4, 2, 1 ]
36
+ num_res_blocks: 2
37
+ channel_mult: [ 1, 2, 4, 4 ]
38
+ num_heads: 8
39
+ use_spatial_transformer: True
40
+ transformer_depth: 1
41
+ context_dim: 768
42
+ use_checkpoint: True
43
+ legacy: False
44
+
45
+ first_stage_config:
46
+ target: ldm.models.autoencoder.AutoencoderKL
47
+ params:
48
+ embed_dim: 4
49
+ monitor: val/rec_loss
50
+ ddconfig:
51
+ double_z: true
52
+ z_channels: 4
53
+ resolution: 256
54
+ in_channels: 3
55
+ out_ch: 3
56
+ ch: 128
57
+ ch_mult:
58
+ - 1
59
+ - 2
60
+ - 4
61
+ - 4
62
+ num_res_blocks: 2
63
+ attn_resolutions: []
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
67
+
68
+ cond_stage_config:
69
+ target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
70
+
71
+
72
+ data:
73
+ target: ldm.data.simple.ObjaverseDataModuleFromConfig
74
+ params:
75
+ root_dir: 'views_whole_sphere'
76
+ batch_size: 192
77
+ num_workers: 16
78
+ total_view: 4
79
+ train:
80
+ validation: False
81
+ image_transforms:
82
+ size: 256
83
+
84
+ validation:
85
+ validation: True
86
+ image_transforms:
87
+ size: 256
88
+
89
+
90
+ lightning:
91
+ find_unused_parameters: false
92
+ metrics_over_trainsteps_checkpoint: True
93
+ modelcheckpoint:
94
+ params:
95
+ every_n_train_steps: 5000
96
+ callbacks:
97
+ image_logger:
98
+ target: main.ImageLogger
99
+ params:
100
+ batch_frequency: 500
101
+ max_images: 32
102
+ increase_log_steps: False
103
+ log_first_step: True
104
+ log_images_kwargs:
105
+ use_ema_scope: False
106
+ inpaint: False
107
+ plot_progressive_rows: False
108
+ plot_diffusion_rows: False
109
+ N: 32
110
+ unconditional_guidance_scale: 3.0
111
+ unconditional_guidance_label: [""]
112
+
113
+ trainer:
114
+ benchmark: True
115
+ val_check_interval: 5000000 # really sorry
116
+ num_sanity_val_steps: 0
117
+ accumulate_grad_batches: 1
One-2-3-45-master 2/demo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
One-2-3-45-master 2/demo/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
One-2-3-45-master 2/demo/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ weights/
2
+ data/
3
+ *.ipynb
4
+ demo_examples_*
One-2-3-45-master 2/demo/app.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import torch
5
+ import fire
6
+ import gradio as gr
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ import plotly.graph_objects as go
11
+ from functools import partial
12
+ import trimesh
13
+ import tempfile
14
+ from rembg import remove
15
+
16
+ code_dir = "../"
17
+ sys.path.append(code_dir)
18
+ from utils.zero123_utils import init_model, predict_stage1_gradio, zero123_infer
19
+ from utils.sam_utils import sam_init, sam_out_nosave
20
+ from utils.utils import image_preprocess_nosave, gen_poses
21
+ from elevation_estimate.estimate_wild_imgs import estimate_elev
22
+
23
+ _GPU_INDEX = 0
24
+ _HALF_PRECISION = True
25
+ _MESH_RESOLUTION = 256
26
+
27
+ _TITLE = '''One-2-3-45: Any Single Image to 3D Mesh in 45 Seconds without Per-Shape Optimization'''
28
+ _DESCRIPTION = '''
29
+ <div>
30
+ <a style="display:inline-block" href="http://one-2-3-45.com"><img src="https://img.shields.io/badge/Project_Homepage-f9f7f7?logo="></a>
31
+ <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2306.16928"><img src="https://img.shields.io/badge/2306.16928-f9f7f7?logo="></a>
32
+ <a style="display:inline-block; margin-left: .5em" href='https://github.com/One-2-3-45/One-2-3-45'><img src='https://img.shields.io/github/stars/One-2-3-45/One-2-3-45?style=social' /></a>
33
+ </div>
34
+ We reconstruct a 3D textured mesh from a single image by initially predicting multi-view images and then lifting them to 3D.
35
+ '''
36
+ _USER_GUIDE = "Please upload an image in the block above (or choose an example above) and click **Run Generation**."
37
+ _BBOX_1 = "Predicting bounding box for the input image..."
38
+ _BBOX_2 = "Bounding box adjusted. Continue adjusting or **Run Generation**."
39
+ _BBOX_3 = "Bounding box predicted. Adjust it using sliders or **Run Generation**."
40
+ _SAM = "Preprocessing the input image... (safety check, SAM segmentation, *etc*.)"
41
+ _GEN_1 = "Predicting multi-view images... (may take \~13 seconds) <br> Images will be shown in the bottom right blocks."
42
+ _GEN_2 = "Predicting nearby views and generating mesh... (may take \~33 seconds) <br> Mesh will be shown on the right."
43
+ _DONE = "Done! Mesh is shown on the right. <br> If it is not satisfactory, please select **Retry view** checkboxes for inaccurate views and click **Regenerate selected view(s)** at the bottom."
44
+ _REGEN_1 = "Selected view(s) are regenerated. You can click **Regenerate nearby views and mesh**. <br> Alternatively, if the regenerated view(s) are still not satisfactory, you can repeat the previous step (select the view and regenerate)."
45
+ _REGEN_2 = "Regeneration done. Mesh is shown on the right."
46
+
47
+
48
+ def calc_cam_cone_pts_3d(polar_deg, azimuth_deg, radius_m, fov_deg):
49
+ '''
50
+ :param polar_deg (float).
51
+ :param azimuth_deg (float).
52
+ :param radius_m (float).
53
+ :param fov_deg (float).
54
+ :return (5, 3) array of float with (x, y, z).
55
+ '''
56
+ polar_rad = np.deg2rad(polar_deg)
57
+ azimuth_rad = np.deg2rad(azimuth_deg)
58
+ fov_rad = np.deg2rad(fov_deg)
59
+ polar_rad = -polar_rad # NOTE: Inverse of how used_x relates to x.
60
+
61
+ # Camera pose center:
62
+ cam_x = radius_m * np.cos(azimuth_rad) * np.cos(polar_rad)
63
+ cam_y = radius_m * np.sin(azimuth_rad) * np.cos(polar_rad)
64
+ cam_z = radius_m * np.sin(polar_rad)
65
+
66
+ # Obtain four corners of camera frustum, assuming it is looking at origin.
67
+ # First, obtain camera extrinsics (rotation matrix only):
68
+ camera_R = np.array([[np.cos(azimuth_rad) * np.cos(polar_rad),
69
+ -np.sin(azimuth_rad),
70
+ -np.cos(azimuth_rad) * np.sin(polar_rad)],
71
+ [np.sin(azimuth_rad) * np.cos(polar_rad),
72
+ np.cos(azimuth_rad),
73
+ -np.sin(azimuth_rad) * np.sin(polar_rad)],
74
+ [np.sin(polar_rad),
75
+ 0.0,
76
+ np.cos(polar_rad)]])
77
+
78
+ # Multiply by corners in camera space to obtain go to space:
79
+ corn1 = [-1.0, np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)]
80
+ corn2 = [-1.0, -np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)]
81
+ corn3 = [-1.0, -np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)]
82
+ corn4 = [-1.0, np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)]
83
+ corn1 = np.dot(camera_R, corn1)
84
+ corn2 = np.dot(camera_R, corn2)
85
+ corn3 = np.dot(camera_R, corn3)
86
+ corn4 = np.dot(camera_R, corn4)
87
+
88
+ # Now attach as offset to actual 3D camera position:
89
+ corn1 = np.array(corn1) / np.linalg.norm(corn1, ord=2)
90
+ corn_x1 = cam_x + corn1[0]
91
+ corn_y1 = cam_y + corn1[1]
92
+ corn_z1 = cam_z + corn1[2]
93
+ corn2 = np.array(corn2) / np.linalg.norm(corn2, ord=2)
94
+ corn_x2 = cam_x + corn2[0]
95
+ corn_y2 = cam_y + corn2[1]
96
+ corn_z2 = cam_z + corn2[2]
97
+ corn3 = np.array(corn3) / np.linalg.norm(corn3, ord=2)
98
+ corn_x3 = cam_x + corn3[0]
99
+ corn_y3 = cam_y + corn3[1]
100
+ corn_z3 = cam_z + corn3[2]
101
+ corn4 = np.array(corn4) / np.linalg.norm(corn4, ord=2)
102
+ corn_x4 = cam_x + corn4[0]
103
+ corn_y4 = cam_y + corn4[1]
104
+ corn_z4 = cam_z + corn4[2]
105
+
106
+ xs = [cam_x, corn_x1, corn_x2, corn_x3, corn_x4]
107
+ ys = [cam_y, corn_y1, corn_y2, corn_y3, corn_y4]
108
+ zs = [cam_z, corn_z1, corn_z2, corn_z3, corn_z4]
109
+
110
+ return np.array([xs, ys, zs]).T
111
+
112
+ class CameraVisualizer:
113
+ def __init__(self, gradio_plot):
114
+ self._gradio_plot = gradio_plot
115
+ self._fig = None
116
+ self._polar = 0.0
117
+ self._azimuth = 0.0
118
+ self._radius = 0.0
119
+ self._raw_image = None
120
+ self._8bit_image = None
121
+ self._image_colorscale = None
122
+
123
+ def encode_image(self, raw_image, elev=90):
124
+ '''
125
+ :param raw_image (H, W, 3) array of uint8 in [0, 255].
126
+ '''
127
+ # https://stackoverflow.com/questions/60685749/python-plotly-how-to-add-an-image-to-a-3d-scatter-plot
128
+
129
+ dum_img = Image.fromarray(np.ones((3, 3, 3), dtype='uint8')).convert('P', palette='WEB')
130
+ idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3))
131
+
132
+ self._raw_image = raw_image
133
+ self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None)
134
+ # self._8bit_image = Image.fromarray(raw_image.clip(0, 254)).convert(
135
+ # 'P', palette='WEB', dither=None)
136
+ self._image_colorscale = [
137
+ [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)]
138
+ self._elev = elev
139
+ # return self.update_figure()
140
+
141
+ def update_figure(self):
142
+ fig = go.Figure()
143
+
144
+ if self._raw_image is not None:
145
+ (H, W, C) = self._raw_image.shape
146
+
147
+ x = np.zeros((H, W))
148
+ (y, z) = np.meshgrid(np.linspace(-1.0, 1.0, W), np.linspace(1.0, -1.0, H) * H / W)
149
+
150
+ angle_deg = self._elev-90
151
+ angle = np.radians(90-self._elev)
152
+ rotation_matrix = np.array([
153
+ [np.cos(angle), 0, np.sin(angle)],
154
+ [0, 1, 0],
155
+ [-np.sin(angle), 0, np.cos(angle)]
156
+ ])
157
+ # Assuming x, y, z are the original 3D coordinates of the image
158
+ coordinates = np.stack((x, y, z), axis=-1) # Combine x, y, z into a single array
159
+ # Apply the rotation matrix
160
+ rotated_coordinates = np.matmul(coordinates, rotation_matrix)
161
+ # Extract the new x, y, z coordinates from the rotated coordinates
162
+ x, y, z = rotated_coordinates[..., 0], rotated_coordinates[..., 1], rotated_coordinates[..., 2]
163
+
164
+ fig.add_trace(go.Surface(
165
+ x=x, y=y, z=z,
166
+ surfacecolor=self._8bit_image,
167
+ cmin=0,
168
+ cmax=255,
169
+ colorscale=self._image_colorscale,
170
+ showscale=False,
171
+ lighting_diffuse=1.0,
172
+ lighting_ambient=1.0,
173
+ lighting_fresnel=1.0,
174
+ lighting_roughness=1.0,
175
+ lighting_specular=0.3))
176
+
177
+ scene_bounds = 3.5
178
+ base_radius = 2.5
179
+ zoom_scale = 1.5 # Note that input radius offset is in [-0.5, 0.5].
180
+ fov_deg = 50.0
181
+ edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1)]
182
+
183
+ input_cone = calc_cam_cone_pts_3d(
184
+ angle_deg, 0.0, base_radius, fov_deg) # (5, 3).
185
+ output_cone = calc_cam_cone_pts_3d(
186
+ self._polar, self._azimuth, base_radius + self._radius * zoom_scale, fov_deg) # (5, 3).
187
+ output_cones = []
188
+ for i in range(1,4):
189
+ output_cones.append(calc_cam_cone_pts_3d(
190
+ angle_deg, i*90, base_radius + self._radius * zoom_scale, fov_deg))
191
+ delta_deg = 30 if angle_deg <= -15 else -30
192
+ for i in range(4):
193
+ output_cones.append(calc_cam_cone_pts_3d(
194
+ angle_deg+delta_deg, 30+i*90, base_radius + self._radius * zoom_scale, fov_deg))
195
+
196
+ cones = [(input_cone, 'rgb(174, 54, 75)', 'Input view (Predicted view 1)')]
197
+ for i in range(len(output_cones)):
198
+ cones.append((output_cones[i], 'rgb(32, 77, 125)', f'Predicted view {i+2}'))
199
+
200
+ for idx, (cone, clr, legend) in enumerate(cones):
201
+
202
+ for (i, edge) in enumerate(edges):
203
+ (x1, x2) = (cone[edge[0], 0], cone[edge[1], 0])
204
+ (y1, y2) = (cone[edge[0], 1], cone[edge[1], 1])
205
+ (z1, z2) = (cone[edge[0], 2], cone[edge[1], 2])
206
+ fig.add_trace(go.Scatter3d(
207
+ x=[x1, x2], y=[y1, y2], z=[z1, z2], mode='lines',
208
+ line=dict(color=clr, width=3),
209
+ name=legend, showlegend=(i == 1) and (idx <= 1)))
210
+
211
+ # Add label.
212
+ if cone[0, 2] <= base_radius / 2.0:
213
+ fig.add_trace(go.Scatter3d(
214
+ x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] - 0.05], showlegend=False,
215
+ mode='text', text=legend, textposition='bottom center'))
216
+ else:
217
+ fig.add_trace(go.Scatter3d(
218
+ x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] + 0.05], showlegend=False,
219
+ mode='text', text=legend, textposition='top center'))
220
+
221
+ # look at center of scene
222
+ fig.update_layout(
223
+ # width=640,
224
+ # height=480,
225
+ # height=400,
226
+ height=450,
227
+ autosize=True,
228
+ hovermode=False,
229
+ margin=go.layout.Margin(l=0, r=0, b=0, t=0),
230
+ showlegend=False,
231
+ legend=dict(
232
+ yanchor='bottom',
233
+ y=0.01,
234
+ xanchor='right',
235
+ x=0.99,
236
+ ),
237
+ scene=dict(
238
+ aspectmode='manual',
239
+ aspectratio=dict(x=1, y=1, z=1.0),
240
+ camera=dict(
241
+ eye=dict(x=base_radius - 1.6, y=0.0, z=0.6),
242
+ center=dict(x=0.0, y=0.0, z=0.0),
243
+ up=dict(x=0.0, y=0.0, z=1.0)),
244
+ xaxis_title='',
245
+ yaxis_title='',
246
+ zaxis_title='',
247
+ xaxis=dict(
248
+ range=[-scene_bounds, scene_bounds],
249
+ showticklabels=False,
250
+ showgrid=True,
251
+ zeroline=False,
252
+ showbackground=True,
253
+ showspikes=False,
254
+ showline=False,
255
+ ticks=''),
256
+ yaxis=dict(
257
+ range=[-scene_bounds, scene_bounds],
258
+ showticklabels=False,
259
+ showgrid=True,
260
+ zeroline=False,
261
+ showbackground=True,
262
+ showspikes=False,
263
+ showline=False,
264
+ ticks=''),
265
+ zaxis=dict(
266
+ range=[-scene_bounds, scene_bounds],
267
+ showticklabels=False,
268
+ showgrid=True,
269
+ zeroline=False,
270
+ showbackground=True,
271
+ showspikes=False,
272
+ showline=False,
273
+ ticks='')))
274
+
275
+ self._fig = fig
276
+ return fig
277
+
278
+
279
+ def stage1_run(models, device, cam_vis, tmp_dir,
280
+ input_im, scale, ddim_steps, elev=None, rerun_all=[],
281
+ *btn_retrys):
282
+ is_rerun = True if cam_vis is None else False
283
+ model = models['turncam']
284
+
285
+ stage1_dir = os.path.join(tmp_dir, "stage1_8")
286
+ if not is_rerun:
287
+ os.makedirs(stage1_dir, exist_ok=True)
288
+ output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4)), device=device, ddim_steps=ddim_steps, scale=scale)
289
+ stage2_steps = 50 # ddim_steps
290
+ zero123_infer(model, tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
291
+ try:
292
+ elev_output = estimate_elev(tmp_dir)
293
+ except:
294
+ print("Failed to estimate polar angle")
295
+ elev_output = 90
296
+ print("Estimated polar angle:", elev_output)
297
+ gen_poses(tmp_dir, elev_output)
298
+ show_in_im1 = np.asarray(input_im, dtype=np.uint8)
299
+ cam_vis.encode_image(show_in_im1, elev=elev_output)
300
+ new_fig = cam_vis.update_figure()
301
+
302
+ flag_lower_cam = elev_output <= 75
303
+ if flag_lower_cam:
304
+ output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4,8)), device=device, ddim_steps=ddim_steps, scale=scale)
305
+ else:
306
+ output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(8,12)), device=device, ddim_steps=ddim_steps, scale=scale)
307
+ torch.cuda.empty_cache()
308
+ return (90-elev_output, new_fig, *output_ims, *output_ims_2)
309
+ else:
310
+ rerun_idx = [i for i in range(len(btn_retrys)) if btn_retrys[i]]
311
+ if 90-int(elev["label"]) > 75:
312
+ rerun_idx_in = [i if i < 4 else i+4 for i in rerun_idx]
313
+ else:
314
+ rerun_idx_in = rerun_idx
315
+ for idx in rerun_idx_in:
316
+ if idx not in rerun_all:
317
+ rerun_all.append(idx)
318
+ print("rerun_idx", rerun_all)
319
+ output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=rerun_idx_in, device=device, ddim_steps=ddim_steps, scale=scale)
320
+ outputs = [gr.update(visible=True)] * 8
321
+ for idx, view_idx in enumerate(rerun_idx):
322
+ outputs[view_idx] = output_ims[idx]
323
+ reset = [gr.update(value=False)] * 8
324
+ torch.cuda.empty_cache()
325
+ return (rerun_all, *reset, *outputs)
326
+
327
+ def stage2_run(models, device, tmp_dir,
328
+ elev, scale, is_glb=False, rerun_all=[], stage2_steps=50):
329
+ flag_lower_cam = 90-int(elev["label"]) <= 75
330
+ is_rerun = True if rerun_all else False
331
+ model = models['turncam']
332
+ if not is_rerun:
333
+ if flag_lower_cam:
334
+ zero123_infer(model, tmp_dir, indices=list(range(1,8)), device=device, ddim_steps=stage2_steps, scale=scale)
335
+ else:
336
+ zero123_infer(model, tmp_dir, indices=list(range(1,4))+list(range(8,12)), device=device, ddim_steps=stage2_steps, scale=scale)
337
+ else:
338
+ print("rerun_idx", rerun_all)
339
+ zero123_infer(model, tmp_dir, indices=rerun_all, device=device, ddim_steps=stage2_steps, scale=scale)
340
+
341
+ dataset = tmp_dir
342
+ main_dir_path = os.path.dirname(__file__)
343
+ torch.cuda.empty_cache()
344
+ os.chdir(os.path.join(code_dir, 'reconstruction/'))
345
+
346
+ bash_script = f'CUDA_VISIBLE_DEVICES={_GPU_INDEX} python exp_runner_generic_blender_val.py \
347
+ --specific_dataset_name {dataset} \
348
+ --mode export_mesh \
349
+ --conf confs/one2345_lod0_val_demo.conf \
350
+ --resolution {_MESH_RESOLUTION}'
351
+ print(bash_script)
352
+ os.system(bash_script)
353
+ os.chdir(main_dir_path)
354
+
355
+ ply_path = os.path.join(tmp_dir, f"mesh.ply")
356
+ mesh_ext = ".glb" if is_glb else ".obj"
357
+ mesh_path = os.path.join(tmp_dir, f"mesh{mesh_ext}")
358
+ # Read the textured mesh from .ply file
359
+ mesh = trimesh.load_mesh(ply_path)
360
+ rotation_matrix = trimesh.transformations.rotation_matrix(np.pi/2, [1, 0, 0])
361
+ mesh.apply_transform(rotation_matrix)
362
+ rotation_matrix = trimesh.transformations.rotation_matrix(np.pi, [0, 0, 1])
363
+ mesh.apply_transform(rotation_matrix)
364
+ # flip x
365
+ mesh.vertices[:, 0] = -mesh.vertices[:, 0]
366
+ mesh.faces = np.fliplr(mesh.faces)
367
+ # Export the mesh as .obj file with colors
368
+ if not is_glb:
369
+ mesh.export(mesh_path, file_type='obj', include_color=True)
370
+ else:
371
+ mesh.export(mesh_path, file_type='glb')
372
+ torch.cuda.empty_cache()
373
+
374
+ if not is_rerun:
375
+ return (mesh_path)
376
+ else:
377
+ return (mesh_path, gr.update(value=[]), gr.update(visible=False), gr.update(visible=False))
378
+
379
+ def nsfw_check(models, raw_im, device='cuda'):
380
+ safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device)
381
+ (_, has_nsfw_concept) = models['nsfw'](
382
+ images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values)
383
+ del safety_checker_input
384
+ if np.any(has_nsfw_concept):
385
+ print('NSFW content detected.')
386
+ return Image.open("unsafe.png")
387
+ else:
388
+ print('Safety check passed.')
389
+ return False
390
+
391
+ def preprocess_run(predictor, models, raw_im, lower_contrast, *bbox_sliders):
392
+ raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
393
+ check_results = nsfw_check(models, raw_im, device=predictor.device)
394
+ if check_results:
395
+ return check_results
396
+ image_sam = sam_out_nosave(predictor, raw_im.convert("RGB"), *bbox_sliders)
397
+ input_256 = image_preprocess_nosave(image_sam, lower_contrast=lower_contrast, rescale=True)
398
+ torch.cuda.empty_cache()
399
+ return input_256
400
+
401
+ def on_coords_slider(image, x_min, y_min, x_max, y_max, color=(88, 191, 131, 255)):
402
+ """Draw a bounding box annotation for an image."""
403
+ print("Slider adjusted, drawing bbox...")
404
+ image.thumbnail([512, 512], Image.Resampling.LANCZOS)
405
+ image_size = image.size
406
+ if max(image_size) > 224:
407
+ image.thumbnail([224, 224], Image.Resampling.LANCZOS)
408
+ shrink_ratio = max(image.size) / max(image_size)
409
+ x_min = int(x_min * shrink_ratio)
410
+ y_min = int(y_min * shrink_ratio)
411
+ x_max = int(x_max * shrink_ratio)
412
+ y_max = int(y_max * shrink_ratio)
413
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2BGRA)
414
+ image = cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, int(max(max(image.shape) / 400*2, 2)))
415
+ return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) # image[:, :, ::-1]
416
+
417
+ def init_bbox(image):
418
+ image.thumbnail([512, 512], Image.Resampling.LANCZOS)
419
+ width, height = image.size
420
+ image_rem = image.convert('RGBA')
421
+ image_nobg = remove(image_rem, alpha_matting=True)
422
+ arr = np.asarray(image_nobg)[:,:,-1]
423
+ x_nonzero = np.nonzero(arr.sum(axis=0))
424
+ y_nonzero = np.nonzero(arr.sum(axis=1))
425
+ x_min = int(x_nonzero[0].min())
426
+ y_min = int(y_nonzero[0].min())
427
+ x_max = int(x_nonzero[0].max())
428
+ y_max = int(y_nonzero[0].max())
429
+ image_mini = image.copy()
430
+ image_mini.thumbnail([224, 224], Image.Resampling.LANCZOS)
431
+ shrink_ratio = max(image_mini.size) / max(width, height)
432
+ x_min_shrink = int(x_min * shrink_ratio)
433
+ y_min_shrink = int(y_min * shrink_ratio)
434
+ x_max_shrink = int(x_max * shrink_ratio)
435
+ y_max_shrink = int(y_max * shrink_ratio)
436
+
437
+ return [on_coords_slider(image_mini, x_min_shrink, y_min_shrink, x_max_shrink, y_max_shrink),
438
+ gr.update(value=x_min, maximum=width),
439
+ gr.update(value=y_min, maximum=height),
440
+ gr.update(value=x_max, maximum=width),
441
+ gr.update(value=y_max, maximum=height)]
442
+
443
+
444
+ def run_demo(
445
+ device_idx=_GPU_INDEX,
446
+ ckpt='zero123-xl.ckpt'):
447
+
448
+ device = f"cuda:{device_idx}" if torch.cuda.is_available() else "cpu"
449
+ models = init_model(device, os.path.join(code_dir, 'zero123-xl.ckpt'), half_precision=_HALF_PRECISION)
450
+
451
+ # init sam model
452
+ predictor = sam_init(device_idx)
453
+
454
+ with open('instructions_12345.md', 'r') as f:
455
+ article = f.read()
456
+
457
+ # NOTE: Examples must match inputs
458
+ example_folder = os.path.join(os.path.dirname(__file__), 'demo_examples')
459
+ example_fns = os.listdir(example_folder)
460
+ example_fns.sort()
461
+ examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
462
+
463
+ # Compose demo layout & data flow.
464
+ with gr.Blocks(title=_TITLE, css="style.css") as demo:
465
+ with gr.Row():
466
+ with gr.Column(scale=1):
467
+ gr.Markdown('# ' + _TITLE)
468
+ with gr.Column(scale=0):
469
+ gr.DuplicateButton(value='Duplicate Space for private use',
470
+ elem_id='duplicate-button')
471
+ gr.Markdown(_DESCRIPTION)
472
+
473
+ with gr.Row(variant='panel'):
474
+ with gr.Column(scale=1.2):
475
+ image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image', tool=None)
476
+
477
+ gr.Examples(
478
+ examples=examples_full, # NOTE: elements must match inputs list!
479
+ inputs=[image_block],
480
+ outputs=[image_block],
481
+ cache_examples=False,
482
+ label='Examples (click one of the images below to start)',
483
+ examples_per_page=40
484
+ )
485
+ preprocess_chk = gr.Checkbox(
486
+ False, label='Reduce image contrast (mitigate shadows on the backside)')
487
+ with gr.Accordion('Advanced options', open=False):
488
+ scale_slider = gr.Slider(0, 30, value=3, step=1,
489
+ label='Diffusion guidance scale')
490
+ steps_slider = gr.Slider(5, 200, value=75, step=5,
491
+ label='Number of diffusion inference steps')
492
+ glb_chk = gr.Checkbox(
493
+ False, label='Export the mesh in .glb format')
494
+
495
+ run_btn = gr.Button('Run Generation', variant='primary', interactive=False)
496
+ guide_text = gr.Markdown(_USER_GUIDE, visible=True)
497
+
498
+ with gr.Column(scale=.8):
499
+ with gr.Row():
500
+ bbox_block = gr.Image(type='pil', label="Bounding box", height=290, interactive=False)
501
+ sam_block = gr.Image(type='pil', label="SAM output", interactive=False)
502
+ max_width = max_height = 256
503
+ with gr.Row():
504
+ x_min_slider = gr.Slider(label="X min", interactive=True, value=0, minimum=0, maximum=max_width, step=1)
505
+ y_min_slider = gr.Slider(label="Y min", interactive=True, value=0, minimum=0, maximum=max_height, step=1)
506
+ with gr.Row():
507
+ x_max_slider = gr.Slider(label="X max", interactive=True, value=max_width, minimum=0, maximum=max_width, step=1)
508
+ y_max_slider = gr.Slider(label="Y max", interactive=True, value=max_height, minimum=0, maximum=max_height, step=1)
509
+ bbox_sliders = [x_min_slider, y_min_slider, x_max_slider, y_max_slider]
510
+
511
+ mesh_output = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="One-2-3-45's Textured Mesh", elem_id="model-3d-out")
512
+
513
+ with gr.Row(variant='panel'):
514
+ with gr.Column(scale=0.85):
515
+ elev_output = gr.Label(label='Estimated elevation (degree, w.r.t. the horizontal plane)')
516
+ vis_output = gr.Plot(label='Camera poses of the input view (red) and predicted views (blue)', elem_id="plot-out")
517
+
518
+ with gr.Column(scale=1.15):
519
+ gr.Markdown('Predicted multi-view images')
520
+ with gr.Row():
521
+ view_1 = gr.Image(interactive=False, height=200, show_label=False)
522
+ view_2 = gr.Image(interactive=False, height=200, show_label=False)
523
+ view_3 = gr.Image(interactive=False, height=200, show_label=False)
524
+ view_4 = gr.Image(interactive=False, height=200, show_label=False)
525
+ with gr.Row():
526
+ btn_retry_1 = gr.Checkbox(label='Retry view 1')
527
+ btn_retry_2 = gr.Checkbox(label='Retry view 2')
528
+ btn_retry_3 = gr.Checkbox(label='Retry view 3')
529
+ btn_retry_4 = gr.Checkbox(label='Retry view 4')
530
+ with gr.Row():
531
+ view_5 = gr.Image(interactive=False, height=200, show_label=False)
532
+ view_6 = gr.Image(interactive=False, height=200, show_label=False)
533
+ view_7 = gr.Image(interactive=False, height=200, show_label=False)
534
+ view_8 = gr.Image(interactive=False, height=200, show_label=False)
535
+ with gr.Row():
536
+ btn_retry_5 = gr.Checkbox(label='Retry view 5')
537
+ btn_retry_6 = gr.Checkbox(label='Retry view 6')
538
+ btn_retry_7 = gr.Checkbox(label='Retry view 7')
539
+ btn_retry_8 = gr.Checkbox(label='Retry view 8')
540
+ with gr.Row():
541
+ regen_view_btn = gr.Button('1. Regenerate selected view(s)', variant='secondary', visible=False)
542
+ regen_mesh_btn = gr.Button('2. Regenerate nearby views and mesh', variant='secondary', visible=False)
543
+
544
+ gr.Markdown(article)
545
+ gr.HTML("""
546
+ <div class="footer">
547
+ <p>
548
+ One-2-3-45 Demo by <a style="text-decoration:none" href="https://chaoxu.xyz" target="_blank">Chao Xu</a>
549
+ </p>
550
+ </div>
551
+ """)
552
+
553
+ update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
554
+
555
+ views = [view_1, view_2, view_3, view_4, view_5, view_6, view_7, view_8]
556
+ btn_retrys = [btn_retry_1, btn_retry_2, btn_retry_3, btn_retry_4, btn_retry_5, btn_retry_6, btn_retry_7, btn_retry_8]
557
+
558
+ rerun_idx = gr.State([])
559
+ tmp_dir = gr.State('./demo_tmp/tmp_dir')
560
+
561
+ def refresh(tmp_dir):
562
+ if os.path.exists(tmp_dir):
563
+ shutil.rmtree(tmp_dir)
564
+ tmp_dir = tempfile.TemporaryDirectory(dir=os.path.join(os.path.dirname(__file__), 'demo_tmp'))
565
+ print("create tmp_dir", tmp_dir.name)
566
+ clear = [gr.update(value=[])] + [None] * 5 + [gr.update(visible=False)] * 2 + [None] * 8 + [gr.update(value=False)] * 8
567
+ return (tmp_dir.name, *clear)
568
+
569
+ placeholder = gr.Image(visible=False)
570
+ tmp_func = lambda x: False if not x else gr.update(visible=False)
571
+ disable_func = lambda x: gr.update(interactive=False)
572
+ enable_func = lambda x: gr.update(interactive=True)
573
+ image_block.change(disable_func, inputs=run_btn, outputs=run_btn, queue=False
574
+ ).success(fn=refresh,
575
+ inputs=[tmp_dir],
576
+ outputs=[tmp_dir, rerun_idx, bbox_block, sam_block, elev_output, vis_output, mesh_output, regen_view_btn, regen_mesh_btn, *views, *btn_retrys],
577
+ queue=False
578
+ ).success(fn=tmp_func, inputs=[image_block], outputs=[placeholder], queue=False
579
+ ).success(fn=partial(update_guide, _BBOX_1), outputs=[guide_text], queue=False
580
+ ).success(fn=init_bbox,
581
+ inputs=[image_block],
582
+ outputs=[bbox_block, *bbox_sliders], queue=False
583
+ ).success(fn=partial(update_guide, _BBOX_3), outputs=[guide_text], queue=False
584
+ ).success(enable_func, inputs=run_btn, outputs=run_btn, queue=False)
585
+
586
+
587
+ for bbox_slider in bbox_sliders:
588
+ bbox_slider.release(fn=on_coords_slider,
589
+ inputs=[image_block, *bbox_sliders],
590
+ outputs=[bbox_block],
591
+ queue=False
592
+ ).success(fn=partial(update_guide, _BBOX_2), outputs=[guide_text], queue=False)
593
+
594
+ cam_vis = CameraVisualizer(vis_output)
595
+
596
+ # Define the function to be called when any of the btn_retry buttons are clicked
597
+ def on_retry_button_click(*btn_retrys):
598
+ any_checked = any([btn_retry for btn_retry in btn_retrys])
599
+ print('any_checked:', any_checked, [btn_retry for btn_retry in btn_retrys])
600
+ if any_checked:
601
+ return (gr.update(visible=True), gr.update(visible=True))
602
+ else:
603
+ return (gr.update(), gr.update())
604
+ # make regen_btn visible when any of the btn_retry is checked
605
+ for btn_retry in btn_retrys:
606
+ # Add the event handlers to the btn_retry buttons
607
+ btn_retry.change(fn=on_retry_button_click, inputs=[*btn_retrys], outputs=[regen_view_btn, regen_mesh_btn], queue=False)
608
+
609
+
610
+ run_btn.click(fn=partial(update_guide, _SAM), outputs=[guide_text], queue=False
611
+ ).success(fn=partial(preprocess_run, predictor, models),
612
+ inputs=[image_block, preprocess_chk, *bbox_sliders],
613
+ outputs=[sam_block]
614
+ ).success(fn=partial(update_guide, _GEN_1), outputs=[guide_text], queue=False
615
+ ).success(fn=partial(stage1_run, models, device, cam_vis),
616
+ inputs=[tmp_dir, sam_block, scale_slider, steps_slider],
617
+ outputs=[elev_output, vis_output, *views]
618
+ ).success(fn=partial(update_guide, _GEN_2), outputs=[guide_text], queue=False
619
+ ).success(fn=partial(stage2_run, models, device),
620
+ inputs=[tmp_dir, elev_output, scale_slider, glb_chk],
621
+ outputs=[mesh_output]
622
+ ).success(fn=partial(update_guide, _DONE), outputs=[guide_text], queue=False)
623
+
624
+
625
+ regen_view_btn.click(fn=partial(stage1_run, models, device, None),
626
+ inputs=[tmp_dir, sam_block, scale_slider, steps_slider, elev_output, rerun_idx, *btn_retrys],
627
+ outputs=[rerun_idx, *btn_retrys, *views]
628
+ ).success(fn=partial(update_guide, _REGEN_1), outputs=[guide_text], queue=False)
629
+ regen_mesh_btn.click(fn=partial(stage2_run, models, device),
630
+ inputs=[tmp_dir, elev_output, scale_slider, glb_chk, rerun_idx],
631
+ outputs=[mesh_output, rerun_idx, regen_view_btn, regen_mesh_btn]
632
+ ).success(fn=partial(update_guide, _REGEN_2), outputs=[guide_text], queue=False)
633
+
634
+
635
+ demo.queue().launch(share=True, max_threads=80) # auth=("admin", os.environ['PASSWD'])
636
+
637
+
638
+ if __name__ == '__main__':
639
+ fire.Fire(run_demo)
One-2-3-45-master 2/demo/demo_tmp/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ tmp*
One-2-3-45-master 2/demo/demo_tmp/.gitkeep ADDED
File without changes
One-2-3-45-master 2/demo/instructions_12345.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Tuning Tips:
2
+
3
+ 1. The multi-view prediction module (Zero123) operates probabilistically. If some of the predicted views are not satisfactory, you may select and regenerate them.
4
+
5
+ 2. In “advanced options”, you can tune two parameters as in other common diffusion models:
6
+ - Diffusion Guidance Scale determines how much you want the model to respect the input information (input image + viewpoints). Increasing the scale typically results in better adherence, less diversity, and also higher image distortion.
7
+
8
+ - Number of diffusion inference steps controls the number of diffusion steps applied to generate each image. Generally, a higher value yields better results but with diminishing returns.
9
+
10
+ Enjoy creating your 3D asset!
One-2-3-45-master 2/demo/memora/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
One-2-3-45-master 2/demo/memora/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Memora
3
+ emoji: 🐨
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.47.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
One-2-3-45-master 2/demo/style.css ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #model-3d-out {
2
+ height: 400px;
3
+ }
4
+
5
+ #plot-out {
6
+ height: 450px;
7
+ }
8
+
9
+ #duplicate-button {
10
+ margin-left: auto;
11
+ color: #fff;
12
+ background: #1565c0;
13
+ }
14
+
15
+ .footer {
16
+ margin-bottom: 45px;
17
+ margin-top: 10px;
18
+ text-align: center;
19
+ border-bottom: 1px solid #e5e5e5;
20
+ }
21
+ .footer>p {
22
+ font-size: .8rem;
23
+ display: inline-block;
24
+ padding: 0 10px;
25
+ transform: translateY(10px);
26
+ background: white;
27
+ }
28
+ .dark .footer {
29
+ border-color: #303030;
30
+ }
31
+ .dark .footer>p {
32
+ background: #0b0f19;
33
+ }
One-2-3-45-master 2/download_ckpt.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from tqdm import tqdm
3
+
4
+ def download_checkpoint(url, save_path):
5
+ try:
6
+ with urllib.request.urlopen(url) as response, open(save_path, 'wb') as file:
7
+ file_size = int(response.info().get('Content-Length', -1))
8
+ chunk_size = 8192
9
+ num_chunks = file_size // chunk_size if file_size > chunk_size else 1
10
+
11
+ with tqdm(total=file_size, unit='B', unit_scale=True, desc='Downloading', ncols=100) as pbar:
12
+ for chunk in iter(lambda: response.read(chunk_size), b''):
13
+ file.write(chunk)
14
+ pbar.update(len(chunk))
15
+
16
+ print(f"Checkpoint downloaded and saved to: {save_path}")
17
+ except Exception as e:
18
+ print(f"Error downloading checkpoint: {e}")
19
+
20
+ if __name__ == "__main__":
21
+ ckpts = {
22
+ "sam_vit_h_4b8939.pth": "https://huggingface.co/One-2-3-45/code/resolve/main/sam_vit_h_4b8939.pth",
23
+ "zero123-xl.ckpt": "https://huggingface.co/One-2-3-45/code/resolve/main/zero123-xl.ckpt",
24
+ "elevation_estimate/utils/weights/indoor_ds_new.ckpt" : "https://huggingface.co/One-2-3-45/code/resolve/main/one2345_elev_est/tools/weights/indoor_ds_new.ckpt",
25
+ "reconstruction/exp/lod0/checkpoints/ckpt_215000.pth": "https://huggingface.co/One-2-3-45/code/resolve/main/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth"
26
+ }
27
+ for ckpt_name, ckpt_url in ckpts.items():
28
+ print(f"Downloading checkpoint: {ckpt_name}")
29
+ download_checkpoint(ckpt_url, ckpt_name)
30
+
One-2-3-45-master 2/elevation_estimate/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build/
2
+ .idea/
3
+ *.egg-info/
One-2-3-45-master 2/elevation_estimate/__init__.py ADDED
File without changes
One-2-3-45-master 2/elevation_estimate/estimate_wild_imgs.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from .utils.elev_est_api import elev_est_api
3
+
4
+ def estimate_elev(root_dir):
5
+ img_dir = osp.join(root_dir, "stage2_8")
6
+ img_paths = []
7
+ for i in range(4):
8
+ img_paths.append(f"{img_dir}/0_{i}.png")
9
+ elev = elev_est_api(img_paths)
10
+ return elev
One-2-3-45-master 2/elevation_estimate/loftr/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .loftr import LoFTR
2
+ from .utils.cvpr_ds_config import default_cfg
One-2-3-45-master 2/elevation_estimate/loftr/backbone/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
2
+
3
+
4
+ def build_backbone(config):
5
+ if config['backbone_type'] == 'ResNetFPN':
6
+ if config['resolution'] == (8, 2):
7
+ return ResNetFPN_8_2(config['resnetfpn'])
8
+ elif config['resolution'] == (16, 4):
9
+ return ResNetFPN_16_4(config['resnetfpn'])
10
+ else:
11
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
One-2-3-45-master 2/elevation_estimate/loftr/backbone/resnet_fpn.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def conv1x1(in_planes, out_planes, stride=1):
6
+ """1x1 convolution without padding"""
7
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
13
+
14
+
15
+ class BasicBlock(nn.Module):
16
+ def __init__(self, in_planes, planes, stride=1):
17
+ super().__init__()
18
+ self.conv1 = conv3x3(in_planes, planes, stride)
19
+ self.conv2 = conv3x3(planes, planes)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+
24
+ if stride == 1:
25
+ self.downsample = None
26
+ else:
27
+ self.downsample = nn.Sequential(
28
+ conv1x1(in_planes, planes, stride=stride),
29
+ nn.BatchNorm2d(planes)
30
+ )
31
+
32
+ def forward(self, x):
33
+ y = x
34
+ y = self.relu(self.bn1(self.conv1(y)))
35
+ y = self.bn2(self.conv2(y))
36
+
37
+ if self.downsample is not None:
38
+ x = self.downsample(x)
39
+
40
+ return self.relu(x+y)
41
+
42
+
43
+ class ResNetFPN_8_2(nn.Module):
44
+ """
45
+ ResNet+FPN, output resolution are 1/8 and 1/2.
46
+ Each block has 2 layers.
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ # Config
52
+ block = BasicBlock
53
+ initial_dim = config['initial_dim']
54
+ block_dims = config['block_dims']
55
+
56
+ # Class Variable
57
+ self.in_planes = initial_dim
58
+
59
+ # Networks
60
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
61
+ self.bn1 = nn.BatchNorm2d(initial_dim)
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
65
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
66
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
67
+
68
+ # 3. FPN upsample
69
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
70
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
71
+ self.layer2_outconv2 = nn.Sequential(
72
+ conv3x3(block_dims[2], block_dims[2]),
73
+ nn.BatchNorm2d(block_dims[2]),
74
+ nn.LeakyReLU(),
75
+ conv3x3(block_dims[2], block_dims[1]),
76
+ )
77
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
78
+ self.layer1_outconv2 = nn.Sequential(
79
+ conv3x3(block_dims[1], block_dims[1]),
80
+ nn.BatchNorm2d(block_dims[1]),
81
+ nn.LeakyReLU(),
82
+ conv3x3(block_dims[1], block_dims[0]),
83
+ )
84
+
85
+ for m in self.modules():
86
+ if isinstance(m, nn.Conv2d):
87
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
89
+ nn.init.constant_(m.weight, 1)
90
+ nn.init.constant_(m.bias, 0)
91
+
92
+ def _make_layer(self, block, dim, stride=1):
93
+ layer1 = block(self.in_planes, dim, stride=stride)
94
+ layer2 = block(dim, dim, stride=1)
95
+ layers = (layer1, layer2)
96
+
97
+ self.in_planes = dim
98
+ return nn.Sequential(*layers)
99
+
100
+ def forward(self, x):
101
+ # ResNet Backbone
102
+ x0 = self.relu(self.bn1(self.conv1(x)))
103
+ x1 = self.layer1(x0) # 1/2
104
+ x2 = self.layer2(x1) # 1/4
105
+ x3 = self.layer3(x2) # 1/8
106
+
107
+ # FPN
108
+ x3_out = self.layer3_outconv(x3)
109
+
110
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
111
+ x2_out = self.layer2_outconv(x2)
112
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
113
+
114
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
115
+ x1_out = self.layer1_outconv(x1)
116
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
117
+
118
+ return [x3_out, x1_out]
119
+
120
+
121
+ class ResNetFPN_16_4(nn.Module):
122
+ """
123
+ ResNet+FPN, output resolution are 1/16 and 1/4.
124
+ Each block has 2 layers.
125
+ """
126
+
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ # Config
130
+ block = BasicBlock
131
+ initial_dim = config['initial_dim']
132
+ block_dims = config['block_dims']
133
+
134
+ # Class Variable
135
+ self.in_planes = initial_dim
136
+
137
+ # Networks
138
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
139
+ self.bn1 = nn.BatchNorm2d(initial_dim)
140
+ self.relu = nn.ReLU(inplace=True)
141
+
142
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
143
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
144
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
145
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
146
+
147
+ # 3. FPN upsample
148
+ self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
149
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
150
+ self.layer3_outconv2 = nn.Sequential(
151
+ conv3x3(block_dims[3], block_dims[3]),
152
+ nn.BatchNorm2d(block_dims[3]),
153
+ nn.LeakyReLU(),
154
+ conv3x3(block_dims[3], block_dims[2]),
155
+ )
156
+
157
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
158
+ self.layer2_outconv2 = nn.Sequential(
159
+ conv3x3(block_dims[2], block_dims[2]),
160
+ nn.BatchNorm2d(block_dims[2]),
161
+ nn.LeakyReLU(),
162
+ conv3x3(block_dims[2], block_dims[1]),
163
+ )
164
+
165
+ for m in self.modules():
166
+ if isinstance(m, nn.Conv2d):
167
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
168
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
169
+ nn.init.constant_(m.weight, 1)
170
+ nn.init.constant_(m.bias, 0)
171
+
172
+ def _make_layer(self, block, dim, stride=1):
173
+ layer1 = block(self.in_planes, dim, stride=stride)
174
+ layer2 = block(dim, dim, stride=1)
175
+ layers = (layer1, layer2)
176
+
177
+ self.in_planes = dim
178
+ return nn.Sequential(*layers)
179
+
180
+ def forward(self, x):
181
+ # ResNet Backbone
182
+ x0 = self.relu(self.bn1(self.conv1(x)))
183
+ x1 = self.layer1(x0) # 1/2
184
+ x2 = self.layer2(x1) # 1/4
185
+ x3 = self.layer3(x2) # 1/8
186
+ x4 = self.layer4(x3) # 1/16
187
+
188
+ # FPN
189
+ x4_out = self.layer4_outconv(x4)
190
+
191
+ x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
192
+ x3_out = self.layer3_outconv(x3)
193
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
194
+
195
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
196
+ x2_out = self.layer2_outconv(x2)
197
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
198
+
199
+ return [x4_out, x2_out]
One-2-3-45-master 2/elevation_estimate/loftr/loftr.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops.einops import rearrange
4
+
5
+ from .backbone import build_backbone
6
+ from .utils.position_encoding import PositionEncodingSine
7
+ from .loftr_module import LocalFeatureTransformer, FinePreprocess
8
+ from .utils.coarse_matching import CoarseMatching
9
+ from .utils.fine_matching import FineMatching
10
+
11
+
12
+ class LoFTR(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ # Misc
16
+ self.config = config
17
+
18
+ # Modules
19
+ self.backbone = build_backbone(config)
20
+ self.pos_encoding = PositionEncodingSine(
21
+ config['coarse']['d_model'],
22
+ temp_bug_fix=config['coarse']['temp_bug_fix'])
23
+ self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
24
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
25
+ self.fine_preprocess = FinePreprocess(config)
26
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
27
+ self.fine_matching = FineMatching()
28
+
29
+ def forward(self, data):
30
+ """
31
+ Update:
32
+ data (dict): {
33
+ 'image0': (torch.Tensor): (N, 1, H, W)
34
+ 'image1': (torch.Tensor): (N, 1, H, W)
35
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
36
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
37
+ }
38
+ """
39
+ # 1. Local Feature CNN
40
+ data.update({
41
+ 'bs': data['image0'].size(0),
42
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
43
+ })
44
+
45
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
46
+ feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
47
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
48
+ else: # handle different input shapes
49
+ (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
50
+
51
+ data.update({
52
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
53
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
54
+ })
55
+
56
+ # 2. coarse-level loftr module
57
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
58
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
59
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
60
+
61
+ mask_c0 = mask_c1 = None # mask is useful in training
62
+ if 'mask0' in data:
63
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
64
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
65
+
66
+ # 3. match coarse-level
67
+ self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
68
+
69
+ # 4. fine-level refinement
70
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
71
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
72
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
73
+
74
+ # 5. match fine-level
75
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
76
+
77
+ def load_state_dict(self, state_dict, *args, **kwargs):
78
+ for k in list(state_dict.keys()):
79
+ if k.startswith('matcher.'):
80
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
81
+ return super().load_state_dict(state_dict, *args, **kwargs)
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import LocalFeatureTransformer
2
+ from .fine_preprocess import FinePreprocess
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/fine_preprocess.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.einops import rearrange, repeat
5
+
6
+
7
+ class FinePreprocess(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+
11
+ self.config = config
12
+ self.cat_c_feat = config['fine_concat_coarse_feat']
13
+ self.W = self.config['fine_window_size']
14
+
15
+ d_model_c = self.config['coarse']['d_model']
16
+ d_model_f = self.config['fine']['d_model']
17
+ self.d_model_f = d_model_f
18
+ if self.cat_c_feat:
19
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
20
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
21
+
22
+ self._reset_parameters()
23
+
24
+ def _reset_parameters(self):
25
+ for p in self.parameters():
26
+ if p.dim() > 1:
27
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
28
+
29
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
30
+ W = self.W
31
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
32
+
33
+ data.update({'W': W})
34
+ if data['b_ids'].shape[0] == 0:
35
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
36
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
37
+ return feat0, feat1
38
+
39
+ # 1. unfold(crop) all local windows
40
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
41
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
42
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
43
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
44
+
45
+ # 2. select only the predicted matches
46
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
47
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
48
+
49
+ # option: use coarse-level loftr feature as context: concat and linear
50
+ if self.cat_c_feat:
51
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
52
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
53
+ feat_cf_win = self.merge_feat(torch.cat([
54
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
55
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
56
+ ], -1))
57
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
58
+
59
+ return feat_f0_unfold, feat_f1_unfold
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/linear_attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4
+ """
5
+
6
+ import torch
7
+ from torch.nn import Module, Dropout
8
+
9
+
10
+ def elu_feature_map(x):
11
+ return torch.nn.functional.elu(x) + 1
12
+
13
+
14
+ class LinearAttention(Module):
15
+ def __init__(self, eps=1e-6):
16
+ super().__init__()
17
+ self.feature_map = elu_feature_map
18
+ self.eps = eps
19
+
20
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
21
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
22
+ Args:
23
+ queries: [N, L, H, D]
24
+ keys: [N, S, H, D]
25
+ values: [N, S, H, D]
26
+ q_mask: [N, L]
27
+ kv_mask: [N, S]
28
+ Returns:
29
+ queried_values: (N, L, H, D)
30
+ """
31
+ Q = self.feature_map(queries)
32
+ K = self.feature_map(keys)
33
+
34
+ # set padded position to zero
35
+ if q_mask is not None:
36
+ Q = Q * q_mask[:, :, None, None]
37
+ if kv_mask is not None:
38
+ K = K * kv_mask[:, :, None, None]
39
+ values = values * kv_mask[:, :, None, None]
40
+
41
+ v_length = values.size(1)
42
+ values = values / v_length # prevent fp16 overflow
43
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
44
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
45
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
46
+
47
+ return queried_values.contiguous()
48
+
49
+
50
+ class FullAttention(Module):
51
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
52
+ super().__init__()
53
+ self.use_dropout = use_dropout
54
+ self.dropout = Dropout(attention_dropout)
55
+
56
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
57
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
58
+ Args:
59
+ queries: [N, L, H, D]
60
+ keys: [N, S, H, D]
61
+ values: [N, S, H, D]
62
+ q_mask: [N, L]
63
+ kv_mask: [N, S]
64
+ Returns:
65
+ queried_values: (N, L, H, D)
66
+ """
67
+
68
+ # Compute the unnormalized attention and apply the masks
69
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
70
+ if kv_mask is not None:
71
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
72
+
73
+ # Compute the attention and the weighted average
74
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
75
+ A = torch.softmax(softmax_temp * QK, dim=2)
76
+ if self.use_dropout:
77
+ A = self.dropout(A)
78
+
79
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
80
+
81
+ return queried_values.contiguous()
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/transformer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ from .linear_attention import LinearAttention, FullAttention
5
+
6
+
7
+ class LoFTREncoderLayer(nn.Module):
8
+ def __init__(self,
9
+ d_model,
10
+ nhead,
11
+ attention='linear'):
12
+ super(LoFTREncoderLayer, self).__init__()
13
+
14
+ self.dim = d_model // nhead
15
+ self.nhead = nhead
16
+
17
+ # multi-head attention
18
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
19
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
20
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
21
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
22
+ self.merge = nn.Linear(d_model, d_model, bias=False)
23
+
24
+ # feed-forward network
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(d_model*2, d_model*2, bias=False),
27
+ nn.ReLU(True),
28
+ nn.Linear(d_model*2, d_model, bias=False),
29
+ )
30
+
31
+ # norm and dropout
32
+ self.norm1 = nn.LayerNorm(d_model)
33
+ self.norm2 = nn.LayerNorm(d_model)
34
+
35
+ def forward(self, x, source, x_mask=None, source_mask=None):
36
+ """
37
+ Args:
38
+ x (torch.Tensor): [N, L, C]
39
+ source (torch.Tensor): [N, S, C]
40
+ x_mask (torch.Tensor): [N, L] (optional)
41
+ source_mask (torch.Tensor): [N, S] (optional)
42
+ """
43
+ bs = x.size(0)
44
+ query, key, value = x, source, source
45
+
46
+ # multi-head attention
47
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
48
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
49
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
50
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
51
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
52
+ message = self.norm1(message)
53
+
54
+ # feed-forward network
55
+ message = self.mlp(torch.cat([x, message], dim=2))
56
+ message = self.norm2(message)
57
+
58
+ return x + message
59
+
60
+
61
+ class LocalFeatureTransformer(nn.Module):
62
+ """A Local Feature Transformer (LoFTR) module."""
63
+
64
+ def __init__(self, config):
65
+ super(LocalFeatureTransformer, self).__init__()
66
+
67
+ self.config = config
68
+ self.d_model = config['d_model']
69
+ self.nhead = config['nhead']
70
+ self.layer_names = config['layer_names']
71
+ encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
72
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
73
+ self._reset_parameters()
74
+
75
+ def _reset_parameters(self):
76
+ for p in self.parameters():
77
+ if p.dim() > 1:
78
+ nn.init.xavier_uniform_(p)
79
+
80
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
81
+ """
82
+ Args:
83
+ feat0 (torch.Tensor): [N, L, C]
84
+ feat1 (torch.Tensor): [N, S, C]
85
+ mask0 (torch.Tensor): [N, L] (optional)
86
+ mask1 (torch.Tensor): [N, S] (optional)
87
+ """
88
+
89
+ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
90
+
91
+ for layer, name in zip(self.layers, self.layer_names):
92
+ if name == 'self':
93
+ feat0 = layer(feat0, feat0, mask0, mask0)
94
+ feat1 = layer(feat1, feat1, mask1, mask1)
95
+ elif name == 'cross':
96
+ feat0 = layer(feat0, feat1, mask0, mask1)
97
+ feat1 = layer(feat1, feat0, mask1, mask0)
98
+ else:
99
+ raise KeyError
100
+
101
+ return feat0, feat1
One-2-3-45-master 2/elevation_estimate/loftr/utils/coarse_matching.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.einops import rearrange
5
+
6
+ INF = 1e9
7
+
8
+ def mask_border(m, b: int, v):
9
+ """ Mask borders with value
10
+ Args:
11
+ m (torch.Tensor): [N, H0, W0, H1, W1]
12
+ b (int)
13
+ v (m.dtype)
14
+ """
15
+ if b <= 0:
16
+ return
17
+
18
+ m[:, :b] = v
19
+ m[:, :, :b] = v
20
+ m[:, :, :, :b] = v
21
+ m[:, :, :, :, :b] = v
22
+ m[:, -b:] = v
23
+ m[:, :, -b:] = v
24
+ m[:, :, :, -b:] = v
25
+ m[:, :, :, :, -b:] = v
26
+
27
+
28
+ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
29
+ if bd <= 0:
30
+ return
31
+
32
+ m[:, :bd] = v
33
+ m[:, :, :bd] = v
34
+ m[:, :, :, :bd] = v
35
+ m[:, :, :, :, :bd] = v
36
+
37
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
38
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
39
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
40
+ m[b_idx, h0 - bd:] = v
41
+ m[b_idx, :, w0 - bd:] = v
42
+ m[b_idx, :, :, h1 - bd:] = v
43
+ m[b_idx, :, :, :, w1 - bd:] = v
44
+
45
+
46
+ def compute_max_candidates(p_m0, p_m1):
47
+ """Compute the max candidates of all pairs within a batch
48
+
49
+ Args:
50
+ p_m0, p_m1 (torch.Tensor): padded masks
51
+ """
52
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
53
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
54
+ max_cand = torch.sum(
55
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
56
+ return max_cand
57
+
58
+
59
+ class CoarseMatching(nn.Module):
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.config = config
63
+ # general config
64
+ self.thr = config['thr']
65
+ self.border_rm = config['border_rm']
66
+ # -- # for trainig fine-level LoFTR
67
+ self.train_coarse_percent = config['train_coarse_percent']
68
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
69
+
70
+ # we provide 2 options for differentiable matching
71
+ self.match_type = config['match_type']
72
+ if self.match_type == 'dual_softmax':
73
+ self.temperature = config['dsmax_temperature']
74
+ elif self.match_type == 'sinkhorn':
75
+ try:
76
+ from .superglue import log_optimal_transport
77
+ except ImportError:
78
+ raise ImportError("download superglue.py first!")
79
+ self.log_optimal_transport = log_optimal_transport
80
+ self.bin_score = nn.Parameter(
81
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
82
+ self.skh_iters = config['skh_iters']
83
+ self.skh_prefilter = config['skh_prefilter']
84
+ else:
85
+ raise NotImplementedError()
86
+
87
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
88
+ """
89
+ Args:
90
+ feat0 (torch.Tensor): [N, L, C]
91
+ feat1 (torch.Tensor): [N, S, C]
92
+ data (dict)
93
+ mask_c0 (torch.Tensor): [N, L] (optional)
94
+ mask_c1 (torch.Tensor): [N, S] (optional)
95
+ Update:
96
+ data (dict): {
97
+ 'b_ids' (torch.Tensor): [M'],
98
+ 'i_ids' (torch.Tensor): [M'],
99
+ 'j_ids' (torch.Tensor): [M'],
100
+ 'gt_mask' (torch.Tensor): [M'],
101
+ 'mkpts0_c' (torch.Tensor): [M, 2],
102
+ 'mkpts1_c' (torch.Tensor): [M, 2],
103
+ 'mconf' (torch.Tensor): [M]}
104
+ NOTE: M' != M during training.
105
+ """
106
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
107
+
108
+ # normalize
109
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
110
+ [feat_c0, feat_c1])
111
+
112
+ if self.match_type == 'dual_softmax':
113
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
114
+ feat_c1) / self.temperature
115
+ if mask_c0 is not None:
116
+ sim_matrix.masked_fill_(
117
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
118
+ -INF)
119
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
120
+
121
+ elif self.match_type == 'sinkhorn':
122
+ # sinkhorn, dustbin included
123
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
124
+ if mask_c0 is not None:
125
+ sim_matrix[:, :L, :S].masked_fill_(
126
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
127
+ -INF)
128
+
129
+ # build uniform prior & use sinkhorn
130
+ log_assign_matrix = self.log_optimal_transport(
131
+ sim_matrix, self.bin_score, self.skh_iters)
132
+ assign_matrix = log_assign_matrix.exp()
133
+ conf_matrix = assign_matrix[:, :-1, :-1]
134
+
135
+ # filter prediction with dustbin score (only in evaluation mode)
136
+ if not self.training and self.skh_prefilter:
137
+ filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
138
+ filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
139
+ conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
140
+ conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
141
+
142
+ if self.config['sparse_spvs']:
143
+ data.update({'conf_matrix_with_bin': assign_matrix.clone()})
144
+
145
+ data.update({'conf_matrix': conf_matrix})
146
+
147
+ # predict coarse matches from conf_matrix
148
+ data.update(**self.get_coarse_match(conf_matrix, data))
149
+
150
+ @torch.no_grad()
151
+ def get_coarse_match(self, conf_matrix, data):
152
+ """
153
+ Args:
154
+ conf_matrix (torch.Tensor): [N, L, S]
155
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
156
+ Returns:
157
+ coarse_matches (dict): {
158
+ 'b_ids' (torch.Tensor): [M'],
159
+ 'i_ids' (torch.Tensor): [M'],
160
+ 'j_ids' (torch.Tensor): [M'],
161
+ 'gt_mask' (torch.Tensor): [M'],
162
+ 'm_bids' (torch.Tensor): [M],
163
+ 'mkpts0_c' (torch.Tensor): [M, 2],
164
+ 'mkpts1_c' (torch.Tensor): [M, 2],
165
+ 'mconf' (torch.Tensor): [M]}
166
+ """
167
+ axes_lengths = {
168
+ 'h0c': data['hw0_c'][0],
169
+ 'w0c': data['hw0_c'][1],
170
+ 'h1c': data['hw1_c'][0],
171
+ 'w1c': data['hw1_c'][1]
172
+ }
173
+ _device = conf_matrix.device
174
+ # 1. confidence thresholding
175
+ mask = conf_matrix > self.thr
176
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
177
+ **axes_lengths)
178
+ if 'mask0' not in data:
179
+ mask_border(mask, self.border_rm, False)
180
+ else:
181
+ mask_border_with_padding(mask, self.border_rm, False,
182
+ data['mask0'], data['mask1'])
183
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
184
+ **axes_lengths)
185
+
186
+ # 2. mutual nearest
187
+ mask = mask \
188
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
189
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
190
+
191
+ # 3. find all valid coarse matches
192
+ # this only works when at most one `True` in each row
193
+ mask_v, all_j_ids = mask.max(dim=2)
194
+ b_ids, i_ids = torch.where(mask_v)
195
+ j_ids = all_j_ids[b_ids, i_ids]
196
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
197
+
198
+ # 4. Random sampling of training samples for fine-level LoFTR
199
+ # (optional) pad samples with gt coarse-level matches
200
+ if self.training:
201
+ # NOTE:
202
+ # The sampling is performed across all pairs in a batch without manually balancing
203
+ # #samples for fine-level increases w.r.t. batch_size
204
+ if 'mask0' not in data:
205
+ num_candidates_max = mask.size(0) * max(
206
+ mask.size(1), mask.size(2))
207
+ else:
208
+ num_candidates_max = compute_max_candidates(
209
+ data['mask0'], data['mask1'])
210
+ num_matches_train = int(num_candidates_max *
211
+ self.train_coarse_percent)
212
+ num_matches_pred = len(b_ids)
213
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
214
+
215
+ # pred_indices is to select from prediction
216
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
217
+ pred_indices = torch.arange(num_matches_pred, device=_device)
218
+ else:
219
+ pred_indices = torch.randint(
220
+ num_matches_pred,
221
+ (num_matches_train - self.train_pad_num_gt_min, ),
222
+ device=_device)
223
+
224
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
225
+ gt_pad_indices = torch.randint(
226
+ len(data['spv_b_ids']),
227
+ (max(num_matches_train - num_matches_pred,
228
+ self.train_pad_num_gt_min), ),
229
+ device=_device)
230
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
231
+
232
+ b_ids, i_ids, j_ids, mconf = map(
233
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
234
+ dim=0),
235
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
236
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
237
+
238
+ # These matches select patches that feed into fine-level network
239
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
240
+
241
+ # 4. Update with matches in original image resolution
242
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
243
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
244
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
245
+ mkpts0_c = torch.stack(
246
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
247
+ dim=1) * scale0
248
+ mkpts1_c = torch.stack(
249
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
250
+ dim=1) * scale1
251
+
252
+ # These matches is the current prediction (for visualization)
253
+ coarse_matches.update({
254
+ 'gt_mask': mconf == 0,
255
+ 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
256
+ 'mkpts0_c': mkpts0_c[mconf != 0],
257
+ 'mkpts1_c': mkpts1_c[mconf != 0],
258
+ 'mconf': mconf[mconf != 0]
259
+ })
260
+
261
+ return coarse_matches
One-2-3-45-master 2/elevation_estimate/loftr/utils/cvpr_ds_config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+
4
+ def lower_config(yacs_cfg):
5
+ if not isinstance(yacs_cfg, CN):
6
+ return yacs_cfg
7
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
8
+
9
+
10
+ _CN = CN()
11
+ _CN.BACKBONE_TYPE = 'ResNetFPN'
12
+ _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
13
+ _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
14
+ _CN.FINE_CONCAT_COARSE_FEAT = True
15
+
16
+ # 1. LoFTR-backbone (local feature CNN) config
17
+ _CN.RESNETFPN = CN()
18
+ _CN.RESNETFPN.INITIAL_DIM = 128
19
+ _CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
20
+
21
+ # 2. LoFTR-coarse module config
22
+ _CN.COARSE = CN()
23
+ _CN.COARSE.D_MODEL = 256
24
+ _CN.COARSE.D_FFN = 256
25
+ _CN.COARSE.NHEAD = 8
26
+ _CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
27
+ _CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
28
+ _CN.COARSE.TEMP_BUG_FIX = False
29
+
30
+ # 3. Coarse-Matching config
31
+ _CN.MATCH_COARSE = CN()
32
+ _CN.MATCH_COARSE.THR = 0.2
33
+ _CN.MATCH_COARSE.BORDER_RM = 2
34
+ _CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
35
+ _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
36
+ _CN.MATCH_COARSE.SKH_ITERS = 3
37
+ _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
38
+ _CN.MATCH_COARSE.SKH_PREFILTER = True
39
+ _CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory
40
+ _CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
41
+
42
+ # 4. LoFTR-fine module config
43
+ _CN.FINE = CN()
44
+ _CN.FINE.D_MODEL = 128
45
+ _CN.FINE.D_FFN = 128
46
+ _CN.FINE.NHEAD = 8
47
+ _CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
48
+ _CN.FINE.ATTENTION = 'linear'
49
+
50
+ default_cfg = lower_config(_CN)
One-2-3-45-master 2/elevation_estimate/loftr/utils/fine_matching.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from kornia.geometry.subpix import dsnt
6
+ from kornia.utils.grid import create_meshgrid
7
+
8
+
9
+ class FineMatching(nn.Module):
10
+ """FineMatching with s2d paradigm"""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, feat_f0, feat_f1, data):
16
+ """
17
+ Args:
18
+ feat0 (torch.Tensor): [M, WW, C]
19
+ feat1 (torch.Tensor): [M, WW, C]
20
+ data (dict)
21
+ Update:
22
+ data (dict):{
23
+ 'expec_f' (torch.Tensor): [M, 3],
24
+ 'mkpts0_f' (torch.Tensor): [M, 2],
25
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
26
+ """
27
+ M, WW, C = feat_f0.shape
28
+ W = int(math.sqrt(WW))
29
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
30
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
31
+
32
+ # corner case: if no coarse matches found
33
+ if M == 0:
34
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
35
+ # logger.warning('No matches found in coarse-level.')
36
+ data.update({
37
+ 'expec_f': torch.empty(0, 3, device=feat_f0.device),
38
+ 'mkpts0_f': data['mkpts0_c'],
39
+ 'mkpts1_f': data['mkpts1_c'],
40
+ })
41
+ return
42
+
43
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
44
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
45
+ softmax_temp = 1. / C**.5
46
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
47
+
48
+ # compute coordinates from heatmap
49
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
50
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
51
+
52
+ # compute std over <x, y>
53
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
54
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
55
+
56
+ # for fine-level supervision
57
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
58
+
59
+ # compute absolute kpt coords
60
+ self.get_fine_match(coords_normalized, data)
61
+
62
+ @torch.no_grad()
63
+ def get_fine_match(self, coords_normed, data):
64
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
65
+
66
+ # mkpts0_f and mkpts1_f
67
+ mkpts0_f = data['mkpts0_c']
68
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
69
+ mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
70
+
71
+ data.update({
72
+ "mkpts0_f": mkpts0_f,
73
+ "mkpts1_f": mkpts1_f
74
+ })
One-2-3-45-master 2/elevation_estimate/loftr/utils/geometry.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ @torch.no_grad()
5
+ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
6
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
7
+ Also check covisibility and depth consistency.
8
+ Depth is consistent if relative error < 0.2 (hard-coded).
9
+
10
+ Args:
11
+ kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
12
+ depth0 (torch.Tensor): [N, H, W],
13
+ depth1 (torch.Tensor): [N, H, W],
14
+ T_0to1 (torch.Tensor): [N, 3, 4],
15
+ K0 (torch.Tensor): [N, 3, 3],
16
+ K1 (torch.Tensor): [N, 3, 3],
17
+ Returns:
18
+ calculable_mask (torch.Tensor): [N, L]
19
+ warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
20
+ """
21
+ kpts0_long = kpts0.round().long()
22
+
23
+ # Sample depth, get calculable_mask on depth != 0
24
+ kpts0_depth = torch.stack(
25
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
26
+ ) # (N, L)
27
+ nonzero_mask = kpts0_depth != 0
28
+
29
+ # Unproject
30
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
31
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
32
+
33
+ # Rigid Transform
34
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
35
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
36
+
37
+ # Project
38
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
39
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
40
+
41
+ # Covisible Check
42
+ h, w = depth1.shape[1:3]
43
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
44
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
45
+ w_kpts0_long = w_kpts0.long()
46
+ w_kpts0_long[~covisible_mask, :] = 0
47
+
48
+ w_kpts0_depth = torch.stack(
49
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
50
+ ) # (N, L)
51
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
52
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
53
+
54
+ return valid_mask, w_kpts0
One-2-3-45-master 2/elevation_estimate/loftr/utils/position_encoding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PositionEncodingSine(nn.Module):
7
+ """
8
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
9
+ """
10
+
11
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
12
+ """
13
+ Args:
14
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
15
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
16
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
17
+ on the final performance. For now, we keep both impls for backward compatability.
18
+ We will remove the buggy impl after re-training all variants of our released models.
19
+ """
20
+ super().__init__()
21
+
22
+ pe = torch.zeros((d_model, *max_shape))
23
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
24
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
25
+ if temp_bug_fix:
26
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
27
+ else: # a buggy implementation (for backward compatability only)
28
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
29
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
30
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
31
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
32
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
33
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
34
+
35
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
36
+
37
+ def forward(self, x):
38
+ """
39
+ Args:
40
+ x: [N, C, H, W]
41
+ """
42
+ return x + self.pe[:, :, :x.size(2), :x.size(3)]
One-2-3-45-master 2/elevation_estimate/loftr/utils/supervision.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import log
2
+ from loguru import logger
3
+
4
+ import torch
5
+ from einops import repeat
6
+ from kornia.utils import create_meshgrid
7
+
8
+ from .geometry import warp_kpts
9
+
10
+ ############## ↓ Coarse-Level supervision ↓ ##############
11
+
12
+
13
+ @torch.no_grad()
14
+ def mask_pts_at_padded_regions(grid_pt, mask):
15
+ """For megadepth dataset, zero-padding exists in images"""
16
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
17
+ grid_pt[~mask.bool()] = 0
18
+ return grid_pt
19
+
20
+
21
+ @torch.no_grad()
22
+ def spvs_coarse(data, config):
23
+ """
24
+ Update:
25
+ data (dict): {
26
+ "conf_matrix_gt": [N, hw0, hw1],
27
+ 'spv_b_ids': [M]
28
+ 'spv_i_ids': [M]
29
+ 'spv_j_ids': [M]
30
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
31
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
32
+ }
33
+
34
+ NOTE:
35
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
36
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
37
+ """
38
+ # 1. misc
39
+ device = data['image0'].device
40
+ N, _, H0, W0 = data['image0'].shape
41
+ _, _, H1, W1 = data['image1'].shape
42
+ scale = config['LOFTR']['RESOLUTION'][0]
43
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
44
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
45
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
46
+
47
+ # 2. warp grids
48
+ # create kpts in meshgrid and resize them to image resolution
49
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
50
+ grid_pt0_i = scale0 * grid_pt0_c
51
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
52
+ grid_pt1_i = scale1 * grid_pt1_c
53
+
54
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
55
+ if 'mask0' in data:
56
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
57
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
58
+
59
+ # warp kpts bi-directionally and resize them to coarse-level resolution
60
+ # (no depth consistency check, since it leads to worse results experimentally)
61
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
62
+ _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
63
+ _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
64
+ w_pt0_c = w_pt0_i / scale1
65
+ w_pt1_c = w_pt1_i / scale0
66
+
67
+ # 3. check if mutual nearest neighbor
68
+ w_pt0_c_round = w_pt0_c[:, :, :].round().long()
69
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
70
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
71
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
72
+
73
+ # corner case: out of boundary
74
+ def out_bound_mask(pt, w, h):
75
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
76
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
77
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
78
+
79
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
80
+ correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
81
+ correct_0to1[:, 0] = False # ignore the top-left corner
82
+
83
+ # 4. construct a gt conf_matrix
84
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
85
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
86
+ j_ids = nearest_index1[b_ids, i_ids]
87
+
88
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
89
+ data.update({'conf_matrix_gt': conf_matrix_gt})
90
+
91
+ # 5. save coarse matches(gt) for training fine level
92
+ if len(b_ids) == 0:
93
+ logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
94
+ # this won't affect fine-level loss calculation
95
+ b_ids = torch.tensor([0], device=device)
96
+ i_ids = torch.tensor([0], device=device)
97
+ j_ids = torch.tensor([0], device=device)
98
+
99
+ data.update({
100
+ 'spv_b_ids': b_ids,
101
+ 'spv_i_ids': i_ids,
102
+ 'spv_j_ids': j_ids
103
+ })
104
+
105
+ # 6. save intermediate results (for fast fine-level computation)
106
+ data.update({
107
+ 'spv_w_pt0_i': w_pt0_i,
108
+ 'spv_pt1_i': grid_pt1_i
109
+ })
110
+
111
+
112
+ def compute_supervision_coarse(data, config):
113
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
114
+ data_source = data['dataset_name'][0]
115
+ if data_source.lower() in ['scannet', 'megadepth']:
116
+ spvs_coarse(data, config)
117
+ else:
118
+ raise ValueError(f'Unknown data source: {data_source}')
119
+
120
+
121
+ ############## ↓ Fine-Level supervision ↓ ##############
122
+
123
+ @torch.no_grad()
124
+ def spvs_fine(data, config):
125
+ """
126
+ Update:
127
+ data (dict):{
128
+ "expec_f_gt": [M, 2]}
129
+ """
130
+ # 1. misc
131
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
132
+ w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
133
+ scale = config['LOFTR']['RESOLUTION'][1]
134
+ radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
135
+
136
+ # 2. get coarse prediction
137
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
138
+
139
+ # 3. compute gt
140
+ scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
141
+ # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
142
+ expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
143
+ data.update({"expec_f_gt": expec_f_gt})
144
+
145
+
146
+ def compute_supervision_fine(data, config):
147
+ data_source = data['dataset_name'][0]
148
+ if data_source.lower() in ['scannet', 'megadepth']:
149
+ spvs_fine(data, config)
150
+ else:
151
+ raise NotImplementedError
One-2-3-45-master 2/elevation_estimate/pyproject.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "elevation_estimate"
3
+ version = "0.1"
4
+
5
+ [tool.setuptools.packages.find]
6
+ exclude = ["configs", "tests"] # empty by default
7
+ namespaces = false # true by default
One-2-3-45-master 2/elevation_estimate/utils/__init__.py ADDED
File without changes
One-2-3-45-master 2/elevation_estimate/utils/elev_est_api.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import os.path as osp
5
+ import imageio
6
+ from copy import deepcopy
7
+
8
+ import loguru
9
+ import torch
10
+ import matplotlib.cm as cm
11
+ import matplotlib.pyplot as plt
12
+
13
+ from ..loftr import LoFTR, default_cfg
14
+ from . import plt_utils
15
+ from .plotting import make_matching_figure
16
+ from .utils3d import rect_to_img, canonical_to_camera, calc_pose
17
+
18
+
19
+ class ElevEstHelper:
20
+ _feature_matcher = None
21
+
22
+ @classmethod
23
+ def get_feature_matcher(cls):
24
+ if cls._feature_matcher is None:
25
+ loguru.logger.info("Loading feature matcher...")
26
+ _default_cfg = deepcopy(default_cfg)
27
+ _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt
28
+ matcher = LoFTR(config=_default_cfg)
29
+ current_dir = os.path.dirname(os.path.abspath(__file__))
30
+ ckpt_path = os.path.join(current_dir, "weights/indoor_ds_new.ckpt")
31
+ if not osp.exists(ckpt_path):
32
+ loguru.logger.info("Downloading feature matcher...")
33
+ os.makedirs("weights", exist_ok=True)
34
+ import gdown
35
+ gdown.cached_download(url="https://drive.google.com/uc?id=19s3QvcCWQ6g-N1PrYlDCg-2mOJZ3kkgS",
36
+ path=ckpt_path)
37
+ matcher.load_state_dict(torch.load(ckpt_path)['state_dict'])
38
+ matcher = matcher.eval().cuda()
39
+ cls._feature_matcher = matcher
40
+ return cls._feature_matcher
41
+
42
+
43
+ def mask_out_bkgd(img_path, dbg=False):
44
+ img = imageio.imread_v2(img_path)
45
+ if img.shape[-1] == 4:
46
+ fg_mask = img[:, :, :3]
47
+ else:
48
+ loguru.logger.info("Image has no alpha channel, using thresholding to mask out background")
49
+ fg_mask = ~(img > 245).all(axis=-1)
50
+ if dbg:
51
+ plt.imshow(plt_utils.vis_mask(img, fg_mask.astype(np.uint8), color=[0, 255, 0]))
52
+ plt.show()
53
+ return fg_mask
54
+
55
+
56
+ def get_feature_matching(img_paths, dbg=False):
57
+ assert len(img_paths) == 4
58
+ matcher = ElevEstHelper.get_feature_matcher()
59
+ feature_matching = {}
60
+ masks = []
61
+ for i in range(4):
62
+ mask = mask_out_bkgd(img_paths[i], dbg=dbg)
63
+ masks.append(mask)
64
+ for i in range(0, 4):
65
+ for j in range(i + 1, 4):
66
+ img0_pth = img_paths[i]
67
+ img1_pth = img_paths[j]
68
+ mask0 = masks[i]
69
+ mask1 = masks[j]
70
+ img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)
71
+ img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)
72
+ original_shape = img0_raw.shape
73
+ img0_raw_resized = cv2.resize(img0_raw, (480, 480))
74
+ img1_raw_resized = cv2.resize(img1_raw, (480, 480))
75
+
76
+ img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255.
77
+ img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255.
78
+ batch = {'image0': img0, 'image1': img1}
79
+
80
+ # Inference with LoFTR and get prediction
81
+ with torch.no_grad():
82
+ matcher(batch)
83
+ mkpts0 = batch['mkpts0_f'].cpu().numpy()
84
+ mkpts1 = batch['mkpts1_f'].cpu().numpy()
85
+ mconf = batch['mconf'].cpu().numpy()
86
+ mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480
87
+ mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480
88
+ mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480
89
+ mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480
90
+ keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)]
91
+ keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)]
92
+ keep = np.logical_and(keep0, keep1)
93
+ mkpts0 = mkpts0[keep]
94
+ mkpts1 = mkpts1[keep]
95
+ mconf = mconf[keep]
96
+ if dbg:
97
+ # Draw visualization
98
+ color = cm.jet(mconf)
99
+ text = [
100
+ 'LoFTR',
101
+ 'Matches: {}'.format(len(mkpts0)),
102
+ ]
103
+ fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)
104
+ fig.show()
105
+ feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1)
106
+
107
+ return feature_matching
108
+
109
+
110
+ def gen_pose_hypothesis(center_elevation):
111
+ elevations = np.radians(
112
+ [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120
113
+ azimuths = np.radians([30, 30, 30, 20, 40])
114
+ input_poses = calc_pose(elevations, azimuths, len(azimuths))
115
+ input_poses = input_poses[1:]
116
+ input_poses[..., 1] *= -1
117
+ input_poses[..., 2] *= -1
118
+ return input_poses
119
+
120
+
121
+ def ba_error_general(K, matches, poses):
122
+ projmat0 = K @ poses[0].inverse()[:3, :4]
123
+ projmat1 = K @ poses[1].inverse()[:3, :4]
124
+ match_01 = matches[0]
125
+ pts0 = match_01[:, :2]
126
+ pts1 = match_01[:, 2:4]
127
+ Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(),
128
+ pts0.cpu().numpy().T, pts1.cpu().numpy().T)
129
+ Xref = Xref[:3] / Xref[3:]
130
+ Xref = Xref.T
131
+ Xref = torch.from_numpy(Xref).cuda().float()
132
+ reproj_error = 0
133
+ for match, cp in zip(matches[1:], poses[2:]):
134
+ dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1))
135
+ if dist.numel() > 0:
136
+ # print("dist.shape", dist.shape)
137
+ m0to2_index = dist.argmin(1)
138
+ keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1
139
+ if keep.sum() > 0:
140
+ xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse()))
141
+ reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1)
142
+ conf02 = match[m0to2_index][keep][:, -1]
143
+ reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum())
144
+
145
+ return reproj_error
146
+
147
+
148
+ def find_optim_elev(elevs, nimgs, matches, K, dbg=False):
149
+ errs = []
150
+ for elev in elevs:
151
+ err = 0
152
+ cam_poses = gen_pose_hypothesis(elev)
153
+ for start in range(nimgs - 1):
154
+ batch_matches, batch_poses = [], []
155
+ for i in range(start, nimgs + start):
156
+ ci = i % nimgs
157
+ batch_poses.append(cam_poses[ci])
158
+ for j in range(nimgs - 1):
159
+ key = f"{start}_{(start + j + 1) % nimgs}"
160
+ match = matches[key]
161
+ batch_matches.append(match)
162
+ err += ba_error_general(K, batch_matches, batch_poses)
163
+ errs.append(err)
164
+ errs = torch.tensor(errs)
165
+ if dbg:
166
+ plt.plot(elevs, errs)
167
+ plt.show()
168
+ optim_elev = elevs[torch.argmin(errs)].item()
169
+ return optim_elev
170
+
171
+
172
+ def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None, dbg=False):
173
+ flag = True
174
+ matches = {}
175
+ for i in range(4):
176
+ for j in range(i + 1, 4):
177
+ match_ij = feature_matching[f"{i}_{j}"]
178
+ if len(match_ij) == 0:
179
+ flag = False
180
+ match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1)
181
+ matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float().cuda()
182
+ matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float().cuda()
183
+ if not flag:
184
+ loguru.logger.info("0 matches, could not estimate elevation")
185
+ return None
186
+ interval = 10
187
+ elevs = np.arange(min_elev, max_elev, interval)
188
+ optim_elev1 = find_optim_elev(elevs, 4, matches, K)
189
+
190
+ elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1)
191
+ optim_elev2 = find_optim_elev(elevs, 4, matches, K)
192
+
193
+ return optim_elev2
194
+
195
+
196
+ def elev_est_api(img_paths, min_elev=30, max_elev=150, K=None, dbg=False):
197
+ feature_matching = get_feature_matching(img_paths, dbg=dbg)
198
+ if K is None:
199
+ loguru.logger.warning("K is not provided, using default K")
200
+ K = np.array([[280.0, 0, 128.0],
201
+ [0, 280.0, 128.0],
202
+ [0, 0, 1]])
203
+ K = torch.from_numpy(K).cuda().float()
204
+ elev = get_elev_est(feature_matching, min_elev, max_elev, K, dbg=dbg)
205
+ return elev
One-2-3-45-master 2/elevation_estimate/utils/plotting.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+
6
+
7
+ def _compute_conf_thresh(data):
8
+ dataset_name = data['dataset_name'][0].lower()
9
+ if dataset_name == 'scannet':
10
+ thr = 5e-4
11
+ elif dataset_name == 'megadepth':
12
+ thr = 1e-4
13
+ else:
14
+ raise ValueError(f'Unknown dataset: {dataset_name}')
15
+ return thr
16
+
17
+
18
+ # --- VISUALIZATION --- #
19
+
20
+ def make_matching_figure(
21
+ img0, img1, mkpts0, mkpts1, color,
22
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
23
+ # draw image pair
24
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
25
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
26
+ axes[0].imshow(img0, cmap='gray')
27
+ axes[1].imshow(img1, cmap='gray')
28
+ for i in range(2): # clear all frames
29
+ axes[i].get_yaxis().set_ticks([])
30
+ axes[i].get_xaxis().set_ticks([])
31
+ for spine in axes[i].spines.values():
32
+ spine.set_visible(False)
33
+ plt.tight_layout(pad=1)
34
+
35
+ if kpts0 is not None:
36
+ assert kpts1 is not None
37
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
38
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
39
+
40
+ # draw matches
41
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
42
+ fig.canvas.draw()
43
+ transFigure = fig.transFigure.inverted()
44
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
45
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
46
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
47
+ (fkpts0[i, 1], fkpts1[i, 1]),
48
+ transform=fig.transFigure, c=color[i], linewidth=1)
49
+ for i in range(len(mkpts0))]
50
+
51
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
52
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
53
+
54
+ # put txts
55
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
56
+ fig.text(
57
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
58
+ fontsize=15, va='top', ha='left', color=txt_color)
59
+
60
+ # save or return figure
61
+ if path:
62
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
63
+ plt.close()
64
+ else:
65
+ return fig
66
+
67
+
68
+ def _make_evaluation_figure(data, b_id, alpha='dynamic'):
69
+ b_mask = data['m_bids'] == b_id
70
+ conf_thr = _compute_conf_thresh(data)
71
+
72
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
73
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
74
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
75
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
76
+
77
+ # for megadepth, we visualize matches on the resized image
78
+ if 'scale0' in data:
79
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
80
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
81
+
82
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
83
+ correct_mask = epi_errs < conf_thr
84
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
85
+ n_correct = np.sum(correct_mask)
86
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
87
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
88
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
89
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
90
+
91
+ # matching info
92
+ if alpha == 'dynamic':
93
+ alpha = dynamic_alpha(len(correct_mask))
94
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
95
+
96
+ text = [
97
+ f'#Matches {len(kpts0)}',
98
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
99
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
100
+ ]
101
+
102
+ # make the figure
103
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
104
+ color, text=text)
105
+ return figure
106
+
107
+ def _make_confidence_figure(data, b_id):
108
+ # TODO: Implement confidence figure
109
+ raise NotImplementedError()
110
+
111
+
112
+ def make_matching_figures(data, config, mode='evaluation'):
113
+ """ Make matching figures for a batch.
114
+
115
+ Args:
116
+ data (Dict): a batch updated by PL_LoFTR.
117
+ config (Dict): matcher config
118
+ Returns:
119
+ figures (Dict[str, List[plt.figure]]
120
+ """
121
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
122
+ figures = {mode: []}
123
+ for b_id in range(data['image0'].size(0)):
124
+ if mode == 'evaluation':
125
+ fig = _make_evaluation_figure(
126
+ data, b_id,
127
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
128
+ elif mode == 'confidence':
129
+ fig = _make_confidence_figure(data, b_id)
130
+ else:
131
+ raise ValueError(f'Unknown plot mode: {mode}')
132
+ figures[mode].append(fig)
133
+ return figures
134
+
135
+
136
+ def dynamic_alpha(n_matches,
137
+ milestones=[0, 300, 1000, 2000],
138
+ alphas=[1.0, 0.8, 0.4, 0.2]):
139
+ if n_matches == 0:
140
+ return 1.0
141
+ ranges = list(zip(alphas, alphas[1:] + [None]))
142
+ loc = bisect.bisect_right(milestones, n_matches) - 1
143
+ _range = ranges[loc]
144
+ if _range[1] is None:
145
+ return _range[0]
146
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
147
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
148
+
149
+
150
+ def error_colormap(err, thr, alpha=1.0):
151
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
152
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
153
+ return np.clip(
154
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
One-2-3-45-master 2/elevation_estimate/utils/plt_utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import cv2
6
+ import math
7
+
8
+ import numpy as np
9
+ import tqdm
10
+ from cv2 import findContours
11
+ from dl_ext.primitive import safe_zip
12
+ from dl_ext.timer import EvalTime
13
+
14
+
15
+ def plot_confidence(confidence):
16
+ n = len(confidence)
17
+ plt.plot(np.arange(n), confidence)
18
+ plt.show()
19
+
20
+
21
+ def image_grid(
22
+ images,
23
+ rows=None,
24
+ cols=None,
25
+ fill: bool = True,
26
+ show_axes: bool = False,
27
+ rgb=None,
28
+ show=True,
29
+ label=None,
30
+ **kwargs
31
+ ):
32
+ """
33
+ A util function for plotting a grid of images.
34
+ Args:
35
+ images: (N, H, W, 4) array of RGBA images
36
+ rows: number of rows in the grid
37
+ cols: number of columns in the grid
38
+ fill: boolean indicating if the space between images should be filled
39
+ show_axes: boolean indicating if the axes of the plots should be visible
40
+ rgb: boolean, If True, only RGB channels are plotted.
41
+ If False, only the alpha channel is plotted.
42
+ Returns:
43
+ None
44
+ """
45
+ evaltime = EvalTime(disable=True)
46
+ evaltime('')
47
+ if isinstance(images, torch.Tensor):
48
+ images = images.detach().cpu()
49
+ if len(images[0].shape) == 2:
50
+ rgb = False
51
+ if images[0].shape[-1] == 2:
52
+ # flow
53
+ images = [flow_to_image(im) for im in images]
54
+ if (rows is None) != (cols is None):
55
+ raise ValueError("Specify either both rows and cols or neither.")
56
+
57
+ if rows is None:
58
+ rows = int(len(images) ** 0.5)
59
+ cols = math.ceil(len(images) / rows)
60
+
61
+ gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
62
+ if len(images) < 50:
63
+ figsize = (10, 10)
64
+ else:
65
+ figsize = (15, 15)
66
+ evaltime('0.5')
67
+ plt.figure(figsize=figsize)
68
+ # fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=figsize)
69
+ if label:
70
+ # fig.suptitle(label, fontsize=30)
71
+ plt.suptitle(label, fontsize=30)
72
+ # bleed = 0
73
+ # fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
74
+ evaltime('subplots')
75
+
76
+ # for i, (ax, im) in enumerate(tqdm.tqdm(zip(axarr.ravel(), images), leave=True, total=len(images))):
77
+ for i in range(len(images)):
78
+ # evaltime(f'{i} begin')
79
+ plt.subplot(rows, cols, i + 1)
80
+ if rgb:
81
+ # only render RGB channels
82
+ plt.imshow(images[i][..., :3], **kwargs)
83
+ # ax.imshow(im[..., :3], **kwargs)
84
+ else:
85
+ # only render Alpha channel
86
+ plt.imshow(images[i], **kwargs)
87
+ # ax.imshow(im, **kwargs)
88
+ if not show_axes:
89
+ plt.axis('off')
90
+ # ax.set_axis_off()
91
+ # ax.set_title(f'{i}')
92
+ plt.title(f'{i}')
93
+ # evaltime(f'{i} end')
94
+ evaltime('2')
95
+ if show:
96
+ plt.show()
97
+ # return fig
98
+
99
+
100
+ def depth_grid(
101
+ depths,
102
+ rows=None,
103
+ cols=None,
104
+ fill: bool = True,
105
+ show_axes: bool = False,
106
+ ):
107
+ """
108
+ A util function for plotting a grid of images.
109
+ Args:
110
+ images: (N, H, W, 4) array of RGBA images
111
+ rows: number of rows in the grid
112
+ cols: number of columns in the grid
113
+ fill: boolean indicating if the space between images should be filled
114
+ show_axes: boolean indicating if the axes of the plots should be visible
115
+ rgb: boolean, If True, only RGB channels are plotted.
116
+ If False, only the alpha channel is plotted.
117
+ Returns:
118
+ None
119
+ """
120
+ if (rows is None) != (cols is None):
121
+ raise ValueError("Specify either both rows and cols or neither.")
122
+
123
+ if rows is None:
124
+ rows = len(depths)
125
+ cols = 1
126
+
127
+ gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
128
+ fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
129
+ bleed = 0
130
+ fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
131
+
132
+ for ax, im in zip(axarr.ravel(), depths):
133
+ ax.imshow(im)
134
+ if not show_axes:
135
+ ax.set_axis_off()
136
+ plt.show()
137
+
138
+
139
+ def hover_masks_on_imgs(images, masks):
140
+ masks = np.array(masks)
141
+ new_imgs = []
142
+ tids = list(range(1, masks.max() + 1))
143
+ colors = colormap(rgb=True, lighten=True)
144
+ for im, mask in tqdm.tqdm(safe_zip(images, masks), total=len(images)):
145
+ for tid in tids:
146
+ im = vis_mask(
147
+ im,
148
+ (mask == tid).astype(np.uint8),
149
+ color=colors[tid],
150
+ alpha=0.5,
151
+ border_alpha=0.5,
152
+ border_color=[255, 255, 255],
153
+ border_thick=3)
154
+ new_imgs.append(im)
155
+ return new_imgs
156
+
157
+
158
+ def vis_mask(img,
159
+ mask,
160
+ color=[255, 255, 255],
161
+ alpha=0.4,
162
+ show_border=True,
163
+ border_alpha=0.5,
164
+ border_thick=1,
165
+ border_color=None):
166
+ """Visualizes a single binary mask."""
167
+ if isinstance(mask, torch.Tensor):
168
+ from anypose.utils.pn_utils import to_array
169
+ mask = to_array(mask > 0).astype(np.uint8)
170
+ img = img.astype(np.float32)
171
+ idx = np.nonzero(mask)
172
+
173
+ img[idx[0], idx[1], :] *= 1.0 - alpha
174
+ img[idx[0], idx[1], :] += [alpha * x for x in color]
175
+
176
+ if show_border:
177
+ contours, _ = findContours(
178
+ mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
179
+ # contours = [c for c in contours if c.shape[0] > 10]
180
+ if border_color is None:
181
+ border_color = color
182
+ if not isinstance(border_color, list):
183
+ border_color = border_color.tolist()
184
+ if border_alpha < 1:
185
+ with_border = img.copy()
186
+ cv2.drawContours(with_border, contours, -1, border_color,
187
+ border_thick, cv2.LINE_AA)
188
+ img = (1 - border_alpha) * img + border_alpha * with_border
189
+ else:
190
+ cv2.drawContours(img, contours, -1, border_color, border_thick,
191
+ cv2.LINE_AA)
192
+
193
+ return img.astype(np.uint8)
194
+
195
+
196
+ def colormap(rgb=False, lighten=True):
197
+ """Copied from Detectron codebase."""
198
+ color_list = np.array(
199
+ [
200
+ 0.000, 0.447, 0.741,
201
+ 0.850, 0.325, 0.098,
202
+ 0.929, 0.694, 0.125,
203
+ 0.494, 0.184, 0.556,
204
+ 0.466, 0.674, 0.188,
205
+ 0.301, 0.745, 0.933,
206
+ 0.635, 0.078, 0.184,
207
+ 0.300, 0.300, 0.300,
208
+ 0.600, 0.600, 0.600,
209
+ 1.000, 0.000, 0.000,
210
+ 1.000, 0.500, 0.000,
211
+ 0.749, 0.749, 0.000,
212
+ 0.000, 1.000, 0.000,
213
+ 0.000, 0.000, 1.000,
214
+ 0.667, 0.000, 1.000,
215
+ 0.333, 0.333, 0.000,
216
+ 0.333, 0.667, 0.000,
217
+ 0.333, 1.000, 0.000,
218
+ 0.667, 0.333, 0.000,
219
+ 0.667, 0.667, 0.000,
220
+ 0.667, 1.000, 0.000,
221
+ 1.000, 0.333, 0.000,
222
+ 1.000, 0.667, 0.000,
223
+ 1.000, 1.000, 0.000,
224
+ 0.000, 0.333, 0.500,
225
+ 0.000, 0.667, 0.500,
226
+ 0.000, 1.000, 0.500,
227
+ 0.333, 0.000, 0.500,
228
+ 0.333, 0.333, 0.500,
229
+ 0.333, 0.667, 0.500,
230
+ 0.333, 1.000, 0.500,
231
+ 0.667, 0.000, 0.500,
232
+ 0.667, 0.333, 0.500,
233
+ 0.667, 0.667, 0.500,
234
+ 0.667, 1.000, 0.500,
235
+ 1.000, 0.000, 0.500,
236
+ 1.000, 0.333, 0.500,
237
+ 1.000, 0.667, 0.500,
238
+ 1.000, 1.000, 0.500,
239
+ 0.000, 0.333, 1.000,
240
+ 0.000, 0.667, 1.000,
241
+ 0.000, 1.000, 1.000,
242
+ 0.333, 0.000, 1.000,
243
+ 0.333, 0.333, 1.000,
244
+ 0.333, 0.667, 1.000,
245
+ 0.333, 1.000, 1.000,
246
+ 0.667, 0.000, 1.000,
247
+ 0.667, 0.333, 1.000,
248
+ 0.667, 0.667, 1.000,
249
+ 0.667, 1.000, 1.000,
250
+ 1.000, 0.000, 1.000,
251
+ 1.000, 0.333, 1.000,
252
+ 1.000, 0.667, 1.000,
253
+ 0.167, 0.000, 0.000,
254
+ 0.333, 0.000, 0.000,
255
+ 0.500, 0.000, 0.000,
256
+ 0.667, 0.000, 0.000,
257
+ 0.833, 0.000, 0.000,
258
+ 1.000, 0.000, 0.000,
259
+ 0.000, 0.167, 0.000,
260
+ 0.000, 0.333, 0.000,
261
+ 0.000, 0.500, 0.000,
262
+ 0.000, 0.667, 0.000,
263
+ 0.000, 0.833, 0.000,
264
+ 0.000, 1.000, 0.000,
265
+ 0.000, 0.000, 0.167,
266
+ 0.000, 0.000, 0.333,
267
+ 0.000, 0.000, 0.500,
268
+ 0.000, 0.000, 0.667,
269
+ 0.000, 0.000, 0.833,
270
+ 0.000, 0.000, 1.000,
271
+ 0.000, 0.000, 0.000,
272
+ 0.143, 0.143, 0.143,
273
+ 0.286, 0.286, 0.286,
274
+ 0.429, 0.429, 0.429,
275
+ 0.571, 0.571, 0.571,
276
+ 0.714, 0.714, 0.714,
277
+ 0.857, 0.857, 0.857,
278
+ 1.000, 1.000, 1.000
279
+ ]
280
+ ).astype(np.float32)
281
+ color_list = color_list.reshape((-1, 3))
282
+ if not rgb:
283
+ color_list = color_list[:, ::-1]
284
+
285
+ if lighten:
286
+ # Make all the colors a little lighter / whiter. This is copied
287
+ # from the detectron visualization code (search for 'w_ratio').
288
+ w_ratio = 0.4
289
+ color_list = (color_list * (1 - w_ratio) + w_ratio)
290
+ return color_list * 255
291
+
292
+
293
+ def vis_layer_mask(masks, save_path=None):
294
+ masks = torch.as_tensor(masks)
295
+ tids = masks.unique().tolist()
296
+ tids.remove(0)
297
+ for tid in tqdm.tqdm(tids):
298
+ show = save_path is None
299
+ image_grid(masks == tid, label=f'{tid}', show=show)
300
+ if save_path:
301
+ os.makedirs(osp.dirname(save_path), exist_ok=True)
302
+ plt.savefig(save_path % tid)
303
+ plt.close('all')
304
+
305
+
306
+ def show(x, **kwargs):
307
+ if isinstance(x, torch.Tensor):
308
+ x = x.detach().cpu()
309
+ plt.imshow(x, **kwargs)
310
+ plt.show()
311
+
312
+
313
+ def vis_title(rgb, text, shift_y=30):
314
+ tmp = rgb.copy()
315
+ shift_x = rgb.shape[1] // 2
316
+ cv2.putText(tmp, text,
317
+ (shift_x, shift_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2, lineType=cv2.LINE_AA)
318
+ return tmp
One-2-3-45-master 2/elevation_estimate/utils/utils3d.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def cart_to_hom(pts):
6
+ """
7
+ :param pts: (N, 3 or 2)
8
+ :return pts_hom: (N, 4 or 3)
9
+ """
10
+ if isinstance(pts, np.ndarray):
11
+ pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1)
12
+ else:
13
+ ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device)
14
+ pts_hom = torch.cat((pts, ones), dim=-1)
15
+ return pts_hom
16
+
17
+
18
+ def hom_to_cart(pts):
19
+ return pts[..., :-1] / pts[..., -1:]
20
+
21
+
22
+ def canonical_to_camera(pts, pose):
23
+ pts = cart_to_hom(pts)
24
+ pts = pts @ pose.transpose(-1, -2)
25
+ pts = hom_to_cart(pts)
26
+ return pts
27
+
28
+
29
+ def rect_to_img(K, pts_rect):
30
+ from dl_ext.vision_ext.datasets.kitti.structures import Calibration
31
+ pts_2d_hom = pts_rect @ K.t()
32
+ pts_img = Calibration.hom_to_cart(pts_2d_hom)
33
+ return pts_img
34
+
35
+
36
+ def calc_pose(phis, thetas, size, radius=1.2):
37
+ import torch
38
+ def normalize(vectors):
39
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
40
+
41
+ device = torch.device('cuda')
42
+ thetas = torch.FloatTensor(thetas).to(device)
43
+ phis = torch.FloatTensor(phis).to(device)
44
+
45
+ centers = torch.stack([
46
+ radius * torch.sin(thetas) * torch.sin(phis),
47
+ -radius * torch.cos(thetas) * torch.sin(phis),
48
+ radius * torch.cos(phis),
49
+ ], dim=-1) # [B, 3]
50
+
51
+ # lookat
52
+ forward_vector = normalize(centers).squeeze(0)
53
+ up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1)
54
+ right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))
55
+ if right_vector.pow(2).sum() < 0.01:
56
+ right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
57
+ up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))
58
+
59
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
60
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
61
+ poses[:, :3, 3] = centers
62
+ return poses
One-2-3-45-master 2/elevation_estimate/utils/weights/.gitkeep ADDED
File without changes
One-2-3-45-master 2/example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
One-2-3-45-master 2/ldm/data/__init__.py ADDED
File without changes
One-2-3-45-master 2/ldm/data/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
5
+
6
+
7
+ class Txt2ImgIterableBaseDataset(IterableDataset):
8
+ '''
9
+ Define an interface to make the IterableDatasets for text2img data chainable
10
+ '''
11
+ def __init__(self, num_records=0, valid_ids=None, size=256):
12
+ super().__init__()
13
+ self.num_records = num_records
14
+ self.valid_ids = valid_ids
15
+ self.sample_ids = valid_ids
16
+ self.size = size
17
+
18
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
19
+
20
+ def __len__(self):
21
+ return self.num_records
22
+
23
+ @abstractmethod
24
+ def __iter__(self):
25
+ pass
26
+
27
+
28
+ class PRNGMixin(object):
29
+ """
30
+ Adds a prng property which is a numpy RandomState which gets
31
+ reinitialized whenever the pid changes to avoid synchronized sampling
32
+ behavior when used in conjunction with multiprocessing.
33
+ """
34
+ @property
35
+ def prng(self):
36
+ currentpid = os.getpid()
37
+ if getattr(self, "_initpid", None) != currentpid:
38
+ self._initpid = currentpid
39
+ self._prng = np.random.RandomState()
40
+ return self._prng
One-2-3-45-master 2/ldm/data/coco.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import albumentations
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from torch.utils.data import Dataset
8
+ from abc import abstractmethod
9
+
10
+
11
+ class CocoBase(Dataset):
12
+ """needed for (image, caption, segmentation) pairs"""
13
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
14
+ crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None):
15
+ self.split = self.get_split()
16
+ self.size = size
17
+ if crop_size is None:
18
+ self.crop_size = size
19
+ else:
20
+ self.crop_size = crop_size
21
+
22
+ assert crop_type in [None, 'random', 'center']
23
+ self.crop_type = crop_type
24
+ self.use_segmenation = use_segmentation
25
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
26
+ self.stuffthing = use_stuffthing # include thing in segmentation
27
+ if self.onehot and not self.stuffthing:
28
+ raise NotImplemented("One hot mode is only supported for the "
29
+ "stuffthings version because labels are stored "
30
+ "a bit different.")
31
+
32
+ data_json = datajson
33
+ with open(data_json) as json_file:
34
+ self.json_data = json.load(json_file)
35
+ self.img_id_to_captions = dict()
36
+ self.img_id_to_filepath = dict()
37
+ self.img_id_to_segmentation_filepath = dict()
38
+
39
+ assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json",
40
+ f"captions_val{self.year()}.json"]
41
+ # TODO currently hardcoded paths, would be better to follow logic in
42
+ # cocstuff pixelmaps
43
+ if self.use_segmenation:
44
+ if self.stuffthing:
45
+ self.segmentation_prefix = (
46
+ f"data/cocostuffthings/val{self.year()}" if
47
+ data_json.endswith(f"captions_val{self.year()}.json") else
48
+ f"data/cocostuffthings/train{self.year()}")
49
+ else:
50
+ self.segmentation_prefix = (
51
+ f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if
52
+ data_json.endswith(f"captions_val{self.year()}.json") else
53
+ f"data/coco/annotations/stuff_train{self.year()}_pixelmaps")
54
+
55
+ imagedirs = self.json_data["images"]
56
+ self.labels = {"image_ids": list()}
57
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
58
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
59
+ self.img_id_to_captions[imgdir["id"]] = list()
60
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
61
+ if self.use_segmenation:
62
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
63
+ self.segmentation_prefix, pngfilename)
64
+ if given_files is not None:
65
+ if pngfilename in given_files:
66
+ self.labels["image_ids"].append(imgdir["id"])
67
+ else:
68
+ self.labels["image_ids"].append(imgdir["id"])
69
+
70
+ capdirs = self.json_data["annotations"]
71
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
72
+ # there are in average 5 captions per image
73
+ #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
74
+ self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"])
75
+
76
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
77
+ if self.split=="validation":
78
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
79
+ else:
80
+ # default option for train is random crop
81
+ if self.crop_type in [None, 'random']:
82
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
83
+ else:
84
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
85
+ self.preprocessor = albumentations.Compose(
86
+ [self.rescaler, self.cropper],
87
+ additional_targets={"segmentation": "image"})
88
+ if force_no_crop:
89
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
90
+ self.preprocessor = albumentations.Compose(
91
+ [self.rescaler],
92
+ additional_targets={"segmentation": "image"})
93
+
94
+ @abstractmethod
95
+ def year(self):
96
+ raise NotImplementedError()
97
+
98
+ def __len__(self):
99
+ return len(self.labels["image_ids"])
100
+
101
+ def preprocess_image(self, image_path, segmentation_path=None):
102
+ image = Image.open(image_path)
103
+ if not image.mode == "RGB":
104
+ image = image.convert("RGB")
105
+ image = np.array(image).astype(np.uint8)
106
+ if segmentation_path:
107
+ segmentation = Image.open(segmentation_path)
108
+ if not self.onehot and not segmentation.mode == "RGB":
109
+ segmentation = segmentation.convert("RGB")
110
+ segmentation = np.array(segmentation).astype(np.uint8)
111
+ if self.onehot:
112
+ assert self.stuffthing
113
+ # stored in caffe format: unlabeled==255. stuff and thing from
114
+ # 0-181. to be compatible with the labels in
115
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
116
+ # we shift stuffthing one to the right and put unlabeled in zero
117
+ # as long as segmentation is uint8 shifting to right handles the
118
+ # latter too
119
+ assert segmentation.dtype == np.uint8
120
+ segmentation = segmentation + 1
121
+
122
+ processed = self.preprocessor(image=image, segmentation=segmentation)
123
+
124
+ image, segmentation = processed["image"], processed["segmentation"]
125
+ else:
126
+ image = self.preprocessor(image=image,)['image']
127
+
128
+ image = (image / 127.5 - 1.0).astype(np.float32)
129
+ if segmentation_path:
130
+ if self.onehot:
131
+ assert segmentation.dtype == np.uint8
132
+ # make it one hot
133
+ n_labels = 183
134
+ flatseg = np.ravel(segmentation)
135
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
136
+ onehot[np.arange(flatseg.size), flatseg] = True
137
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
138
+ segmentation = onehot
139
+ else:
140
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
141
+ return image, segmentation
142
+ else:
143
+ return image
144
+
145
+ def __getitem__(self, i):
146
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
147
+ if self.use_segmenation:
148
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
149
+ image, segmentation = self.preprocess_image(img_path, seg_path)
150
+ else:
151
+ image = self.preprocess_image(img_path)
152
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
153
+ # randomly draw one of all available captions per image
154
+ caption = captions[np.random.randint(0, len(captions))]
155
+ example = {"image": image,
156
+ #"caption": [str(caption[0])],
157
+ "caption": caption,
158
+ "img_path": img_path,
159
+ "filename_": img_path.split(os.sep)[-1]
160
+ }
161
+ if self.use_segmenation:
162
+ example.update({"seg_path": seg_path, 'segmentation': segmentation})
163
+ return example
164
+
165
+
166
+ class CocoImagesAndCaptionsTrain2017(CocoBase):
167
+ """returns a pair of (image, caption)"""
168
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,):
169
+ super().__init__(size=size,
170
+ dataroot="data/coco/train2017",
171
+ datajson="data/coco/annotations/captions_train2017.json",
172
+ onehot_segmentation=onehot_segmentation,
173
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
174
+
175
+ def get_split(self):
176
+ return "train"
177
+
178
+ def year(self):
179
+ return '2017'
180
+
181
+
182
+ class CocoImagesAndCaptionsValidation2017(CocoBase):
183
+ """returns a pair of (image, caption)"""
184
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
185
+ given_files=None):
186
+ super().__init__(size=size,
187
+ dataroot="data/coco/val2017",
188
+ datajson="data/coco/annotations/captions_val2017.json",
189
+ onehot_segmentation=onehot_segmentation,
190
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
191
+ given_files=given_files)
192
+
193
+ def get_split(self):
194
+ return "validation"
195
+
196
+ def year(self):
197
+ return '2017'
198
+
199
+
200
+
201
+ class CocoImagesAndCaptionsTrain2014(CocoBase):
202
+ """returns a pair of (image, caption)"""
203
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'):
204
+ super().__init__(size=size,
205
+ dataroot="data/coco/train2014",
206
+ datajson="data/coco/annotations2014/annotations/captions_train2014.json",
207
+ onehot_segmentation=onehot_segmentation,
208
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
209
+ use_segmentation=False,
210
+ crop_type=crop_type)
211
+
212
+ def get_split(self):
213
+ return "train"
214
+
215
+ def year(self):
216
+ return '2014'
217
+
218
+ class CocoImagesAndCaptionsValidation2014(CocoBase):
219
+ """returns a pair of (image, caption)"""
220
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
221
+ given_files=None,crop_type='center',**kwargs):
222
+ super().__init__(size=size,
223
+ dataroot="data/coco/val2014",
224
+ datajson="data/coco/annotations2014/annotations/captions_val2014.json",
225
+ onehot_segmentation=onehot_segmentation,
226
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
227
+ given_files=given_files,
228
+ use_segmentation=False,
229
+ crop_type=crop_type)
230
+
231
+ def get_split(self):
232
+ return "validation"
233
+
234
+ def year(self):
235
+ return '2014'
236
+
237
+ if __name__ == '__main__':
238
+ with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
239
+ json_data = json.load(json_file)
240
+ capdirs = json_data["annotations"]
241
+ import pudb; pudb.set_trace()
242
+ #d2 = CocoImagesAndCaptionsTrain2014(size=256)
243
+ d2 = CocoImagesAndCaptionsValidation2014(size=256)
244
+ print("constructed dataset.")
245
+ print(f"length of {d2.__class__.__name__}: {len(d2)}")
246
+
247
+ ex2 = d2[0]
248
+ # ex3 = d3[0]
249
+ # print(ex1["image"].shape)
250
+ print(ex2["image"].shape)
251
+ # print(ex3["image"].shape)
252
+ # print(ex1["segmentation"].shape)
253
+ print(ex2["caption"].__class__.__name__)
One-2-3-45-master 2/ldm/data/dummy.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import string
4
+ from torch.utils.data import Dataset, Subset
5
+
6
+ class DummyData(Dataset):
7
+ def __init__(self, length, size):
8
+ self.length = length
9
+ self.size = size
10
+
11
+ def __len__(self):
12
+ return self.length
13
+
14
+ def __getitem__(self, i):
15
+ x = np.random.randn(*self.size)
16
+ letters = string.ascii_lowercase
17
+ y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
18
+ return {"jpg": x, "txt": y}
19
+
20
+
21
+ class DummyDataWithEmbeddings(Dataset):
22
+ def __init__(self, length, size, emb_size):
23
+ self.length = length
24
+ self.size = size
25
+ self.emb_size = emb_size
26
+
27
+ def __len__(self):
28
+ return self.length
29
+
30
+ def __getitem__(self, i):
31
+ x = np.random.randn(*self.size)
32
+ y = np.random.randn(*self.emb_size).astype(np.float32)
33
+ return {"jpg": x, "txt": y}
34
+
One-2-3-45-master 2/ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+ example["caption"] = example["human_label"] # dummy caption
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
One-2-3-45-master 2/ldm/data/inpainting/__init__.py ADDED
File without changes
One-2-3-45-master 2/ldm/data/inpainting/synthetic_mask.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import numpy as np
3
+
4
+ settings = {
5
+ "256narrow": {
6
+ "p_irr": 1,
7
+ "min_n_irr": 4,
8
+ "max_n_irr": 50,
9
+ "max_l_irr": 40,
10
+ "max_w_irr": 10,
11
+ "min_n_box": None,
12
+ "max_n_box": None,
13
+ "min_s_box": None,
14
+ "max_s_box": None,
15
+ "marg": None,
16
+ },
17
+ "256train": {
18
+ "p_irr": 0.5,
19
+ "min_n_irr": 1,
20
+ "max_n_irr": 5,
21
+ "max_l_irr": 200,
22
+ "max_w_irr": 100,
23
+ "min_n_box": 1,
24
+ "max_n_box": 4,
25
+ "min_s_box": 30,
26
+ "max_s_box": 150,
27
+ "marg": 10,
28
+ },
29
+ "512train": { # TODO: experimental
30
+ "p_irr": 0.5,
31
+ "min_n_irr": 1,
32
+ "max_n_irr": 5,
33
+ "max_l_irr": 450,
34
+ "max_w_irr": 250,
35
+ "min_n_box": 1,
36
+ "max_n_box": 4,
37
+ "min_s_box": 30,
38
+ "max_s_box": 300,
39
+ "marg": 10,
40
+ },
41
+ "512train-large": { # TODO: experimental
42
+ "p_irr": 0.5,
43
+ "min_n_irr": 1,
44
+ "max_n_irr": 5,
45
+ "max_l_irr": 450,
46
+ "max_w_irr": 400,
47
+ "min_n_box": 1,
48
+ "max_n_box": 4,
49
+ "min_s_box": 75,
50
+ "max_s_box": 450,
51
+ "marg": 10,
52
+ },
53
+ }
54
+
55
+
56
+ def gen_segment_mask(mask, start, end, brush_width):
57
+ mask = mask > 0
58
+ mask = (255 * mask).astype(np.uint8)
59
+ mask = Image.fromarray(mask)
60
+ draw = ImageDraw.Draw(mask)
61
+ draw.line([start, end], fill=255, width=brush_width, joint="curve")
62
+ mask = np.array(mask) / 255
63
+ return mask
64
+
65
+
66
+ def gen_box_mask(mask, masked):
67
+ x_0, y_0, w, h = masked
68
+ mask[y_0:y_0 + h, x_0:x_0 + w] = 1
69
+ return mask
70
+
71
+
72
+ def gen_round_mask(mask, masked, radius):
73
+ x_0, y_0, w, h = masked
74
+ xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
75
+
76
+ mask = mask > 0
77
+ mask = (255 * mask).astype(np.uint8)
78
+ mask = Image.fromarray(mask)
79
+ draw = ImageDraw.Draw(mask)
80
+ draw.rounded_rectangle(xy, radius=radius, fill=255)
81
+ mask = np.array(mask) / 255
82
+ return mask
83
+
84
+
85
+ def gen_large_mask(prng, img_h, img_w,
86
+ marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
87
+ min_n_box, max_n_box, min_s_box, max_s_box):
88
+ """
89
+ img_h: int, an image height
90
+ img_w: int, an image width
91
+ marg: int, a margin for a box starting coordinate
92
+ p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
93
+
94
+ min_n_irr: int, min number of segments
95
+ max_n_irr: int, max number of segments
96
+ max_l_irr: max length of a segment in polygonal chain
97
+ max_w_irr: max width of a segment in polygonal chain
98
+
99
+ min_n_box: int, min bound for the number of box primitives
100
+ max_n_box: int, max bound for the number of box primitives
101
+ min_s_box: int, min length of a box side
102
+ max_s_box: int, max length of a box side
103
+ """
104
+
105
+ mask = np.zeros((img_h, img_w))
106
+ uniform = prng.randint
107
+
108
+ if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
109
+ n = uniform(min_n_irr, max_n_irr) # sample number of segments
110
+
111
+ for _ in range(n):
112
+ y = uniform(0, img_h) # sample a starting point
113
+ x = uniform(0, img_w)
114
+
115
+ a = uniform(0, 360) # sample angle
116
+ l = uniform(10, max_l_irr) # sample segment length
117
+ w = uniform(5, max_w_irr) # sample a segment width
118
+
119
+ # draw segment starting from (x,y) to (x_,y_) using brush of width w
120
+ x_ = x + l * np.sin(a)
121
+ y_ = y + l * np.cos(a)
122
+
123
+ mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
124
+ x, y = x_, y_
125
+ else: # generate Box masks
126
+ n = uniform(min_n_box, max_n_box) # sample number of rectangles
127
+
128
+ for _ in range(n):
129
+ h = uniform(min_s_box, max_s_box) # sample box shape
130
+ w = uniform(min_s_box, max_s_box)
131
+
132
+ x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
133
+ y_0 = uniform(marg, img_h - marg - h)
134
+
135
+ if np.random.uniform(0, 1) < 0.5:
136
+ mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
137
+ else:
138
+ r = uniform(0, 60) # sample radius
139
+ mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
140
+ return mask
141
+
142
+
143
+ make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
144
+ make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
145
+ make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
146
+ make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
147
+
148
+
149
+ MASK_MODES = {
150
+ "256train": make_lama_mask,
151
+ "256narrow": make_narrow_lama_mask,
152
+ "512train": make_512_lama_mask,
153
+ "512train-large": make_512_lama_mask_large
154
+ }
155
+
156
+ if __name__ == "__main__":
157
+ import sys
158
+
159
+ out = sys.argv[1]
160
+
161
+ prng = np.random.RandomState(1)
162
+ kwargs = settings["256train"]
163
+ mask = gen_large_mask(prng, 256, 256, **kwargs)
164
+ mask = (255 * mask).astype(np.uint8)
165
+ mask = Image.fromarray(mask)
166
+ mask.save(out)
One-2-3-45-master 2/ldm/data/laion.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import kornia
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import torchvision
7
+ from PIL import Image
8
+ import glob
9
+ import random
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ from tqdm import tqdm
13
+ from omegaconf import OmegaConf
14
+ from einops import rearrange
15
+ import torch
16
+ from webdataset.handlers import warn_and_continue
17
+
18
+
19
+ from ldm.util import instantiate_from_config
20
+ from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
21
+ from ldm.data.base import PRNGMixin
22
+
23
+
24
+ class DataWithWings(torch.utils.data.IterableDataset):
25
+ def __init__(self, min_size, transform=None, target_transform=None):
26
+ self.min_size = min_size
27
+ self.transform = transform if transform is not None else nn.Identity()
28
+ self.target_transform = target_transform if target_transform is not None else nn.Identity()
29
+ self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
30
+ self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
31
+ self.pwatermark_threshold = 0.8
32
+ self.punsafe_threshold = 0.5
33
+ self.aesthetic_threshold = 5.
34
+ self.total_samples = 0
35
+ self.samples = 0
36
+ location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
37
+
38
+ self.inner_dataset = wds.DataPipeline(
39
+ wds.ResampledShards(location),
40
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
41
+ wds.shuffle(1000, handler=wds.warn_and_continue),
42
+ wds.decode('pilrgb', handler=wds.warn_and_continue),
43
+ wds.map(self._add_tags, handler=wds.ignore_and_continue),
44
+ wds.select(self._filter_predicate),
45
+ wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
46
+ wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
47
+ )
48
+
49
+ @staticmethod
50
+ def _compute_hash(url, text):
51
+ if url is None:
52
+ url = ''
53
+ if text is None:
54
+ text = ''
55
+ total = (url + text).encode('utf-8')
56
+ return mmh3.hash64(total)[0]
57
+
58
+ def _add_tags(self, x):
59
+ hsh = self._compute_hash(x['json']['url'], x['txt'])
60
+ pwatermark, punsafe = self.kv[hsh]
61
+ aesthetic = self.kv_aesthetic[hsh][0]
62
+ return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
63
+
64
+ def _punsafe_to_class(self, punsafe):
65
+ return torch.tensor(punsafe >= self.punsafe_threshold).long()
66
+
67
+ def _filter_predicate(self, x):
68
+ try:
69
+ return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
70
+ except:
71
+ return False
72
+
73
+ def __iter__(self):
74
+ return iter(self.inner_dataset)
75
+
76
+
77
+ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
78
+ """Take a list of samples (as dictionary) and create a batch, preserving the keys.
79
+ If `tensors` is True, `ndarray` objects are combined into
80
+ tensor batches.
81
+ :param dict samples: list of samples
82
+ :param bool tensors: whether to turn lists of ndarrays into a single ndarray
83
+ :returns: single sample consisting of a batch
84
+ :rtype: dict
85
+ """
86
+ keys = set.intersection(*[set(sample.keys()) for sample in samples])
87
+ batched = {key: [] for key in keys}
88
+
89
+ for s in samples:
90
+ [batched[key].append(s[key]) for key in batched]
91
+
92
+ result = {}
93
+ for key in batched:
94
+ if isinstance(batched[key][0], (int, float)):
95
+ if combine_scalars:
96
+ result[key] = np.array(list(batched[key]))
97
+ elif isinstance(batched[key][0], torch.Tensor):
98
+ if combine_tensors:
99
+ result[key] = torch.stack(list(batched[key]))
100
+ elif isinstance(batched[key][0], np.ndarray):
101
+ if combine_tensors:
102
+ result[key] = np.array(list(batched[key]))
103
+ else:
104
+ result[key] = list(batched[key])
105
+ return result
106
+
107
+
108
+ class WebDataModuleFromConfig(pl.LightningDataModule):
109
+ def __init__(self, tar_base, batch_size, train=None, validation=None,
110
+ test=None, num_workers=4, multinode=True, min_size=None,
111
+ max_pwatermark=1.0,
112
+ **kwargs):
113
+ super().__init__(self)
114
+ print(f'Setting tar base to {tar_base}')
115
+ self.tar_base = tar_base
116
+ self.batch_size = batch_size
117
+ self.num_workers = num_workers
118
+ self.train = train
119
+ self.validation = validation
120
+ self.test = test
121
+ self.multinode = multinode
122
+ self.min_size = min_size # filter out very small images
123
+ self.max_pwatermark = max_pwatermark # filter out watermarked images
124
+
125
+ def make_loader(self, dataset_config, train=True):
126
+ if 'image_transforms' in dataset_config:
127
+ image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
128
+ else:
129
+ image_transforms = []
130
+
131
+ image_transforms.extend([torchvision.transforms.ToTensor(),
132
+ torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
133
+ image_transforms = torchvision.transforms.Compose(image_transforms)
134
+
135
+ if 'transforms' in dataset_config:
136
+ transforms_config = OmegaConf.to_container(dataset_config.transforms)
137
+ else:
138
+ transforms_config = dict()
139
+
140
+ transform_dict = {dkey: load_partial_from_config(transforms_config[dkey])
141
+ if transforms_config[dkey] != 'identity' else identity
142
+ for dkey in transforms_config}
143
+ img_key = dataset_config.get('image_key', 'jpeg')
144
+ transform_dict.update({img_key: image_transforms})
145
+
146
+ if 'postprocess' in dataset_config:
147
+ postprocess = instantiate_from_config(dataset_config['postprocess'])
148
+ else:
149
+ postprocess = None
150
+
151
+ shuffle = dataset_config.get('shuffle', 0)
152
+ shardshuffle = shuffle > 0
153
+
154
+ nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
155
+
156
+ if self.tar_base == "__improvedaesthetic__":
157
+ print("## Warning, loading the same improved aesthetic dataset "
158
+ "for all splits and ignoring shards parameter.")
159
+ tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
160
+ else:
161
+ tars = os.path.join(self.tar_base, dataset_config.shards)
162
+
163
+ dset = wds.WebDataset(
164
+ tars,
165
+ nodesplitter=nodesplitter,
166
+ shardshuffle=shardshuffle,
167
+ handler=wds.warn_and_continue).repeat().shuffle(shuffle)
168
+ print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
169
+
170
+ dset = (dset
171
+ .select(self.filter_keys)
172
+ .decode('pil', handler=wds.warn_and_continue)
173
+ .select(self.filter_size)
174
+ .map_dict(**transform_dict, handler=wds.warn_and_continue)
175
+ )
176
+ if postprocess is not None:
177
+ dset = dset.map(postprocess)
178
+ dset = (dset
179
+ .batched(self.batch_size, partial=False,
180
+ collation_fn=dict_collation_fn)
181
+ )
182
+
183
+ loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
184
+ num_workers=self.num_workers)
185
+
186
+ return loader
187
+
188
+ def filter_size(self, x):
189
+ try:
190
+ valid = True
191
+ if self.min_size is not None and self.min_size > 1:
192
+ try:
193
+ valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
194
+ except Exception:
195
+ valid = False
196
+ if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
197
+ try:
198
+ valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
199
+ except Exception:
200
+ valid = False
201
+ return valid
202
+ except Exception:
203
+ return False
204
+
205
+ def filter_keys(self, x):
206
+ try:
207
+ return ("jpg" in x) and ("txt" in x)
208
+ except Exception:
209
+ return False
210
+
211
+ def train_dataloader(self):
212
+ return self.make_loader(self.train)
213
+
214
+ def val_dataloader(self):
215
+ return self.make_loader(self.validation, train=False)
216
+
217
+ def test_dataloader(self):
218
+ return self.make_loader(self.test, train=False)
219
+
220
+
221
+ from ldm.modules.image_degradation import degradation_fn_bsr_light
222
+ import cv2
223
+
224
+ class AddLR(object):
225
+ def __init__(self, factor, output_size, initial_size=None, image_key="jpg"):
226
+ self.factor = factor
227
+ self.output_size = output_size
228
+ self.image_key = image_key
229
+ self.initial_size = initial_size
230
+
231
+ def pt2np(self, x):
232
+ x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
233
+ return x
234
+
235
+ def np2pt(self, x):
236
+ x = torch.from_numpy(x)/127.5-1.0
237
+ return x
238
+
239
+ def __call__(self, sample):
240
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
241
+ x = self.pt2np(sample[self.image_key])
242
+ if self.initial_size is not None:
243
+ x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2)
244
+ x = degradation_fn_bsr_light(x, sf=self.factor)['image']
245
+ x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
246
+ x = self.np2pt(x)
247
+ sample['lr'] = x
248
+ return sample
249
+
250
+ class AddBW(object):
251
+ def __init__(self, image_key="jpg"):
252
+ self.image_key = image_key
253
+
254
+ def pt2np(self, x):
255
+ x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
256
+ return x
257
+
258
+ def np2pt(self, x):
259
+ x = torch.from_numpy(x)/127.5-1.0
260
+ return x
261
+
262
+ def __call__(self, sample):
263
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
264
+ x = sample[self.image_key]
265
+ w = torch.rand(3, device=x.device)
266
+ w /= w.sum()
267
+ out = torch.einsum('hwc,c->hw', x, w)
268
+
269
+ # Keep as 3ch so we can pass to encoder, also we might want to add hints
270
+ sample['lr'] = out.unsqueeze(-1).tile(1,1,3)
271
+ return sample
272
+
273
+ class AddMask(PRNGMixin):
274
+ def __init__(self, mode="512train", p_drop=0.):
275
+ super().__init__()
276
+ assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
277
+ self.make_mask = MASK_MODES[mode]
278
+ self.p_drop = p_drop
279
+
280
+ def __call__(self, sample):
281
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
282
+ x = sample['jpg']
283
+ mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
284
+ if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
285
+ mask = np.ones_like(mask)
286
+ mask[mask < 0.5] = 0
287
+ mask[mask > 0.5] = 1
288
+ mask = torch.from_numpy(mask[..., None])
289
+ sample['mask'] = mask
290
+ sample['masked_image'] = x * (mask < 0.5)
291
+ return sample
292
+
293
+
294
+ class AddEdge(PRNGMixin):
295
+ def __init__(self, mode="512train", mask_edges=True):
296
+ super().__init__()
297
+ assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
298
+ self.make_mask = MASK_MODES[mode]
299
+ self.n_down_choices = [0]
300
+ self.sigma_choices = [1, 2]
301
+ self.mask_edges = mask_edges
302
+
303
+ @torch.no_grad()
304
+ def __call__(self, sample):
305
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
306
+ x = sample['jpg']
307
+
308
+ mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
309
+ mask[mask < 0.5] = 0
310
+ mask[mask > 0.5] = 1
311
+ mask = torch.from_numpy(mask[..., None])
312
+ sample['mask'] = mask
313
+
314
+ n_down_idx = self.prng.choice(len(self.n_down_choices))
315
+ sigma_idx = self.prng.choice(len(self.sigma_choices))
316
+
317
+ n_choices = len(self.n_down_choices)*len(self.sigma_choices)
318
+ raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx),
319
+ (len(self.n_down_choices), len(self.sigma_choices)))
320
+ normalized_idx = raveled_idx/max(1, n_choices-1)
321
+
322
+ n_down = self.n_down_choices[n_down_idx]
323
+ sigma = self.sigma_choices[sigma_idx]
324
+
325
+ kernel_size = 4*sigma+1
326
+ kernel_size = (kernel_size, kernel_size)
327
+ sigma = (sigma, sigma)
328
+ canny = kornia.filters.Canny(
329
+ low_threshold=0.1,
330
+ high_threshold=0.2,
331
+ kernel_size=kernel_size,
332
+ sigma=sigma,
333
+ hysteresis=True,
334
+ )
335
+ y = (x+1.0)/2.0 # in 01
336
+ y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
337
+
338
+ # down
339
+ for i_down in range(n_down):
340
+ size = min(y.shape[-2], y.shape[-1])//2
341
+ y = kornia.geometry.transform.resize(y, size, antialias=True)
342
+
343
+ # edge
344
+ _, y = canny(y)
345
+
346
+ if n_down > 0:
347
+ size = x.shape[0], x.shape[1]
348
+ y = kornia.geometry.transform.resize(y, size, interpolation="nearest")
349
+
350
+ y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous()
351
+ y = y*2.0-1.0
352
+
353
+ if self.mask_edges:
354
+ sample['masked_image'] = y * (mask < 0.5)
355
+ else:
356
+ sample['masked_image'] = y
357
+ sample['mask'] = torch.zeros_like(sample['mask'])
358
+
359
+ # concat normalized idx
360
+ sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx
361
+
362
+ return sample
363
+
364
+
365
+ def example00():
366
+ url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
367
+ dataset = wds.WebDataset(url)
368
+ example = next(iter(dataset))
369
+ for k in example:
370
+ print(k, type(example[k]))
371
+
372
+ print(example["__key__"])
373
+ for k in ["json", "txt"]:
374
+ print(example[k].decode())
375
+
376
+ image = Image.open(io.BytesIO(example["jpg"]))
377
+ outdir = "tmp"
378
+ os.makedirs(outdir, exist_ok=True)
379
+ image.save(os.path.join(outdir, example["__key__"] + ".png"))
380
+
381
+
382
+ def load_example(example):
383
+ return {
384
+ "key": example["__key__"],
385
+ "image": Image.open(io.BytesIO(example["jpg"])),
386
+ "text": example["txt"].decode(),
387
+ }
388
+
389
+
390
+ for i, example in tqdm(enumerate(dataset)):
391
+ ex = load_example(example)
392
+ print(ex["image"].size, ex["text"])
393
+ if i >= 100:
394
+ break
395
+
396
+
397
+ def example01():
398
+ # the first laion shards contain ~10k examples each
399
+ url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -"
400
+
401
+ batch_size = 3
402
+ shuffle_buffer = 10000
403
+ dset = wds.WebDataset(
404
+ url,
405
+ nodesplitter=wds.shardlists.split_by_node,
406
+ shardshuffle=True,
407
+ )
408
+ dset = (dset
409
+ .shuffle(shuffle_buffer, initial=shuffle_buffer)
410
+ .decode('pil', handler=warn_and_continue)
411
+ .batched(batch_size, partial=False,
412
+ collation_fn=dict_collation_fn)
413
+ )
414
+
415
+ num_workers = 2
416
+ loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers)
417
+
418
+ batch_sizes = list()
419
+ keys_per_epoch = list()
420
+ for epoch in range(5):
421
+ keys = list()
422
+ for batch in tqdm(loader):
423
+ batch_sizes.append(len(batch["__key__"]))
424
+ keys.append(batch["__key__"])
425
+
426
+ for bs in batch_sizes:
427
+ assert bs==batch_size
428
+ print(f"{len(batch_sizes)} batches of size {batch_size}.")
429
+ batch_sizes = list()
430
+
431
+ keys_per_epoch.append(keys)
432
+ for i_batch in [0, 1, -1]:
433
+ print(f"Batch {i_batch} of epoch {epoch}:")
434
+ print(keys[i_batch])
435
+ print("next epoch.")
436
+
437
+
438
+ def example02():
439
+ from omegaconf import OmegaConf
440
+ from torch.utils.data.distributed import DistributedSampler
441
+ from torch.utils.data import IterableDataset
442
+ from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
443
+ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
444
+
445
+ #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
446
+ #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
447
+ config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
448
+ datamod = WebDataModuleFromConfig(**config["data"]["params"])
449
+ dataloader = datamod.train_dataloader()
450
+
451
+ for batch in dataloader:
452
+ print(batch.keys())
453
+ print(batch["jpg"].shape)
454
+ break
455
+
456
+
457
+ def example03():
458
+ # improved aesthetics
459
+ tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
460
+ dataset = wds.WebDataset(tars)
461
+
462
+ def filter_keys(x):
463
+ try:
464
+ return ("jpg" in x) and ("txt" in x)
465
+ except Exception:
466
+ return False
467
+
468
+ def filter_size(x):
469
+ try:
470
+ return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
471
+ except Exception:
472
+ return False
473
+
474
+ def filter_watermark(x):
475
+ try:
476
+ return x['json']['pwatermark'] < 0.5
477
+ except Exception:
478
+ return False
479
+
480
+ dataset = (dataset
481
+ .select(filter_keys)
482
+ .decode('pil', handler=wds.warn_and_continue))
483
+ n_save = 20
484
+ n_total = 0
485
+ n_large = 0
486
+ n_large_nowm = 0
487
+ for i, example in enumerate(dataset):
488
+ n_total += 1
489
+ if filter_size(example):
490
+ n_large += 1
491
+ if filter_watermark(example):
492
+ n_large_nowm += 1
493
+ if n_large_nowm < n_save+1:
494
+ image = example["jpg"]
495
+ image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png"))
496
+
497
+ if i%500 == 0:
498
+ print(i)
499
+ print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
500
+ if n_large > 0:
501
+ print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%")
502
+
503
+
504
+
505
+ def example04():
506
+ # improved aesthetics
507
+ for i_shard in range(60208)[::-1]:
508
+ print(i_shard)
509
+ tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
510
+ dataset = wds.WebDataset(tars)
511
+
512
+ def filter_keys(x):
513
+ try:
514
+ return ("jpg" in x) and ("txt" in x)
515
+ except Exception:
516
+ return False
517
+
518
+ def filter_size(x):
519
+ try:
520
+ return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
521
+ except Exception:
522
+ return False
523
+
524
+ dataset = (dataset
525
+ .select(filter_keys)
526
+ .decode('pil', handler=wds.warn_and_continue))
527
+ try:
528
+ example = next(iter(dataset))
529
+ except Exception:
530
+ print(f"Error @ {i_shard}")
531
+
532
+
533
+ if __name__ == "__main__":
534
+ #example01()
535
+ #example02()
536
+ example03()
537
+ #example04()