yujia commited on
Commit
d01e027
·
1 Parent(s): 5ab61a5

init concerto demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/CODE_OF_CONDUCT.md +80 -0
  2. .github/CONTRIBUTING.md +35 -0
  3. .gitignore +29 -0
  4. .gradio/certificate.pem +31 -0
  5. LICENSE +201 -0
  6. README.md +10 -0
  7. app.py +1031 -0
  8. concerto/__init__.py +26 -0
  9. concerto/data.py +84 -0
  10. concerto/model.py +798 -0
  11. concerto/module.py +107 -0
  12. concerto/registry.py +340 -0
  13. concerto/serialization/__init__.py +8 -0
  14. concerto/serialization/default.py +82 -0
  15. concerto/serialization/hilbert.py +318 -0
  16. concerto/serialization/z_order.py +127 -0
  17. concerto/structure.py +159 -0
  18. concerto/transform.py +1224 -0
  19. concerto/utils.py +75 -0
  20. example/pcd/hm3d_00012_kDgLKdMd5X8_2.png +3 -0
  21. example/pcd/hm3d_00113_3goH1WRaCYC.ply +3 -0
  22. example/pcd/hm3d_00113_3goH1WRaCYC.png +3 -0
  23. example/pcd/s3dis_Area2_auditorium1.ply +3 -0
  24. example/pcd/s3dis_Area2_auditorium1.png +3 -0
  25. example/pcd/s3dis_Area4_lobby1.png +3 -0
  26. example/pcd/scannet_0024.ply +3 -0
  27. example/pcd/scannet_0024.png +3 -0
  28. example/pcd/scannet_0603.ply +3 -0
  29. example/pcd/scannet_0603.png +3 -0
  30. example/video/re10k_1.mp4 +3 -0
  31. example/video/re10k_2.mp4 +3 -0
  32. example/video/re10k_3.mp4 +3 -0
  33. example/video/re10k_4.mp4 +3 -0
  34. requirements.txt +33 -0
  35. setup.py +29 -0
  36. vggt/__init__.py +0 -0
  37. vggt/heads/camera_head.py +162 -0
  38. vggt/heads/dpt_head.py +497 -0
  39. vggt/heads/head_act.py +125 -0
  40. vggt/heads/track_head.py +108 -0
  41. vggt/heads/track_modules/__init__.py +5 -0
  42. vggt/heads/track_modules/base_track_predictor.py +209 -0
  43. vggt/heads/track_modules/blocks.py +246 -0
  44. vggt/heads/track_modules/modules.py +218 -0
  45. vggt/heads/track_modules/utils.py +226 -0
  46. vggt/heads/utils.py +108 -0
  47. vggt/layers/__init__.py +11 -0
  48. vggt/layers/attention.py +98 -0
  49. vggt/layers/block.py +259 -0
  50. vggt/layers/drop_path.py +34 -0
.github/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to "sonata"
2
+
3
+ We want to make contributing to this project as easy and transparent as
4
+ possible.
5
+
6
+ ## Pull Requests
7
+
8
+ We welcome pull requests.
9
+
10
+ 1. Fork the repo and create your branch from `main`.
11
+ 2. If you've added code that should be tested, add tests.
12
+ 3. If you've changed APIs, update the documentation in the code.
13
+ 4. Ensure the test suite passes.
14
+ 5. If you haven't already, complete the Contributor License Agreement ("CLA").
15
+
16
+ ## Contributor License Agreement ("CLA")
17
+
18
+ In order to accept your pull request, we need you to submit a CLA. You only need
19
+ to do this once to work on any of Facebook's open source projects.
20
+
21
+ Complete your CLA here: <https://code.facebook.com/cla>
22
+
23
+ ## Issues
24
+
25
+ We use GitHub issues to track public bugs. Please ensure your description is
26
+ clear and has sufficient instructions to be able to reproduce the issue.
27
+
28
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
29
+ disclosure of security bugs. In those cases, please go through the process
30
+ outlined on that page and do not file a public issue.
31
+
32
+ ## License
33
+
34
+ By contributing to "sonata", you agree that your contributions will be licensed under
35
+ the [LICENSE](../LICENSE) file in the root directory of this source tree.
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image/
2
+ __pycache__
3
+ **/build/
4
+ **/*.egg-info/
5
+ **/dist/
6
+ *.so
7
+ exp
8
+ weights
9
+ data
10
+ log
11
+ **/ckpt/
12
+ outputs/
13
+ .vscode
14
+ .idea
15
+ */.DS_Store
16
+ **/*.out
17
+ # Dockerfile
18
+ # **/vggt*
19
+ vggt/ckpt
20
+ **/example/video/conference_room.mp4
21
+ **/example/video/office.mp4
22
+ # **/example/video/re10k_1.mp4
23
+ # **/example/video/re10k_2.mp4
24
+ **/example/pcd/hm3d_00012_kDgLKdMd5X8_2.ply
25
+ # **/example/pcd/hm3d_00113_3goH1WRaCYC.ply
26
+ # **/example/pcd/s3dis_Area2_auditorium1.ply
27
+ **/example/pcd/s3dis_Area4_lobby1.ply
28
+ # **/example/
29
+ **/demo_output*
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Concerto
3
+ emoji: 🎶
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,1031 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import shutil
5
+ from datetime import datetime
6
+ import glob
7
+ import gc
8
+ import gradio as gr
9
+ import numpy as np
10
+ import open3d as o3d
11
+ import concerto
12
+ from scipy.spatial.transform import Rotation as R
13
+ import trimesh
14
+ import time
15
+ from typing import List, Tuple
16
+ from pathlib import Path
17
+ from einops import rearrange
18
+ from tqdm import tqdm
19
+ import camtools as ct
20
+ from PIL import Image
21
+ from torchvision import transforms as TF
22
+ try:
23
+ import flash_attn
24
+ except ImportError:
25
+ flash_attn = None
26
+
27
+ from visual_util import predictions_to_glb
28
+ from vggt.models.vggt import VGGT
29
+ from vggt.utils.load_fn import load_and_preprocess_images
30
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
31
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
32
+
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+ def run_model(target_dir, model) -> dict:
36
+ """
37
+ Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
38
+ """
39
+ print(f"Processing images from {target_dir}")
40
+
41
+ # if not torch.cuda.is_available():
42
+ # raise ValueError("CUDA is not available. Check your environment.")
43
+
44
+ # Move model to device
45
+ model = model.to(device)
46
+ model.eval()
47
+
48
+ # Load and preprocess images
49
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
50
+ image_names = sorted(image_names)
51
+ print(f"Found {len(image_names)} images")
52
+ if len(image_names) == 0:
53
+ raise ValueError("No images found. Check your upload.")
54
+
55
+ images = load_and_preprocess_images(image_names).to(device)
56
+ print(f"Preprocessed images shape: {images.shape}")
57
+
58
+ # Run inference
59
+ print("Running inference...")
60
+ with torch.no_grad():
61
+ if device == "cuda":
62
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
63
+ predictions = model(images)
64
+ else:
65
+ predictions = model(images)
66
+
67
+ # Convert pose encoding to extrinsic and intrinsic matrices
68
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
69
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
70
+ predictions["extrinsic"] = extrinsic
71
+ predictions["intrinsic"] = intrinsic
72
+
73
+ # Convert tensors to numpy
74
+ for key in predictions.keys():
75
+ if isinstance(predictions[key], torch.Tensor):
76
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
77
+
78
+ # Generate world points from depth map
79
+ print("Computing world points from depth map...")
80
+ depth_map = predictions["depth"] # (S, H, W, 1)
81
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
82
+ predictions["world_points_from_depth"] = world_points
83
+
84
+ # Clean up
85
+ torch.cuda.empty_cache()
86
+ return predictions
87
+
88
+ def handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mode,if_TSDF):
89
+ """
90
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
91
+ images or extracted frames from video into it. Return (target_dir, image_paths).
92
+ """
93
+ start_time = time.time()
94
+ gc.collect()
95
+ torch.cuda.empty_cache()
96
+
97
+ # Create a unique folder name
98
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
99
+ target_dir = f"demo_output/inputs_{timestamp}"
100
+ target_dir_images = os.path.join(target_dir, "images")
101
+ target_dir_pcds = os.path.join(target_dir, "pcds")
102
+
103
+ # Clean up if somehow that folder already exists
104
+ if os.path.exists(target_dir):
105
+ shutil.rmtree(target_dir)
106
+ os.makedirs(target_dir)
107
+ os.makedirs(target_dir_images)
108
+ os.makedirs(target_dir_pcds)
109
+ # --- Handle video ---
110
+ if input_video is not None:
111
+ print("processing video")
112
+ if isinstance(input_video, dict) and "name" in input_video:
113
+ video_path = input_video["name"]
114
+ else:
115
+ video_path = input_video
116
+
117
+ vs = cv2.VideoCapture(video_path)
118
+ fps = vs.get(cv2.CAP_PROP_FPS)
119
+ frame_interval = int(fps * frame_slider) # 1 frame/sec
120
+
121
+ count = 0
122
+ video_frame_num = 0
123
+ image_paths = []
124
+ while True:
125
+ gotit, frame = vs.read()
126
+ if not gotit:
127
+ break
128
+ count += 1
129
+ if count % frame_interval == 0:
130
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
131
+ cv2.imwrite(image_path, frame)
132
+ image_paths.append(image_path)
133
+ video_frame_num += 1
134
+ # Sort final images for gallery
135
+ image_paths = sorted(image_paths)
136
+ original_points, original_colors, original_normals = parse_frames(target_dir,conf_thres,prediction_mode,if_TSDF)
137
+ if input_file is not None:
138
+ print("processing ply")
139
+ pcd = o3d.io.read_point_cloud(input_file.name)
140
+ pcd.estimate_normals()
141
+ original_points = np.asarray(pcd.points)
142
+ original_colors = np.asarray(pcd.colors)
143
+ original_normals = np.asarray(pcd.normals)
144
+ image_paths = None
145
+ scene_3d = trimesh.Scene()
146
+ point_cloud_data = trimesh.PointCloud(vertices=original_points, colors=original_colors, vertex_normals=original_normals)
147
+ scene_3d.add_geometry(point_cloud_data)
148
+ original_temp = os.path.join(target_dir_pcds,"original.glb")
149
+ scene_3d.export(file_obj=original_temp)
150
+ np.save(os.path.join(target_dir_pcds, f"points.npy"), original_points)
151
+ np.save(os.path.join(target_dir_pcds, f"colors.npy"), original_colors)
152
+ np.save(os.path.join(target_dir_pcds, f"normals.npy"), original_normals)
153
+ end_time = time.time()
154
+ print(f"Files copied to {target_dir}; took {end_time - start_time:.3f} seconds")
155
+ return target_dir, image_paths,original_temp, end_time - start_time
156
+
157
+ def update_gallery_on_upload(input_file,input_video,conf_thres,frame_slider,prediction_mode,TSDF_mode):
158
+ """
159
+ Whenever user uploads or changes files, immediately handle them
160
+ and show in the gallery. Return (target_dir, image_paths).
161
+ If nothing is uploaded, returns "None" and empty list.
162
+ """
163
+ if not input_video and not input_file:
164
+ return None, None, None, None
165
+ if_TSDF = True if TSDF_mode=="Yes" else False
166
+ target_dir, image_paths,original_view, reconstruction_time = handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mode,if_TSDF)
167
+ if input_file is not None:
168
+ return original_view, target_dir, [], f"Upload and preprocess complete with {reconstruction_time:.3f} sec. Click \"PCA Generate\" to begin PCA processing."
169
+ if input_video is not None:
170
+ return original_view, target_dir, image_paths, f"Upload and preprocess complete with {reconstruction_time:.3f} sec. Click \"PCA Generate\" to begin PCA processing."
171
+
172
+ def clear_fields():
173
+ """
174
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
175
+ """
176
+ return None
177
+
178
+ def PCAing_log(is_example, log_output):
179
+ """
180
+ Display a quick log message while waiting.
181
+ """
182
+ if is_example:
183
+ return log_output
184
+ return "Loading for Doing PCA..."
185
+
186
+ def reset_log():
187
+ """
188
+ Reset a quick log message.
189
+ """
190
+ return "A new point cloud file or video is uploading and preprocessing..."
191
+
192
+ def parse_frames(
193
+ target_dir,
194
+ conf_thres=3.0,
195
+ prediction_mode="Pointmap Regression",
196
+ if_TSDF=True,
197
+ ):
198
+ """
199
+ Perform reconstruction using the already-created target_dir/images.
200
+ """
201
+ if not os.path.isdir(target_dir) or target_dir == "None":
202
+ return None, "No valid target directory found. Please upload first.", None, None
203
+
204
+ start_time = time.time()
205
+ gc.collect()
206
+ torch.cuda.empty_cache()
207
+
208
+ # Prepare frame_filter dropdown
209
+ target_dir_images = os.path.join(target_dir, "images")
210
+ target_dir_pcds = os.path.join(target_dir, "pcds")
211
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
212
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
213
+ frame_filter_choices = ["All"] + all_files
214
+
215
+ print("Running run_model...")
216
+ with torch.no_grad():
217
+ predictions = run_model(target_dir, VGGT_model)
218
+
219
+ # Save predictions
220
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
221
+ np.savez(prediction_save_path, **predictions)
222
+
223
+ # Convert pose encoding to extrinsic and intrinsic matrices
224
+ images = predictions["images"]
225
+ Ts, Ks = predictions["extrinsic"],predictions["intrinsic"]
226
+ Ts = ct.convert.pad_0001(Ts)
227
+ Ts_inv = np.linalg.inv(Ts)
228
+ Cs = np.array([ct.convert.T_to_C(T) for T in Ts]) # (n, 3)
229
+
230
+ # [1, 8, 294, 518, 3]
231
+ world_points = predictions["world_points"]
232
+
233
+ # Compute view direction for each pixel
234
+ # (b n h w c) - (n, 3)
235
+ view_dirs = world_points - rearrange(Cs, "n c -> n 1 1 c")
236
+ view_dirs = rearrange(view_dirs, "n h w c -> (n h w) c")
237
+ view_dirs = view_dirs / np.linalg.norm(view_dirs, axis=-1, keepdims=True)
238
+
239
+ # Extract points and colors
240
+ # [1, 8, 3, 294, 518]
241
+ img_num = world_points.shape[1]
242
+ images = predictions["images"]
243
+ points = rearrange(world_points, "n h w c -> (n h w) c")
244
+ colors = rearrange(images, "n c h w -> (n h w) c")
245
+
246
+ if prediction_mode=="Pointmap Branch":
247
+ world_points_conf = predictions["world_points_conf"]
248
+ conf = world_points_conf.reshape(-1)
249
+ points,Ts_inv,_ = Coord2zup(points, Ts_inv)
250
+ scale = 3 / (points[:, 2].max() - points[:, 2].min())
251
+ points *= scale
252
+ Ts_inv[:, :3, 3] *= scale
253
+
254
+ # Create a point cloud
255
+ pcd = o3d.geometry.PointCloud()
256
+ pcd.points = o3d.utility.Vector3dVector(points)
257
+ pcd.colors = o3d.utility.Vector3dVector(colors)
258
+ pcd.estimate_normals()
259
+ # o3d.io.write_point_cloud("pcd.ply", pcd)
260
+ try:
261
+ pcd, inliers, rotation_matrix, offset = extract_and_align_ground_plane(pcd)
262
+ except Exception as e:
263
+ print(f"cannot find ground, err:{e}")
264
+ # Filp normals such that normals always point to camera
265
+ # Compute the dot product between the normal and the view direction
266
+ # If the dot product is less than 0, flip the normal
267
+ normals = np.asarray(pcd.normals)
268
+ view_dirs = np.asarray(view_dirs)
269
+ dot_product = np.sum(normals * view_dirs, axis=-1)
270
+ flip_mask = dot_product > 0
271
+ normals[flip_mask] = -normals[flip_mask]
272
+
273
+ # Normalize normals a nd m
274
+ normals = normals / np.linalg.norm(normals, axis=-1, keepdims=True)
275
+ pcd.normals = o3d.utility.Vector3dVector(normals)
276
+ if conf_thres == 0.0:
277
+ conf_threshold = 0.0
278
+ else:
279
+ conf_threshold = np.percentile(conf, conf_thres)
280
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
281
+ points = points[conf_mask]
282
+ colors = colors[conf_mask]
283
+ normals = normals[conf_mask]
284
+ elif prediction_mode=="Depthmap Branch":
285
+ # Integrate RGBD images into a TSDF volume and extract a mesh
286
+ # (n, h, w, 3)
287
+ im_colors = rearrange(images, "n c h w -> (n) h w c")
288
+ # (b, n, h, w, 3)
289
+ im_dists = world_points - rearrange(Cs, "n c -> n 1 1 c")
290
+ im_dists = np.linalg.norm(im_dists, axis=-1, keepdims=False)
291
+
292
+ # Convert distance to depth
293
+ im_depths = [] # (n, h, w, c)
294
+ for im_dist, K in zip(im_dists, Ks):
295
+ im_depth = ct.convert.im_distance_to_im_depth(im_dist, K)
296
+ im_depths.append(im_depth)
297
+ im_depths = np.stack(im_depths, axis=0)
298
+ if if_TSDF:
299
+ mesh = integrate_rgbd_to_mesh(
300
+ Ks=Ks,
301
+ Ts=Ts,
302
+ im_depths=im_depths,
303
+ im_colors=im_colors,
304
+ voxel_size=1 / 512,
305
+ )
306
+ rotation_angle = -np.pi / 2
307
+ rotation_axis = np.array([1, 0, 0]) # X 轴
308
+ mesh.rotate(
309
+ o3d.geometry.get_rotation_matrix_from_axis_angle(rotation_axis * rotation_angle),
310
+ center=(0,0,0)
311
+ )
312
+ vertices = np.asarray(mesh.vertices)
313
+ scale_factor = 3./(np.max(vertices[:,2])-np.min(vertices[:,2]))
314
+ mesh.scale(scale_factor, center=(0,0,0))
315
+ points = np.asarray(mesh.vertices)
316
+ colors = np.asarray(mesh.vertex_colors) if mesh.has_vertex_colors() else np.zeros_like(vertices)
317
+ if not mesh.has_vertex_normals():
318
+ mesh.compute_vertex_normals()
319
+ normals = np.asarray(mesh.vertex_normals)
320
+ Ts_inv = rotx(Ts_inv, theta=-90)
321
+ Ts_inv[:, :3, 3] *= scale_factor
322
+ pcd = o3d.geometry.PointCloud()
323
+ pcd.points = o3d.utility.Vector3dVector(points)
324
+ pcd.colors = o3d.utility.Vector3dVector(colors)
325
+ pcd.normals = o3d.utility.Vector3dVector(normals)
326
+ else:
327
+ points=[]
328
+ for K, T, im_depth in zip(Ks, Ts, im_depths):
329
+ point = ct.project.im_depth_to_point_cloud(
330
+ im_depth=im_depth,
331
+ K=K,
332
+ T=T,
333
+ to_image=False,
334
+ ignore_invalid=False,
335
+ )
336
+ points.append(point)
337
+ points = np.vstack(points)
338
+ colors = im_colors.reshape(-1,3)
339
+ world_points_conf = predictions["depth_conf"]
340
+ conf = world_points_conf.reshape(-1)
341
+ if conf_thres == 0.0:
342
+ conf_threshold = 0.0
343
+ else:
344
+ conf_threshold = np.percentile(conf, conf_thres)
345
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
346
+ points = points[conf_mask]
347
+ colors = colors[conf_mask]
348
+ points,Ts_inv,_ = Coord2zup(points, Ts_inv)
349
+ scale_factor = 3./(np.max(points[:,2])-np.min(points[:,2]))
350
+ points *= scale_factor
351
+ Ts_inv[:, :3, 3] *= scale_factor
352
+ pcd = o3d.geometry.PointCloud()
353
+ pcd.points = o3d.utility.Vector3dVector(points)
354
+ pcd.colors = o3d.utility.Vector3dVector(colors)
355
+ pcd.estimate_normals()
356
+ try:
357
+ pcd, inliers, rotation_matrix, offset = extract_and_align_ground_plane(pcd)
358
+ except Exception as e:
359
+ print(f"cannot find ground, err:{e}")
360
+ original_points = np.asarray(pcd.points)
361
+ original_colors = np.asarray(pcd.colors)
362
+ original_normals = np.asarray(pcd.normals)
363
+ # Cleanup
364
+ del predictions
365
+ gc.collect()
366
+ torch.cuda.empty_cache()
367
+ end_time = time.time()
368
+ print(f"Total time: {end_time - start_time:.2f} seconds")
369
+ return original_points, original_colors, original_normals
370
+
371
+ def extract_and_align_ground_plane(pcd,
372
+ height_percentile=20,
373
+ ransac_distance_threshold=0.01,
374
+ ransac_n=3,
375
+ ransac_iterations=1000,
376
+ max_angle_degree=40,
377
+ max_trials=6):
378
+ points = np.asarray(pcd.points)
379
+ z_vals = points[:, 2]
380
+ z_thresh = np.percentile(z_vals, height_percentile)
381
+ low_indices = np.where(z_vals <= z_thresh)[0]
382
+
383
+ remaining_indices = low_indices.copy()
384
+
385
+ for trial in range(max_trials):
386
+ if len(remaining_indices) < ransac_n:
387
+ raise ValueError("Not enough points left to fit a plane.")
388
+
389
+ low_pcd = pcd.select_by_index(remaining_indices)
390
+
391
+ plane_model, inliers = low_pcd.segment_plane(
392
+ distance_threshold=ransac_distance_threshold,
393
+ ransac_n=ransac_n,
394
+ num_iterations=ransac_iterations)
395
+ a, b, c, d = plane_model
396
+ normal = np.array([a, b, c])
397
+ normal /= np.linalg.norm(normal)
398
+
399
+ # current_plane_pcd = pcd.select_by_index(remaining_indices[inliers])
400
+ # o3d.io.write_point_cloud("plane.ply",current_plane_pcd)
401
+ # exit()
402
+
403
+ angle = np.arccos(np.clip(np.dot(normal, [0, 0, 1]), -1.0, 1.0)) * 180 / np.pi
404
+ if angle <= max_angle_degree:
405
+ inliers_global = remaining_indices[inliers]
406
+
407
+ target = np.array([0, 0, 1])
408
+ axis = np.cross(normal, target)
409
+ axis_norm = np.linalg.norm(axis)
410
+
411
+ if axis_norm < 1e-6:
412
+ rotation_matrix = np.eye(3)
413
+ else:
414
+ axis /= axis_norm
415
+ rot_angle = np.arccos(np.clip(np.dot(normal, target), -1.0, 1.0))
416
+ rotation = R.from_rotvec(axis * rot_angle)
417
+ rotation_matrix = rotation.as_matrix()
418
+
419
+ rotated_points = points @ rotation_matrix.T
420
+ ground_points_z = rotated_points[inliers_global, 2]
421
+ offset = np.mean(ground_points_z)
422
+ rotated_points[:, 2] -= offset
423
+
424
+ aligned_pcd = o3d.geometry.PointCloud()
425
+ aligned_pcd.points = o3d.utility.Vector3dVector(rotated_points)
426
+ if pcd.has_colors():
427
+ aligned_pcd.colors = pcd.colors
428
+ if pcd.has_normals():
429
+ rotated_normals = np.asarray(pcd.normals) @ rotation_matrix.T
430
+ aligned_pcd.normals = o3d.utility.Vector3dVector(rotated_normals)
431
+
432
+ return aligned_pcd, inliers_global, rotation_matrix, offset
433
+
434
+ else:
435
+ rejected_indices = remaining_indices[inliers]
436
+ remaining_indices = np.setdiff1d(remaining_indices, rejected_indices)
437
+
438
+ raise ValueError("Failed to find a valid ground plane within max trials.")
439
+
440
+ def rotx(x, theta=90):
441
+ """
442
+ Rotate x by theta degrees around the x-axis
443
+ """
444
+ theta = np.deg2rad(theta)
445
+ rot_matrix = np.array(
446
+ [
447
+ [1, 0, 0, 0],
448
+ [0, np.cos(theta), -np.sin(theta), 0],
449
+ [0, np.sin(theta), np.cos(theta), 0],
450
+ [0, 0, 0, 1],
451
+ ]
452
+ )
453
+ return rot_matrix@ x
454
+
455
+
456
+ def Coord2zup(points, extrinsics, normals = None):
457
+ """
458
+ Convert the dust3r coordinate system to the z-up coordinate system
459
+ """
460
+ points = np.concatenate([points, np.ones([points.shape[0], 1])], axis=1).T
461
+ points = rotx(points, -90)[:3].T
462
+ if normals is not None:
463
+ normals = np.concatenate([normals, np.ones([normals.shape[0], 1])], axis=1).T
464
+ normals = rotx(normals, -90)[:3].T
465
+ normals = normals / np.linalg.norm(normals, axis=1, keepdims=True)
466
+ t = np.min(points,axis=0)
467
+ points -= t
468
+ extrinsics = rotx(extrinsics, -90)
469
+ extrinsics[:, :3, 3] -= t.T
470
+ return points, extrinsics, normals
471
+ def integrate_rgbd_to_mesh(
472
+ Ks,
473
+ Ts,
474
+ im_depths,
475
+ im_colors,
476
+ voxel_size,
477
+ bbox=None,
478
+ ):
479
+ """
480
+ Integrate RGBD images into a TSDF volume and extract a mesh.
481
+
482
+ Args:
483
+ Ks: (N, 3, 3) camera intrinsics.
484
+ Ts: (N, 4, 4) camera extrinsics.
485
+ im_depths: (N, H, W) depth images, already in world scale.
486
+ im_colors: (N, H, W, 3) color images, float range in [0, 1].
487
+ voxel_size: TSDF voxel size, in meters, e.g. 3 / 512.
488
+ bbox: Open3D axis-aligned bounding box, for cropping.
489
+
490
+ Per Open3D convention, invalid depth values shall be set to 0.
491
+ """
492
+ num_images = len(Ks)
493
+ if (
494
+ len(Ts) != num_images
495
+ or len(im_depths) != num_images
496
+ or len(im_colors) != num_images
497
+ ):
498
+ raise ValueError("Ks, Ts, im_depths, im_colors must have the same length.")
499
+
500
+ # Constants.
501
+ trunc_voxel_multiplier = 8.0
502
+ sdf_trunc = trunc_voxel_multiplier * voxel_size
503
+
504
+ volume = o3d.pipelines.integration.ScalableTSDFVolume(
505
+ voxel_length=voxel_size,
506
+ sdf_trunc=sdf_trunc,
507
+ color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8,
508
+ )
509
+
510
+ for K, T, im_depth, im_color in tqdm(
511
+ zip(Ks, Ts, im_depths, im_colors),
512
+ total=len(Ks),
513
+ desc="Integrating RGBD frames",
514
+ ):
515
+ # Set invalid depth values to 0, based on bounding box.
516
+ if bbox is not None:
517
+ points = ct.project.im_depth_to_point_cloud(
518
+ im_depth=im_depth,
519
+ K=K,
520
+ T=T,
521
+ to_image=False,
522
+ ignore_invalid=False,
523
+ )
524
+ assert len(points) == im_depth.shape[0] * im_depth.shape[1]
525
+ point_indices_inside_bbox = bbox.get_point_indices_within_bounding_box(
526
+ o3d.utility.Vector3dVector(points)
527
+ )
528
+ point_indices_outside_bbox = np.setdiff1d(
529
+ np.arange(len(points)), point_indices_inside_bbox
530
+ )
531
+ im_depth.ravel()[point_indices_outside_bbox] = 0
532
+
533
+ im_color_uint8 = np.ascontiguousarray((im_color * 255).astype(np.uint8))
534
+ im_depth_uint16 = np.ascontiguousarray((im_depth * 1000).astype(np.uint16))
535
+ im_color_o3d = o3d.geometry.Image(im_color_uint8)
536
+ im_depth_o3d = o3d.geometry.Image(im_depth_uint16)
537
+ im_rgbd_o3d = o3d.geometry.RGBDImage.create_from_color_and_depth(
538
+ im_color_o3d,
539
+ im_depth_o3d,
540
+ depth_scale=1000.0,
541
+ depth_trunc=10.0,
542
+ convert_rgb_to_intensity=False,
543
+ )
544
+ o3d_intrinsic = o3d.camera.PinholeCameraIntrinsic(
545
+ width=im_depth.shape[1],
546
+ height=im_depth.shape[0],
547
+ fx=K[0, 0],
548
+ fy=K[1, 1],
549
+ cx=K[0, 2],
550
+ cy=K[1, 2],
551
+ )
552
+ o3d_extrinsic = T
553
+ volume.integrate(
554
+ im_rgbd_o3d,
555
+ o3d_intrinsic,
556
+ o3d_extrinsic,
557
+ )
558
+
559
+ mesh = volume.extract_triangle_mesh()
560
+ return mesh
561
+
562
+ def get_pca_color(feat, start = 0, brightness=1.25, center=True):
563
+ u, s, v = torch.pca_lowrank(feat, center=center, q=3*(start+1), niter=5)
564
+ projection = feat @ v
565
+ projection = projection[:, 3*start:3*(start+1)] * 0.6 + projection[:, 3*start:3*(start+1)] * 0.4
566
+ min_val = projection.min(dim=-2, keepdim=True)[0]
567
+ max_val = projection.max(dim=-2, keepdim=True)[0]
568
+ div = torch.clamp(max_val - min_val, min=1e-6)
569
+ color = (projection - min_val) / div * brightness
570
+ color = color.clamp(0.0, 1.0)
571
+ return color
572
+
573
+ def Concerto_process(target_dir, original_points, original_colors, original_normals, slider_value, bright_value, model_type):
574
+ gc.collect()
575
+ torch.cuda.empty_cache()
576
+ target_dir_pcds = os.path.join(target_dir, "pcds")
577
+
578
+ point = {"coord": original_points, "color": original_colors, "normal":original_normals}
579
+ original_coord = point["coord"].copy()
580
+ original_color = point["color"].copy()
581
+ point = transform(point)
582
+
583
+ with torch.inference_mode():
584
+ for key in point.keys():
585
+ if isinstance(point[key], torch.Tensor) and device=="cuda":
586
+ point[key] = point[key].cuda(non_blocking=True)
587
+ # model forward:
588
+ concerto_start_time = time.time()
589
+ if model_type =="Concerto":
590
+ point = concerto_model(point)
591
+ elif model_type == "Sonata":
592
+ point = sonata_model(point)
593
+ concerto_end_time = time.time()
594
+ # upcast point feature
595
+ # Point is a structure contains all the information during forward
596
+ for _ in range(2):
597
+ assert "pooling_parent" in point.keys()
598
+ assert "pooling_inverse" in point.keys()
599
+ parent = point.pop("pooling_parent")
600
+ inverse = point.pop("pooling_inverse")
601
+ parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
602
+ point = parent
603
+ while "pooling_parent" in point.keys():
604
+ assert "pooling_inverse" in point.keys()
605
+ parent = point.pop("pooling_parent")
606
+ inverse = point.pop("pooling_inverse")
607
+ parent.feat = point.feat[inverse]
608
+ point = parent
609
+
610
+ # here point is down-sampled by GridSampling in default transform pipeline
611
+ # feature of point cloud in original scale can be acquired by:
612
+ _ = point.feat[point.inverse]
613
+
614
+ # PCA
615
+ point_feat = point.feat.cpu().detach().numpy()
616
+ np.save(os.path.join(target_dir_pcds,"feat.npy"),point_feat)
617
+ pca_start_time = time.time()
618
+ pca_color = get_pca_color(point.feat,start = slider_value, brightness=bright_value, center=True)
619
+ pca_end_time = time.time()
620
+
621
+ # inverse back to original scale before grid sampling
622
+ # point.inverse is acquired from the GirdSampling transform
623
+ point_inverse = point.inverse.cpu().detach().numpy()
624
+ np.save(os.path.join(target_dir_pcds,"inverse.npy"),point_inverse)
625
+ original_pca_color = pca_color[point.inverse]
626
+ points = original_coord
627
+ colors = original_pca_color.cpu().detach().numpy()
628
+
629
+ end_time = time.time()
630
+ return points, colors, concerto_end_time - concerto_start_time, pca_end_time - pca_start_time
631
+
632
+ def gradio_demo(target_dir,pca_slider,bright_slider, model_type, if_color=True, if_normal=True):
633
+ target_dir_pcds = os.path.join(target_dir, "pcds")
634
+ if not os.path.isfile(os.path.join(target_dir_pcds,"points.npy")):
635
+ return None, "No point cloud available. Please upload data first."
636
+ original_points = np.load(os.path.join(target_dir_pcds,"points.npy"))
637
+ if if_color:
638
+ original_colors = np.load(os.path.join(target_dir_pcds,"colors.npy"))
639
+ else:
640
+ original_colors = np.zeros_like(original_points)
641
+ if if_normal:
642
+ original_normals = np.load(os.path.join(target_dir_pcds,"normals.npy"))
643
+ else:
644
+ original_normals = np.zeros_like(original_points)
645
+ processed_temp = (os.path.join(target_dir_pcds,"processed.glb"))
646
+ processed_points, processed_colors, concerto_time, pca_time = Concerto_process(target_dir,original_points, original_colors,original_normals, pca_slider, bright_slider, model_type)
647
+ feat_3d = trimesh.Scene()
648
+ feat_data = trimesh.PointCloud(vertices=processed_points, colors=processed_colors, vertex_normals=original_normals)
649
+ feat_3d.add_geometry(feat_data)
650
+ feat_3d.export(processed_temp)
651
+
652
+ return processed_temp, f"Feature visualization process finished with {concerto_time:.3f} seconds using Concerto inference and {pca_time:.3f} seconds using PCA. Updating visualization."
653
+
654
+ def concerto_slider_update(target_dir,pca_slider,bright_slider,is_example,log_output):
655
+ if is_example == "True":
656
+ return None, log_output
657
+ else:
658
+ target_dir_pcds = os.path.join(target_dir, "pcds")
659
+ if os.path.isfile(os.path.join(target_dir_pcds,"feat.npy")):
660
+ feat = np.load(os.path.join(target_dir_pcds,"feat.npy"))
661
+ inverse = np.load(os.path.join(target_dir_pcds,"inverse.npy"))
662
+ feat = torch.tensor(feat, device = device)
663
+ inverse = torch.tensor(inverse, device = device)
664
+ pca_start_time = time.time()
665
+ pca_colors = get_pca_color(feat,start = pca_slider, brightness=bright_slider, center=True)
666
+ processed_colors = pca_colors[inverse].cpu().detach().numpy()
667
+ pca_end_time = time.time()
668
+ pca_time = pca_end_time - pca_start_time
669
+ processed_points = np.load(os.path.join(target_dir_pcds,"points.npy"))
670
+ processed_normals = np.load(os.path.join(target_dir_pcds,"normals.npy"))
671
+ processed_temp = (os.path.join(target_dir_pcds,"processed.glb"))
672
+ feat_3d = trimesh.Scene()
673
+ feat_data = trimesh.PointCloud(vertices=processed_points, colors=processed_colors, vertex_normals=processed_normals)
674
+ feat_3d.add_geometry(feat_data)
675
+ feat_3d.export(processed_temp)
676
+ log_output = f"Feature visualization process finished with{pca_time:.3f} seconds using PCA. Updating visualization."
677
+ else:
678
+ processed_temp = None
679
+ log_output = "No representations saved, please click PCA generate first."
680
+ # processed_temp, log_output = gradio_demo(target_dir,pca_slider,bright_slider)
681
+ return processed_temp, log_output
682
+
683
+ # set random seed
684
+ # (random seed affect pca color, yet change random seed need manual adjustment kmeans)
685
+ # (the pca prevent in paper is with another version of cuda and pytorch environment)
686
+ concerto.utils.set_seed(53124)
687
+ # Load model
688
+ if device == 'cuda' and flash_attn is not None:
689
+ print("Loading model with Flash Attention on GPU.")
690
+ concerto_model = concerto.load("concerto_large", repo_id="Pointcept/Concerto").to(device)
691
+ sonata_model = concerto.model.load("sonata", repo_id="facebook/sonata").to(device)
692
+ else:
693
+ print("Loading model on CPU or without Flash Attention.")
694
+ custom_config = dict(
695
+ # enc_patch_size=[1024 for _ in range(5)], # reduce patch size if necessary
696
+ enable_flash=False,
697
+ )
698
+ concerto_model = concerto.load(
699
+ "concerto_large", repo_id="Pointcept/Concerto", custom_config=custom_config
700
+ ).to(device)
701
+ sonata_model = concerto.load("sonata", repo_id="facebook/sonata", custom_config=custom_config).to(device)
702
+
703
+ transform = concerto.transform.default()
704
+
705
+ VGGT_model = VGGT().to(device)
706
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
707
+ VGGT_model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
708
+ # VGGT_model.load_state_dict(torch.load("vggt/ckpt/model.pt",weights_only=True))
709
+
710
+ examples_video = [
711
+ # ["example/video/conference_room.mp4", 0.0, 1, "Depthmap Branch", "Yes",0,1.2, "True"],
712
+ # ["example/video/office.mp4", 0.0, 1, "Pointmap Branch", "Yes",2,1.1, "True"],
713
+ ["example/video/re10k_1.mp4", 10.0, 1, "Depthmap Branch", "No",2,1.2, "True"],
714
+ ["example/video/re10k_2.mp4", 30.0, 1, "Depthmap Branch", "Yes",1,1.2, "True"],
715
+ ["example/video/re10k_3.mp4", 10.0, 1, "Depthmap Branch", "Yes",1,1.2, "True"],
716
+ ["example/video/re10k_4.mp4", 10.0, 1, "Depthmap Branch", "Yes",1,1., "True"],
717
+ ]
718
+
719
+ examples_pcd = [
720
+ ["example/pcd/scannet_0024.png","example/pcd/scannet_0024.ply",2,1.2, "True"],
721
+ ["example/pcd/scannet_0603.png","example/pcd/scannet_0603.ply",0,1.2, "True"],
722
+ # ["example/pcd/hm3d_00012_kDgLKdMd5X8_2.png","example/pcd/hm3d_00012_kDgLKdMd5X8_2.ply",0,1.2, "True"],
723
+ ["example/pcd/hm3d_00113_3goH1WRaCYC.png","example/pcd/hm3d_00113_3goH1WRaCYC.ply",0,1.2, "True"],
724
+ ["example/pcd/s3dis_Area2_auditorium1.png","example/pcd/s3dis_Area2_auditorium1.ply",0,1.2, "True"],
725
+ # ["example/pcd/s3dis_Area4_lobby1.png","example/pcd/s3dis_Area4_lobby1.ply",1,1., "True"],
726
+ ]
727
+
728
+ # ["example/pcd/scannetpp_2a1b555966.png","example/pcd/scannetpp_2a1b555966.ply",1,1.1, "True"],
729
+ # ["example/pcd/hm3d_00012_kDgLKdMd5X8_1.png","example/pcd/hm3d_00012_kDgLKdMd5X8_1.ply",0,1.0, "True"],
730
+ # ["example/pcd/s3dis_Area2_conferenceRoom1.png","example/pcd/s3dis_Area2_conferenceRoom1.ply",0,1.2, "True"],
731
+ # ["example/pcd/s3dis_Area4_hallway3.png","example/pcd/s3dis_Area4_hallway3.ply",0,1.2, "True"],
732
+
733
+ with gr.Blocks(
734
+ css="""
735
+ .custom-log * {
736
+ font-style: italic;
737
+ font-size: 22px !important;
738
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
739
+ -webkit-background-clip: text;
740
+ background-clip: text;
741
+ font-weight: bold !important;
742
+ color: transparent !important;
743
+ text-align: center !important;
744
+ width: 800px;
745
+ height: 100px;
746
+ }
747
+
748
+ .example-log * {
749
+ font-style: italic;
750
+ font-size: 16px !important;
751
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
752
+ -webkit-background-clip: text;
753
+ background-clip: text;
754
+ color: transparent !important;
755
+ }
756
+
757
+ .common-markdown * {
758
+ font-size: 22px !important;
759
+ -webkit-background-clip: text;
760
+ background-clip: text;
761
+ font-weight: bold !important;
762
+ color: #0ea5e9 !important;
763
+ text-align: center !important;
764
+ }
765
+
766
+ #big-box {
767
+ border: 3px solid #00bcd4;
768
+ padding: 20px;
769
+ background-color: transparent;
770
+ border-radius: 15px;
771
+ }
772
+
773
+ #my_radio .wrap {
774
+ display: flex;
775
+ flex-wrap: nowrap;
776
+ justify-content: center;
777
+ align-items: center;
778
+ }
779
+
780
+ #my_radio .wrap label {
781
+ display: flex;
782
+ width: 50%;
783
+ justify-content: center;
784
+ align-items: center;
785
+ margin: 0;
786
+ padding: 10px 0;
787
+ box-sizing: border-box;
788
+ }
789
+ """,
790
+ ) as demo:
791
+ gr.HTML(
792
+ """
793
+ <h1>Concerto: Joint 2D-3D Self-Supervised Learning for Emergent Spatial Representations</h1>
794
+ <div style="font-size: 16px; line-height: 1.5;">
795
+ <ol>
796
+ <details style="display:inline;">
797
+ <summary style="display:inline;"><h3>Getting Started:(<strong>Click to expand</strong>)</h3></summary>
798
+ <li><strong>Before Start:</strong> We deploy the model on CPU, thus making the inference speed slow.</li>
799
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Point Cloud" blocks on the left to provide your input. If you upload a video, it will be automatically split into individual frames with the specified frame gap by VGGT.</li>
800
+ <li>
801
+ <strong>[Optional] Adjust Video-Lifted Point Cloud:</strong>
802
+ Before reconstructing the video, you can fine-tune the VGGT lifting process using the options below
803
+ <details style="display:inline;">
804
+ <summary style="display:inline;">(<strong>Click to expand</strong>)</summary>
805
+ <ul>
806
+ <li><em>Frame Gap / N Sec:</em> Adjust the frame interval.</li>
807
+ <li><em>Confidence Threshold:</em> Adjust the point filtering based on confidence levels.</li>
808
+ <li><em>Select Prediction Mode:</em> Choose between "Depthmap Branch" and "Pointmap Branch."</li>
809
+ <li><em>TSDF Integration (Depthmap Branch Mode):</em> Enable TSDF integration to reduce noise in the point cloud when using the "Depthmap Branch" mode. This procedure will cost a long time for refinement.</li>
810
+ </ul>
811
+ </details>
812
+ </li>
813
+ <li><strong>PCA Generation:</strong> After reconstruction, click the "PCA Generate" button to start the representation extraction and PCA process.</li>
814
+ <li><strong>Clear:</strong> Click the "Clear" button to reset all content in the blocks.</li>
815
+ <li><strong>Point Cloud Preview:</strong> Your uploaded video or point cloud will be displayed in this block.</li>
816
+ <li><strong>PCA Result:</strong> The PCA point cloud will appear here. You can rotate, drag, and zoom to explore the model, and download the GLB file.</li>
817
+ <li>
818
+ <strong>[Optional] Adjust the Point Cloud Input (pre-release feature of the next work): use the checkbox "Input with Point Cloud Color" and "Input with Point Cloud Normal".
819
+ </li>
820
+ <li>
821
+ <strong>[Optional] Adjust PCA Visualization:</strong>
822
+ Fine-tune the PCA visualization using the options below
823
+ <details style="display:inline;">
824
+ <summary style="display:inline;">(<strong>Click to expand</strong>)</summary>
825
+ <ul>
826
+ <li><em>Model Type:</em> Choose the model from Concerto and Sonata.</li>
827
+ <li><em>PCA Start Dimension:</em> PCA reduces high-dimensional representations into 3D vectors. Adjust the PCA start dimension to change the range of the visualization. Increasing this value can help you see PCA visualization with less variance when the initial PCA dimension shows less diversity.</li>
828
+ <li><em>PCA Brightness:</em> Adjust the brightness of the PCA visualization results.</li>
829
+ <li><em>Notice:</em> As a linear dimension reduction method, PCA has its limitation. Sometimes, the visualization cannot fully exhibit the quality of representations.</li>
830
+ </ul>
831
+ </details>
832
+ </li>
833
+ </details>
834
+ </ol>
835
+ </div>
836
+
837
+ """
838
+ )
839
+ _ = gr.Textbox(label="_", visible=False, value="False")
840
+ is_example = gr.Textbox(label="is_example", visible=False, value="False")
841
+ target_dir = gr.Textbox(label="Target Dir", visible=False, value="None")
842
+ preview_imgs = gr.Image(type="filepath",label="Preview Imgs", visible=False, value="None")
843
+ with gr.Row():
844
+ with gr.Column(scale=1,elem_id="big-box"):
845
+ input_file = gr.File(label="Upload Point Cloud", file_types=[".ply"])
846
+ input_video = gr.Video(label="Upload Video", interactive=True)
847
+ image_gallery = gr.Gallery(
848
+ label="Preview",
849
+ columns=4,
850
+ height="300px",
851
+ show_download_button=True,
852
+ object_fit="contain",
853
+ preview=True,
854
+ )
855
+
856
+ frame_slider = gr.Slider(minimum=0.1, maximum=10, value=1, step=0.1,
857
+ label="1 Frame/ N Sec", interactive=True)
858
+ conf_thres = gr.Slider(minimum=0, maximum=100, value=10, step=0.1,
859
+ label="Confidence", interactive=True)
860
+ prediction_mode = gr.Radio(
861
+ ["Depthmap Branch", "Pointmap Branch"],
862
+ label="Select a Prediction Mode",
863
+ value="Depthmap Branch",
864
+ scale=1,
865
+ elem_id="my_radio",
866
+ )
867
+ TSDF_mode = gr.Radio(
868
+ ["Yes", "No"],
869
+ label="TSDF integration under Depthmap Branch mode",
870
+ value="Yes",
871
+ scale=1,
872
+ elem_id="my_radio",
873
+ )
874
+ reconstruction_btn = gr.Button("Video Reconstruct")
875
+ with gr.Column(scale=2):
876
+ log_output = gr.Markdown(
877
+ "Please upload a video or point cloud ply file, then click \"PCA Generate\".", elem_classes=["custom-log"]
878
+ )
879
+ original_view = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, label="Point Cloud Preview", camera_position = (90,None,None))
880
+ processed_view = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, label="PCA Result", camera_position = (90,None,None))
881
+ with gr.Row():
882
+ if_color = gr.Checkbox(label="Input with Point Cloud Color", value=True)
883
+ if_normal = gr.Checkbox(label="Input with Point Cloud Normal", value=True)
884
+ model_type = gr.Radio(
885
+ ["Concerto", "Sonata"],
886
+ label="Select a Model Type",
887
+ value="Concerto",
888
+ scale=1,
889
+ elem_id="my_radio",
890
+ )
891
+ pca_slider = gr.Slider(minimum=0, maximum=5, value=0, step=1,
892
+ label="PCA Start Dimension", interactive=True)
893
+ bright_slider = gr.Slider(minimum=0.5, maximum=1.5, value=1.2, step=0.05,
894
+ label="PCA Brightness", interactive=True)
895
+ with gr.Row():
896
+ submit_btn = gr.Button("PCA Generate")
897
+ clear_btn = gr.ClearButton(
898
+ [input_video, input_file, original_view, processed_view, log_output, target_dir, image_gallery],
899
+ scale=1,
900
+ elem_id="my_clear",
901
+ )
902
+
903
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
904
+ with gr.Row():
905
+ def example_video_updated(
906
+ inputs,
907
+ conf_thres,
908
+ frame_slider,
909
+ prediction_mode,
910
+ TSDF_mode,
911
+ pca_slider,
912
+ bright_slider,
913
+ is_example,
914
+ ):
915
+ return inputs,conf_thres,frame_slider,prediction_mode,TSDF_mode,pca_slider,bright_slider,is_example
916
+ gr.Examples(
917
+ examples=examples_video,
918
+ inputs=[
919
+ input_video,
920
+ conf_thres,
921
+ frame_slider,
922
+ prediction_mode,
923
+ TSDF_mode,
924
+ pca_slider,
925
+ bright_slider,
926
+ is_example,
927
+ ],
928
+ outputs=[
929
+ input_video,
930
+ conf_thres,
931
+ frame_slider,
932
+ prediction_mode,
933
+ TSDF_mode,
934
+ pca_slider,
935
+ bright_slider,
936
+ is_example,
937
+ ],
938
+ label = "Video Examples",
939
+ fn=example_video_updated,
940
+ cache_examples=False,
941
+ examples_per_page=50,
942
+ # examples_per_page=2
943
+ )
944
+ with gr.Row():
945
+ def example_file_updated(
946
+ preview_imgs,
947
+ inputs,
948
+ pca_slider,
949
+ bright_slider,
950
+ is_example,
951
+ ):
952
+ return inputs,pca_slider,bright_slider,is_example
953
+ gr.Examples(
954
+ examples=examples_pcd,
955
+ inputs=[
956
+ preview_imgs,
957
+ input_file,
958
+ pca_slider,
959
+ bright_slider,
960
+ is_example,
961
+ ],
962
+ outputs=[
963
+ input_file,
964
+ pca_slider,
965
+ bright_slider,
966
+ is_example,
967
+ ],
968
+ label = "Point Cloud Examples",
969
+ fn=example_file_updated,
970
+ cache_examples=False,
971
+ examples_per_page=50,
972
+ # examples_per_page=2
973
+ )
974
+
975
+ reconstruction_btn.click(
976
+ fn = update_gallery_on_upload,
977
+ inputs = [input_file,input_video,conf_thres,frame_slider,prediction_mode,TSDF_mode],
978
+ outputs = [original_view, target_dir, image_gallery, log_output]
979
+ )
980
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
981
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
982
+ ).then(
983
+ fn=gradio_demo,
984
+ inputs=[target_dir,pca_slider,bright_slider, model_type, if_color, if_normal],
985
+ outputs=[processed_view,log_output],
986
+ ).then(
987
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
988
+ )
989
+
990
+ pca_slider.change(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
991
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
992
+ ).then(
993
+ fn=concerto_slider_update,
994
+ inputs=[target_dir,pca_slider,bright_slider,is_example,log_output],
995
+ outputs=[processed_view, log_output],
996
+ ).then(
997
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
998
+ )
999
+ bright_slider.change(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
1000
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
1001
+ ).then(
1002
+ fn=concerto_slider_update,
1003
+ inputs=[target_dir,pca_slider,bright_slider,is_example,log_output],
1004
+ outputs=[processed_view, log_output],
1005
+ ).then(
1006
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
1007
+ )
1008
+ model_type.change(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
1009
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
1010
+ ).then(
1011
+ fn=gradio_demo,
1012
+ inputs=[target_dir,pca_slider,bright_slider, model_type, if_color, if_normal],
1013
+ outputs=[processed_view,log_output],
1014
+ ).then(
1015
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
1016
+ )
1017
+
1018
+ input_file.change(fn=reset_log, inputs=[], outputs=[log_output]).then(
1019
+ fn=update_gallery_on_upload,
1020
+ inputs=[input_file,input_video, conf_thres,frame_slider,prediction_mode,TSDF_mode],
1021
+ outputs=[original_view, target_dir, _, log_output],
1022
+ )
1023
+
1024
+ input_video.change(fn=reset_log, inputs=[], outputs=[log_output]).then(
1025
+ fn=update_gallery_on_upload,
1026
+ inputs=[input_file,input_video, conf_thres,frame_slider,prediction_mode,TSDF_mode],
1027
+ outputs=[original_view, target_dir, image_gallery, log_output],
1028
+ )
1029
+
1030
+ if __name__ == "__main__":
1031
+ demo.queue(max_size=20).launch(show_error=True, share=True)
concerto/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .model import load
17
+
18
+ from . import model
19
+ from . import module
20
+ from . import structure
21
+ from . import data
22
+ from . import transform
23
+ from . import utils
24
+ from . import registry
25
+
26
+ __all__ = ["load", "model", "module", "structure", "transform", "registry", "utils"]
concerto/data.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import numpy as np
18
+ import torch
19
+ from collections.abc import Mapping, Sequence
20
+ from huggingface_hub import hf_hub_download
21
+
22
+
23
+ DATAS = ["sample1", "sample1_high_res", "sample1_dino"]
24
+
25
+
26
+ def load(
27
+ name: str = "sonata",
28
+ download_root: str = None,
29
+ ):
30
+ if name in DATAS:
31
+ print(f"Loading data from HuggingFace: {name} ...")
32
+ data_path = hf_hub_download(
33
+ repo_id="pointcept/demo",
34
+ filename=f"{name}.npz",
35
+ repo_type="dataset",
36
+ revision="main",
37
+ local_dir=download_root or os.path.expanduser("~/.cache/sonata/data"),
38
+ )
39
+ elif os.path.isfile(name):
40
+ print(f"Loading data in local path: {name} ...")
41
+ data_path = name
42
+ else:
43
+ raise RuntimeError(f"Data {name} not found; available models = {DATAS}")
44
+ return dict(np.load(data_path))
45
+
46
+
47
+ from torch.utils.data.dataloader import default_collate
48
+
49
+
50
+ def collate_fn(batch):
51
+ """
52
+ collate function for point cloud which support dict and list,
53
+ 'coord' is necessary to determine 'offset'
54
+ """
55
+ if not isinstance(batch, Sequence):
56
+ raise TypeError(f"{batch.dtype} is not supported.")
57
+
58
+ if isinstance(batch[0], torch.Tensor):
59
+ return torch.cat(list(batch))
60
+ elif isinstance(batch[0], str):
61
+ # str is also a kind of Sequence, judgement should before Sequence
62
+ return list(batch)
63
+ elif isinstance(batch[0], Sequence):
64
+ for data in batch:
65
+ data.append(torch.tensor([data[0].shape[0]]))
66
+ batch = [collate_fn(samples) for samples in zip(*batch)]
67
+ batch[-1] = torch.cumsum(batch[-1], dim=0).int()
68
+ return batch
69
+ elif isinstance(batch[0], Mapping):
70
+ batch = {
71
+ key: (
72
+ collate_fn([d[key] for d in batch])
73
+ if "offset" not in key
74
+ # offset -> bincount -> concat bincount-> concat offset
75
+ else torch.cumsum(
76
+ collate_fn([d[key].diff(prepend=torch.tensor([0])) for d in batch]),
77
+ dim=0,
78
+ )
79
+ )
80
+ for key in batch[0]
81
+ }
82
+ return batch
83
+ else:
84
+ return default_collate(batch)
concerto/model.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point Transformer - V3 Mode2 - Sonata & Concerto
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import os
25
+ from packaging import version
26
+ from huggingface_hub import hf_hub_download, PyTorchModelHubMixin
27
+ from addict import Dict
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch.nn.init import trunc_normal_
31
+ import spconv.pytorch as spconv
32
+ import torch_scatter
33
+ from timm.layers import DropPath
34
+
35
+
36
+ try:
37
+ import flash_attn
38
+ except ImportError:
39
+ flash_attn = None
40
+
41
+ from .structure import Point
42
+ from .module import PointSequential, PointModule
43
+ from .utils import offset2bincount
44
+
45
+ MODELS = [
46
+ "sonata",
47
+ "concerto_large",
48
+ "concerto_base",
49
+ "concerto_small",
50
+ "concerto_large_linear_prob_head_sc",
51
+ ]
52
+
53
+
54
+ class LayerScale(nn.Module):
55
+ def __init__(
56
+ self,
57
+ dim: int,
58
+ init_values: float = 1e-5,
59
+ inplace: bool = False,
60
+ ) -> None:
61
+ super().__init__()
62
+ self.inplace = inplace
63
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
67
+
68
+
69
+ class RPE(torch.nn.Module):
70
+ def __init__(self, patch_size, num_heads):
71
+ super().__init__()
72
+ self.patch_size = patch_size
73
+ self.num_heads = num_heads
74
+ self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
75
+ self.rpe_num = 2 * self.pos_bnd + 1
76
+ self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
77
+ torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
78
+
79
+ def forward(self, coord):
80
+ idx = (
81
+ coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
82
+ + self.pos_bnd # relative position to positive index
83
+ + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
84
+ )
85
+ out = self.rpe_table.index_select(0, idx.reshape(-1))
86
+ out = out.view(idx.shape + (-1,)).sum(3)
87
+ out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
88
+ return out
89
+
90
+
91
+ class SerializedAttention(PointModule):
92
+ def __init__(
93
+ self,
94
+ channels,
95
+ num_heads,
96
+ patch_size,
97
+ qkv_bias=True,
98
+ qk_scale=None,
99
+ attn_drop=0.0,
100
+ proj_drop=0.0,
101
+ order_index=0,
102
+ enable_rpe=False,
103
+ enable_flash=True,
104
+ upcast_attention=True,
105
+ upcast_softmax=True,
106
+ ):
107
+ super().__init__()
108
+ assert channels % num_heads == 0
109
+ self.channels = channels
110
+ self.num_heads = num_heads
111
+ self.scale = qk_scale or (channels // num_heads) ** -0.5
112
+ self.order_index = order_index
113
+ self.upcast_attention = upcast_attention
114
+ self.upcast_softmax = upcast_softmax
115
+ self.enable_rpe = enable_rpe
116
+ self.enable_flash = enable_flash
117
+ if enable_flash:
118
+ assert (
119
+ enable_rpe is False
120
+ ), "Set enable_rpe to False when enable Flash Attention"
121
+ assert (
122
+ upcast_attention is False
123
+ ), "Set upcast_attention to False when enable Flash Attention"
124
+ assert (
125
+ upcast_softmax is False
126
+ ), "Set upcast_softmax to False when enable Flash Attention"
127
+ assert flash_attn is not None, "Make sure flash_attn is installed."
128
+ self.patch_size = patch_size
129
+ self.attn_drop = attn_drop
130
+ else:
131
+ # when disable flash attention, we still don't want to use mask
132
+ # consequently, patch size will auto set to the
133
+ # min number of patch_size_max and number of points
134
+ self.patch_size_max = patch_size
135
+ self.patch_size = 0
136
+ self.attn_drop = torch.nn.Dropout(attn_drop)
137
+
138
+ self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
139
+ self.proj = torch.nn.Linear(channels, channels)
140
+ self.proj_drop = torch.nn.Dropout(proj_drop)
141
+ self.softmax = torch.nn.Softmax(dim=-1)
142
+ self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
143
+
144
+ @torch.no_grad()
145
+ def get_rel_pos(self, point, order):
146
+ K = self.patch_size
147
+ rel_pos_key = f"rel_pos_{self.order_index}"
148
+ if rel_pos_key not in point.keys():
149
+ grid_coord = point.grid_coord[order]
150
+ grid_coord = grid_coord.reshape(-1, K, 3)
151
+ point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
152
+ return point[rel_pos_key]
153
+
154
+ @torch.no_grad()
155
+ def get_padding_and_inverse(self, point):
156
+ pad_key = "pad"
157
+ unpad_key = "unpad"
158
+ cu_seqlens_key = "cu_seqlens_key"
159
+ if (
160
+ pad_key not in point.keys()
161
+ or unpad_key not in point.keys()
162
+ or cu_seqlens_key not in point.keys()
163
+ ):
164
+ offset = point.offset
165
+ bincount = offset2bincount(offset)
166
+ bincount_pad = (
167
+ torch.div(
168
+ bincount + self.patch_size - 1,
169
+ self.patch_size,
170
+ rounding_mode="trunc",
171
+ )
172
+ * self.patch_size
173
+ )
174
+ # only pad point when num of points larger than patch_size
175
+ mask_pad = bincount > self.patch_size
176
+ bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
177
+ _offset = nn.functional.pad(offset, (1, 0))
178
+ _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
179
+ pad = torch.arange(_offset_pad[-1], device=offset.device)
180
+ unpad = torch.arange(_offset[-1], device=offset.device)
181
+ cu_seqlens = []
182
+ for i in range(len(offset)):
183
+ unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
184
+ if bincount[i] != bincount_pad[i]:
185
+ pad[
186
+ _offset_pad[i + 1]
187
+ - self.patch_size
188
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
189
+ ] = pad[
190
+ _offset_pad[i + 1]
191
+ - 2 * self.patch_size
192
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
193
+ - self.patch_size
194
+ ]
195
+ pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
196
+ cu_seqlens.append(
197
+ torch.arange(
198
+ _offset_pad[i],
199
+ _offset_pad[i + 1],
200
+ step=self.patch_size,
201
+ dtype=torch.int32,
202
+ device=offset.device,
203
+ )
204
+ )
205
+ point[pad_key] = pad
206
+ point[unpad_key] = unpad
207
+ point[cu_seqlens_key] = nn.functional.pad(
208
+ torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
209
+ )
210
+ return point[pad_key], point[unpad_key], point[cu_seqlens_key]
211
+
212
+ def forward(self, point):
213
+ if not self.enable_flash:
214
+ self.patch_size = min(
215
+ offset2bincount(point.offset).min().tolist(), self.patch_size_max
216
+ )
217
+
218
+ H = self.num_heads
219
+ K = self.patch_size
220
+ C = self.channels
221
+
222
+ pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
223
+
224
+ order = point.serialized_order[self.order_index][pad]
225
+ inverse = unpad[point.serialized_inverse[self.order_index]]
226
+
227
+ # padding and reshape feat and batch for serialized point patch
228
+ qkv = self.qkv(point.feat)[order]
229
+
230
+ if not self.enable_flash:
231
+ # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
232
+ q, k, v = (
233
+ qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
234
+ )
235
+ # attn
236
+ if self.upcast_attention:
237
+ q = q.float()
238
+ k = k.float()
239
+ attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
240
+ if self.enable_rpe:
241
+ attn = attn + self.rpe(self.get_rel_pos(point, order))
242
+ if self.upcast_softmax:
243
+ attn = attn.float()
244
+ attn = self.softmax(attn)
245
+ attn = self.attn_drop(attn).to(qkv.dtype)
246
+ feat = (attn @ v).transpose(1, 2).reshape(-1, C)
247
+ else:
248
+ feat = flash_attn.flash_attn_varlen_qkvpacked_func(
249
+ qkv.half().reshape(-1, 3, H, C // H),
250
+ cu_seqlens,
251
+ max_seqlen=self.patch_size,
252
+ dropout_p=self.attn_drop if self.training else 0,
253
+ softmax_scale=self.scale,
254
+ ).reshape(-1, C)
255
+ feat = feat.to(qkv.dtype)
256
+ feat = feat[inverse]
257
+
258
+ # ffn
259
+ feat = self.proj(feat)
260
+ feat = self.proj_drop(feat)
261
+ point.feat = feat
262
+ return point
263
+
264
+
265
+ class MLP(nn.Module):
266
+ def __init__(
267
+ self,
268
+ in_channels,
269
+ hidden_channels=None,
270
+ out_channels=None,
271
+ act_layer=nn.GELU,
272
+ drop=0.0,
273
+ ):
274
+ super().__init__()
275
+ out_channels = out_channels or in_channels
276
+ hidden_channels = hidden_channels or in_channels
277
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
278
+ self.act = act_layer()
279
+ self.fc2 = nn.Linear(hidden_channels, out_channels)
280
+ self.drop = nn.Dropout(drop)
281
+
282
+ def forward(self, x):
283
+ x = self.fc1(x)
284
+ x = self.act(x)
285
+ x = self.drop(x)
286
+ x = self.fc2(x)
287
+ x = self.drop(x)
288
+ return x
289
+
290
+
291
+ class Block(PointModule):
292
+ def __init__(
293
+ self,
294
+ channels,
295
+ num_heads,
296
+ patch_size=48,
297
+ mlp_ratio=4.0,
298
+ qkv_bias=True,
299
+ qk_scale=None,
300
+ attn_drop=0.0,
301
+ proj_drop=0.0,
302
+ drop_path=0.0,
303
+ layer_scale=None,
304
+ norm_layer=nn.LayerNorm,
305
+ act_layer=nn.GELU,
306
+ pre_norm=True,
307
+ order_index=0,
308
+ cpe_indice_key=None,
309
+ enable_rpe=False,
310
+ enable_flash=True,
311
+ upcast_attention=True,
312
+ upcast_softmax=True,
313
+ ):
314
+ super().__init__()
315
+ self.channels = channels
316
+ self.pre_norm = pre_norm
317
+
318
+ self.cpe = PointSequential(
319
+ spconv.SubMConv3d(
320
+ channels,
321
+ channels,
322
+ kernel_size=3,
323
+ bias=True,
324
+ indice_key=cpe_indice_key,
325
+ ),
326
+ nn.Linear(channels, channels),
327
+ norm_layer(channels),
328
+ )
329
+
330
+ self.norm1 = PointSequential(norm_layer(channels))
331
+ self.ls1 = PointSequential(
332
+ LayerScale(channels, init_values=layer_scale)
333
+ if layer_scale is not None
334
+ else nn.Identity()
335
+ )
336
+ self.attn = SerializedAttention(
337
+ channels=channels,
338
+ patch_size=patch_size,
339
+ num_heads=num_heads,
340
+ qkv_bias=qkv_bias,
341
+ qk_scale=qk_scale,
342
+ attn_drop=attn_drop,
343
+ proj_drop=proj_drop,
344
+ order_index=order_index,
345
+ enable_rpe=enable_rpe,
346
+ enable_flash=enable_flash,
347
+ upcast_attention=upcast_attention,
348
+ upcast_softmax=upcast_softmax,
349
+ )
350
+ self.norm2 = PointSequential(norm_layer(channels))
351
+ self.ls2 = PointSequential(
352
+ LayerScale(channels, init_values=layer_scale)
353
+ if layer_scale is not None
354
+ else nn.Identity()
355
+ )
356
+ self.mlp = PointSequential(
357
+ MLP(
358
+ in_channels=channels,
359
+ hidden_channels=int(channels * mlp_ratio),
360
+ out_channels=channels,
361
+ act_layer=act_layer,
362
+ drop=proj_drop,
363
+ )
364
+ )
365
+ self.drop_path = PointSequential(
366
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
367
+ )
368
+
369
+ def forward(self, point: Point):
370
+ shortcut = point.feat
371
+ point = self.cpe(point)
372
+ point.feat = shortcut + point.feat
373
+ shortcut = point.feat
374
+ if self.pre_norm:
375
+ point = self.norm1(point)
376
+ point = self.drop_path(self.ls1(self.attn(point)))
377
+ point.feat = shortcut + point.feat
378
+ if not self.pre_norm:
379
+ point = self.norm1(point)
380
+
381
+ shortcut = point.feat
382
+ if self.pre_norm:
383
+ point = self.norm2(point)
384
+ point = self.drop_path(self.ls2(self.mlp(point)))
385
+ point.feat = shortcut + point.feat
386
+ if not self.pre_norm:
387
+ point = self.norm2(point)
388
+ point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
389
+ return point
390
+
391
+
392
+ class GridPooling(PointModule):
393
+ def __init__(
394
+ self,
395
+ in_channels,
396
+ out_channels,
397
+ stride=2,
398
+ norm_layer=None,
399
+ act_layer=None,
400
+ reduce="max",
401
+ shuffle_orders=True,
402
+ traceable=True, # record parent and cluster
403
+ ):
404
+ super().__init__()
405
+ self.in_channels = in_channels
406
+ self.out_channels = out_channels
407
+
408
+ self.stride = stride
409
+ assert reduce in ["sum", "mean", "min", "max"]
410
+ self.reduce = reduce
411
+ self.shuffle_orders = shuffle_orders
412
+ self.traceable = traceable
413
+
414
+ self.proj = nn.Linear(in_channels, out_channels)
415
+ if norm_layer is not None:
416
+ self.norm = PointSequential(norm_layer(out_channels))
417
+ if act_layer is not None:
418
+ self.act = PointSequential(act_layer())
419
+
420
+ def forward(self, point: Point):
421
+ if "grid_coord" in point.keys():
422
+ grid_coord = point.grid_coord
423
+ elif {"coord", "grid_size"}.issubset(point.keys()):
424
+ grid_coord = torch.div(
425
+ point.coord - point.coord.min(0)[0],
426
+ point.grid_size,
427
+ rounding_mode="trunc",
428
+ ).int()
429
+ else:
430
+ raise AssertionError(
431
+ "[gird_coord] or [coord, grid_size] should be include in the Point"
432
+ )
433
+ grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc")
434
+ grid_coord = grid_coord | point.batch.view(-1, 1) << 48
435
+ grid_coord, cluster, counts = torch.unique(
436
+ grid_coord,
437
+ sorted=True,
438
+ return_inverse=True,
439
+ return_counts=True,
440
+ dim=0,
441
+ )
442
+ grid_coord = grid_coord & ((1 << 48) - 1)
443
+ # indices of point sorted by cluster, for torch_scatter.segment_csr
444
+ _, indices = torch.sort(cluster)
445
+ # index pointer for sorted point, for torch_scatter.segment_csr
446
+ idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
447
+ # head_indices of each cluster, for reduce attr e.g. code, batch
448
+ head_indices = indices[idx_ptr[:-1]]
449
+ point_dict = Dict(
450
+ feat=torch_scatter.segment_csr(
451
+ self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
452
+ ),
453
+ coord=torch_scatter.segment_csr(
454
+ point.coord[indices], idx_ptr, reduce="mean"
455
+ ),
456
+ grid_coord=grid_coord,
457
+ batch=point.batch[head_indices],
458
+ )
459
+ if "origin_coord" in point.keys():
460
+ point_dict["origin_coord"] = torch_scatter.segment_csr(
461
+ point.origin_coord[indices], idx_ptr, reduce="mean"
462
+ )
463
+ if "condition" in point.keys():
464
+ point_dict["condition"] = point.condition
465
+ if "context" in point.keys():
466
+ point_dict["context"] = point.context
467
+ if "name" in point.keys():
468
+ point_dict["name"] = point.name
469
+ if "split" in point.keys():
470
+ point_dict["split"] = point.split
471
+ if "color" in point.keys():
472
+ point_dict["color"] = torch_scatter.segment_csr(
473
+ point.color[indices], idx_ptr, reduce="mean"
474
+ )
475
+ if "grid_size" in point.keys():
476
+ point_dict["grid_size"] = point.grid_size * self.stride
477
+
478
+ if self.traceable:
479
+ point_dict["pooling_inverse"] = cluster
480
+ point_dict["pooling_parent"] = point
481
+ point_dict["idx_ptr"] = idx_ptr
482
+ order = point.order
483
+ point = Point(point_dict)
484
+ if self.norm is not None:
485
+ point = self.norm(point)
486
+ if self.act is not None:
487
+ point = self.act(point)
488
+ point.serialization(order=order, shuffle_orders=self.shuffle_orders)
489
+ point.sparsify()
490
+ return point
491
+
492
+
493
+ class GridUnpooling(PointModule):
494
+ def __init__(
495
+ self,
496
+ in_channels,
497
+ skip_channels,
498
+ out_channels,
499
+ norm_layer=None,
500
+ act_layer=None,
501
+ traceable=False, # record parent and cluster
502
+ ):
503
+ super().__init__()
504
+ self.proj = PointSequential(nn.Linear(in_channels, out_channels))
505
+ self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
506
+
507
+ if norm_layer is not None:
508
+ self.proj.add(norm_layer(out_channels))
509
+ self.proj_skip.add(norm_layer(out_channels))
510
+
511
+ if act_layer is not None:
512
+ self.proj.add(act_layer())
513
+ self.proj_skip.add(act_layer())
514
+
515
+ self.traceable = traceable
516
+
517
+ def forward(self, point):
518
+ assert "pooling_parent" in point.keys()
519
+ assert "pooling_inverse" in point.keys()
520
+ parent = point.pop("pooling_parent")
521
+ inverse = point.pooling_inverse
522
+ feat = point.feat
523
+
524
+ parent = self.proj_skip(parent)
525
+ parent.feat = parent.feat + self.proj(point).feat[inverse]
526
+ parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat)
527
+
528
+ if self.traceable:
529
+ point.feat = feat
530
+ parent["unpooling_parent"] = point
531
+ return parent
532
+
533
+
534
+ class Embedding(PointModule):
535
+ def __init__(
536
+ self,
537
+ in_channels,
538
+ embed_channels,
539
+ norm_layer=None,
540
+ act_layer=None,
541
+ mask_token=False,
542
+ ):
543
+ super().__init__()
544
+ self.in_channels = in_channels
545
+ self.embed_channels = embed_channels
546
+
547
+ self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels))
548
+ if norm_layer is not None:
549
+ self.stem.add(norm_layer(embed_channels), name="norm")
550
+ if act_layer is not None:
551
+ self.stem.add(act_layer(), name="act")
552
+
553
+ if mask_token:
554
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_channels))
555
+ else:
556
+ self.mask_token = None
557
+
558
+ def forward(self, point: Point):
559
+ point = self.stem(point)
560
+ if "mask" in point.keys():
561
+ point.feat = torch.where(
562
+ point.mask.unsqueeze(-1),
563
+ self.mask_token.to(point.feat.dtype),
564
+ point.feat,
565
+ )
566
+ return point
567
+
568
+
569
+ class PointTransformerV3(PointModule, PyTorchModelHubMixin):
570
+ def __init__(
571
+ self,
572
+ in_channels=6,
573
+ order=("z", "z-trans"),
574
+ stride=(2, 2, 2, 2),
575
+ enc_depths=(3, 3, 3, 12, 3),
576
+ enc_channels=(48, 96, 192, 384, 512),
577
+ enc_num_head=(3, 6, 12, 24, 32),
578
+ enc_patch_size=(1024, 1024, 1024, 1024, 1024),
579
+ dec_depths=(3, 3, 3, 3),
580
+ dec_channels=(96, 96, 192, 384),
581
+ dec_num_head=(6, 6, 12, 32),
582
+ dec_patch_size=(1024, 1024, 1024, 1024),
583
+ mlp_ratio=4,
584
+ qkv_bias=True,
585
+ qk_scale=None,
586
+ attn_drop=0.0,
587
+ proj_drop=0.0,
588
+ drop_path=0.3,
589
+ layer_scale=None,
590
+ pre_norm=True,
591
+ shuffle_orders=True,
592
+ enable_rpe=False,
593
+ enable_flash=True,
594
+ upcast_attention=False,
595
+ upcast_softmax=False,
596
+ traceable=False,
597
+ mask_token=False,
598
+ enc_mode=False,
599
+ freeze_encoder=False,
600
+ ):
601
+ super().__init__()
602
+ self.num_stages = len(enc_depths)
603
+ self.order = [order] if isinstance(order, str) else order
604
+ self.enc_mode = enc_mode
605
+ self.shuffle_orders = shuffle_orders
606
+ self.freeze_encoder = freeze_encoder
607
+
608
+ assert self.num_stages == len(stride) + 1
609
+ assert self.num_stages == len(enc_depths)
610
+ assert self.num_stages == len(enc_channels)
611
+ assert self.num_stages == len(enc_num_head)
612
+ assert self.num_stages == len(enc_patch_size)
613
+ assert self.enc_mode or self.num_stages == len(dec_depths) + 1
614
+ assert self.enc_mode or self.num_stages == len(dec_channels) + 1
615
+ assert self.enc_mode or self.num_stages == len(dec_num_head) + 1
616
+ assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1
617
+
618
+ # normalization layer
619
+ ln_layer = nn.LayerNorm
620
+ # activation layers
621
+ act_layer = nn.GELU
622
+
623
+ self.embedding = Embedding(
624
+ in_channels=in_channels,
625
+ embed_channels=enc_channels[0],
626
+ norm_layer=ln_layer,
627
+ act_layer=act_layer,
628
+ mask_token=mask_token,
629
+ )
630
+
631
+ # encoder
632
+ enc_drop_path = [
633
+ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
634
+ ]
635
+ self.enc = PointSequential()
636
+ for s in range(self.num_stages):
637
+ enc_drop_path_ = enc_drop_path[
638
+ sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
639
+ ]
640
+ enc = PointSequential()
641
+ if s > 0:
642
+ enc.add(
643
+ GridPooling(
644
+ in_channels=enc_channels[s - 1],
645
+ out_channels=enc_channels[s],
646
+ stride=stride[s - 1],
647
+ norm_layer=ln_layer,
648
+ act_layer=act_layer,
649
+ ),
650
+ name="down",
651
+ )
652
+ for i in range(enc_depths[s]):
653
+ enc.add(
654
+ Block(
655
+ channels=enc_channels[s],
656
+ num_heads=enc_num_head[s],
657
+ patch_size=enc_patch_size[s],
658
+ mlp_ratio=mlp_ratio,
659
+ qkv_bias=qkv_bias,
660
+ qk_scale=qk_scale,
661
+ attn_drop=attn_drop,
662
+ proj_drop=proj_drop,
663
+ drop_path=enc_drop_path_[i],
664
+ layer_scale=layer_scale,
665
+ norm_layer=ln_layer,
666
+ act_layer=act_layer,
667
+ pre_norm=pre_norm,
668
+ order_index=i % len(self.order),
669
+ cpe_indice_key=f"stage{s}",
670
+ enable_rpe=enable_rpe,
671
+ enable_flash=enable_flash,
672
+ upcast_attention=upcast_attention,
673
+ upcast_softmax=upcast_softmax,
674
+ ),
675
+ name=f"block{i}",
676
+ )
677
+ if len(enc) != 0:
678
+ self.enc.add(module=enc, name=f"enc{s}")
679
+
680
+ # decoder
681
+ if not self.enc_mode:
682
+ dec_drop_path = [
683
+ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
684
+ ]
685
+ self.dec = PointSequential()
686
+ dec_channels = list(dec_channels) + [enc_channels[-1]]
687
+ for s in reversed(range(self.num_stages - 1)):
688
+ dec_drop_path_ = dec_drop_path[
689
+ sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
690
+ ]
691
+ dec_drop_path_.reverse()
692
+ dec = PointSequential()
693
+ dec.add(
694
+ GridUnpooling(
695
+ in_channels=dec_channels[s + 1],
696
+ skip_channels=enc_channels[s],
697
+ out_channels=dec_channels[s],
698
+ norm_layer=ln_layer,
699
+ act_layer=act_layer,
700
+ traceable=traceable,
701
+ ),
702
+ name="up",
703
+ )
704
+ for i in range(dec_depths[s]):
705
+ dec.add(
706
+ Block(
707
+ channels=dec_channels[s],
708
+ num_heads=dec_num_head[s],
709
+ patch_size=dec_patch_size[s],
710
+ mlp_ratio=mlp_ratio,
711
+ qkv_bias=qkv_bias,
712
+ qk_scale=qk_scale,
713
+ attn_drop=attn_drop,
714
+ proj_drop=proj_drop,
715
+ drop_path=dec_drop_path_[i],
716
+ layer_scale=layer_scale,
717
+ norm_layer=ln_layer,
718
+ act_layer=act_layer,
719
+ pre_norm=pre_norm,
720
+ order_index=i % len(self.order),
721
+ cpe_indice_key=f"stage{s}",
722
+ enable_rpe=enable_rpe,
723
+ enable_flash=enable_flash,
724
+ upcast_attention=upcast_attention,
725
+ upcast_softmax=upcast_softmax,
726
+ ),
727
+ name=f"block{i}",
728
+ )
729
+ self.dec.add(module=dec, name=f"dec{s}")
730
+ if self.freeze_encoder:
731
+ for p in self.embedding.parameters():
732
+ p.requires_grad = False
733
+ for p in self.enc.parameters():
734
+ p.requires_grad = False
735
+ self.apply(self._init_weights)
736
+
737
+ @staticmethod
738
+ def _init_weights(module):
739
+ if isinstance(module, nn.Linear):
740
+ trunc_normal_(module.weight, std=0.02)
741
+ if module.bias is not None:
742
+ nn.init.zeros_(module.bias)
743
+ elif isinstance(module, spconv.SubMConv3d):
744
+ trunc_normal_(module.weight, std=0.02)
745
+ if module.bias is not None:
746
+ nn.init.zeros_(module.bias)
747
+
748
+ def forward(self, data_dict):
749
+ point = Point(data_dict)
750
+ point = self.embedding(point)
751
+
752
+ point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
753
+ point.sparsify()
754
+
755
+ point = self.enc(point)
756
+ if not self.enc_mode:
757
+ point = self.dec(point)
758
+ return point
759
+
760
+
761
+ def load(
762
+ name: str = "concerto_large",
763
+ repo_id="Pointcept/Concerto",
764
+ download_root: str = None,
765
+ custom_config: dict = None,
766
+ ckpt_only: bool = False,
767
+ ):
768
+ if name in MODELS:
769
+ print(f"Loading checkpoint from HuggingFace: {name} ...")
770
+ ckpt_path = hf_hub_download(
771
+ repo_id=repo_id,
772
+ filename=f"{name}.pth",
773
+ repo_type="model",
774
+ revision="main",
775
+ local_dir=download_root or os.path.expanduser("~/.cache/concerto/ckpt"),
776
+ )
777
+ elif os.path.isfile(name):
778
+ print(f"Loading checkpoint in local path: {name} ...")
779
+ ckpt_path = name
780
+ else:
781
+ raise RuntimeError(f"Model {name} not found; available models = {MODELS}")
782
+
783
+ if version.parse(torch.__version__) >= version.parse("2.4"):
784
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
785
+ else:
786
+ ckpt = torch.load(ckpt_path, map_location="cpu")
787
+ if custom_config is not None:
788
+ for key, value in custom_config.items():
789
+ ckpt["config"][key] = value
790
+
791
+ if ckpt_only:
792
+ return ckpt
793
+
794
+ model = PointTransformerV3(**ckpt["config"])
795
+ model.load_state_dict(ckpt["state_dict"])
796
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
797
+ print(f"Model params: {n_parameters / 1e6:.2f}M")
798
+ return model
concerto/module.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point Modules
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import sys
25
+ import torch.nn as nn
26
+ import spconv.pytorch as spconv
27
+ from collections import OrderedDict
28
+
29
+ from .structure import Point
30
+
31
+
32
+ class PointModule(nn.Module):
33
+ r"""PointModule
34
+ placeholder, all module subclass from this will take Point in PointSequential.
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+
41
+ class PointSequential(PointModule):
42
+ r"""A sequential container.
43
+ Modules will be added to it in the order they are passed in the constructor.
44
+ Alternatively, an ordered dict of modules can also be passed in.
45
+ """
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__()
49
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
50
+ for key, module in args[0].items():
51
+ self.add_module(key, module)
52
+ else:
53
+ for idx, module in enumerate(args):
54
+ self.add_module(str(idx), module)
55
+ for name, module in kwargs.items():
56
+ if sys.version_info < (3, 6):
57
+ raise ValueError("kwargs only supported in py36+")
58
+ if name in self._modules:
59
+ raise ValueError("name exists.")
60
+ self.add_module(name, module)
61
+
62
+ def __getitem__(self, idx):
63
+ if not (-len(self) <= idx < len(self)):
64
+ raise IndexError("index {} is out of range".format(idx))
65
+ if idx < 0:
66
+ idx += len(self)
67
+ it = iter(self._modules.values())
68
+ for i in range(idx):
69
+ next(it)
70
+ return next(it)
71
+
72
+ def __len__(self):
73
+ return len(self._modules)
74
+
75
+ def add(self, module, name=None):
76
+ if name is None:
77
+ name = str(len(self._modules))
78
+ if name in self._modules:
79
+ raise KeyError("name exists")
80
+ self.add_module(name, module)
81
+
82
+ def forward(self, input):
83
+ for k, module in self._modules.items():
84
+ # Point module
85
+ if isinstance(module, PointModule):
86
+ input = module(input)
87
+ # Spconv module
88
+ elif spconv.modules.is_spconv_module(module):
89
+ if isinstance(input, Point):
90
+ input.sparse_conv_feat = module(input.sparse_conv_feat)
91
+ input.feat = input.sparse_conv_feat.features
92
+ else:
93
+ input = module(input)
94
+ # PyTorch module
95
+ else:
96
+ if isinstance(input, Point):
97
+ input.feat = module(input.feat)
98
+ if "sparse_conv_feat" in input.keys():
99
+ input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(
100
+ input.feat
101
+ )
102
+ elif isinstance(input, spconv.SparseConvTensor):
103
+ if input.indices.shape[0] != 0:
104
+ input = input.replace_feature(module(input.features))
105
+ else:
106
+ input = module(input)
107
+ return input
concerto/registry.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @lint-ignore-every LICENSELINT
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+ import inspect
4
+ import warnings
5
+ from functools import partial
6
+ from collections import abc
7
+
8
+
9
+ def is_seq_of(seq, expected_type, seq_type=None):
10
+ """Check whether it is a sequence of some type.
11
+
12
+ Args:
13
+ seq (Sequence): The sequence to be checked.
14
+ expected_type (type): Expected type of sequence items.
15
+ seq_type (type, optional): Expected sequence type.
16
+
17
+ Returns:
18
+ bool: Whether the sequence is valid.
19
+ """
20
+ if seq_type is None:
21
+ exp_seq_type = abc.Sequence
22
+ else:
23
+ assert isinstance(seq_type, type)
24
+ exp_seq_type = seq_type
25
+ if not isinstance(seq, exp_seq_type):
26
+ return False
27
+ for item in seq:
28
+ if not isinstance(item, expected_type):
29
+ return False
30
+ return True
31
+
32
+
33
+ def build_from_cfg(cfg, registry, default_args=None):
34
+ """Build a module from configs dict.
35
+
36
+ Args:
37
+ cfg (dict): Config dict. It should at least contain the key "type".
38
+ registry (:obj:`Registry`): The registry to search the type from.
39
+ default_args (dict, optional): Default initialization arguments.
40
+
41
+ Returns:
42
+ object: The constructed object.
43
+ """
44
+ if not isinstance(cfg, dict):
45
+ raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
46
+ if "type" not in cfg:
47
+ if default_args is None or "type" not in default_args:
48
+ raise KeyError(
49
+ '`cfg` or `default_args` must contain the key "type", '
50
+ f"but got {cfg}\n{default_args}"
51
+ )
52
+ if not isinstance(registry, Registry):
53
+ raise TypeError(
54
+ "registry must be an mmcv.Registry object, " f"but got {type(registry)}"
55
+ )
56
+ if not (isinstance(default_args, dict) or default_args is None):
57
+ raise TypeError(
58
+ "default_args must be a dict or None, " f"but got {type(default_args)}"
59
+ )
60
+
61
+ args = cfg.copy()
62
+
63
+ if default_args is not None:
64
+ for name, value in default_args.items():
65
+ args.setdefault(name, value)
66
+
67
+ obj_type = args.pop("type")
68
+ if isinstance(obj_type, str):
69
+ obj_cls = registry.get(obj_type)
70
+ if obj_cls is None:
71
+ raise KeyError(f"{obj_type} is not in the {registry.name} registry")
72
+ elif inspect.isclass(obj_type):
73
+ obj_cls = obj_type
74
+ else:
75
+ raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
76
+ try:
77
+ return obj_cls(**args)
78
+ except Exception as e:
79
+ # Normal TypeError does not print class name.
80
+ raise type(e)(f"{obj_cls.__name__}: {e}")
81
+
82
+
83
+ class Registry:
84
+ """A registry to map strings to classes.
85
+
86
+ Registered object could be built from registry.
87
+ Example:
88
+ >>> MODELS = Registry('models')
89
+ >>> @MODELS.register_module()
90
+ >>> class ResNet:
91
+ >>> pass
92
+ >>> resnet = MODELS.build(dict(type='ResNet'))
93
+
94
+ Please refer to
95
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
96
+ advanced usage.
97
+
98
+ Args:
99
+ name (str): Registry name.
100
+ build_func(func, optional): Build function to construct instance from
101
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
102
+ ``build_func`` is specified. If ``parent`` is specified and
103
+ ``build_func`` is not given, ``build_func`` will be inherited
104
+ from ``parent``. Default: None.
105
+ parent (Registry, optional): Parent registry. The class registered in
106
+ children registry could be built from parent. Default: None.
107
+ scope (str, optional): The scope of registry. It is the key to search
108
+ for children registry. If not specified, scope will be the name of
109
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
110
+ Default: None.
111
+ """
112
+
113
+ def __init__(self, name, build_func=None, parent=None, scope=None):
114
+ self._name = name
115
+ self._module_dict = dict()
116
+ self._children = dict()
117
+ self._scope = self.infer_scope() if scope is None else scope
118
+
119
+ # self.build_func will be set with the following priority:
120
+ # 1. build_func
121
+ # 2. parent.build_func
122
+ # 3. build_from_cfg
123
+ if build_func is None:
124
+ if parent is not None:
125
+ self.build_func = parent.build_func
126
+ else:
127
+ self.build_func = build_from_cfg
128
+ else:
129
+ self.build_func = build_func
130
+ if parent is not None:
131
+ assert isinstance(parent, Registry)
132
+ parent._add_children(self)
133
+ self.parent = parent
134
+ else:
135
+ self.parent = None
136
+
137
+ def __len__(self):
138
+ return len(self._module_dict)
139
+
140
+ def __contains__(self, key):
141
+ return self.get(key) is not None
142
+
143
+ def __repr__(self):
144
+ format_str = (
145
+ self.__class__.__name__ + f"(name={self._name}, "
146
+ f"items={self._module_dict})"
147
+ )
148
+ return format_str
149
+
150
+ @staticmethod
151
+ def infer_scope():
152
+ """Infer the scope of registry.
153
+
154
+ The name of the package where registry is defined will be returned.
155
+
156
+ Example:
157
+ # in mmdet/models/backbone/resnet.py
158
+ >>> MODELS = Registry('models')
159
+ >>> @MODELS.register_module()
160
+ >>> class ResNet:
161
+ >>> pass
162
+ The scope of ``ResNet`` will be ``mmdet``.
163
+
164
+
165
+ Returns:
166
+ scope (str): The inferred scope name.
167
+ """
168
+ # inspect.stack() trace where this function is called, the index-2
169
+ # indicates the frame where `infer_scope()` is called
170
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
171
+ split_filename = filename.split(".")
172
+ return split_filename[0]
173
+
174
+ @staticmethod
175
+ def split_scope_key(key):
176
+ """Split scope and key.
177
+
178
+ The first scope will be split from key.
179
+
180
+ Examples:
181
+ >>> Registry.split_scope_key('mmdet.ResNet')
182
+ 'mmdet', 'ResNet'
183
+ >>> Registry.split_scope_key('ResNet')
184
+ None, 'ResNet'
185
+
186
+ Return:
187
+ scope (str, None): The first scope.
188
+ key (str): The remaining key.
189
+ """
190
+ split_index = key.find(".")
191
+ if split_index != -1:
192
+ return key[:split_index], key[split_index + 1 :]
193
+ else:
194
+ return None, key
195
+
196
+ @property
197
+ def name(self):
198
+ return self._name
199
+
200
+ @property
201
+ def scope(self):
202
+ return self._scope
203
+
204
+ @property
205
+ def module_dict(self):
206
+ return self._module_dict
207
+
208
+ @property
209
+ def children(self):
210
+ return self._children
211
+
212
+ def get(self, key):
213
+ """Get the registry record.
214
+
215
+ Args:
216
+ key (str): The class name in string format.
217
+
218
+ Returns:
219
+ class: The corresponding class.
220
+ """
221
+ scope, real_key = self.split_scope_key(key)
222
+ if scope is None or scope == self._scope:
223
+ # get from self
224
+ if real_key in self._module_dict:
225
+ return self._module_dict[real_key]
226
+ else:
227
+ # get from self._children
228
+ if scope in self._children:
229
+ return self._children[scope].get(real_key)
230
+ else:
231
+ # goto root
232
+ parent = self.parent
233
+ while parent.parent is not None:
234
+ parent = parent.parent
235
+ return parent.get(key)
236
+
237
+ def build(self, *args, **kwargs):
238
+ return self.build_func(*args, **kwargs, registry=self)
239
+
240
+ def _add_children(self, registry):
241
+ """Add children for a registry.
242
+
243
+ The ``registry`` will be added as children based on its scope.
244
+ The parent registry could build objects from children registry.
245
+
246
+ Example:
247
+ >>> models = Registry('models')
248
+ >>> mmdet_models = Registry('models', parent=models)
249
+ >>> @mmdet_models.register_module()
250
+ >>> class ResNet:
251
+ >>> pass
252
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
253
+ """
254
+
255
+ assert isinstance(registry, Registry)
256
+ assert registry.scope is not None
257
+ assert (
258
+ registry.scope not in self.children
259
+ ), f"scope {registry.scope} exists in {self.name} registry"
260
+ self.children[registry.scope] = registry
261
+
262
+ def _register_module(self, module_class, module_name=None, force=False):
263
+ if not inspect.isclass(module_class):
264
+ raise TypeError("module must be a class, " f"but got {type(module_class)}")
265
+
266
+ if module_name is None:
267
+ module_name = module_class.__name__
268
+ if isinstance(module_name, str):
269
+ module_name = [module_name]
270
+ for name in module_name:
271
+ if not force and name in self._module_dict:
272
+ raise KeyError(f"{name} is already registered " f"in {self.name}")
273
+ self._module_dict[name] = module_class
274
+
275
+ def deprecated_register_module(self, cls=None, force=False):
276
+ warnings.warn(
277
+ "The old API of register_module(module, force=False) "
278
+ "is deprecated and will be removed, please use the new API "
279
+ "register_module(name=None, force=False, module=None) instead."
280
+ )
281
+ if cls is None:
282
+ return partial(self.deprecated_register_module, force=force)
283
+ self._register_module(cls, force=force)
284
+ return cls
285
+
286
+ def register_module(self, name=None, force=False, module=None):
287
+ """Register a module.
288
+
289
+ A record will be added to `self._module_dict`, whose key is the class
290
+ name or the specified name, and value is the class itself.
291
+ It can be used as a decorator or a normal function.
292
+
293
+ Example:
294
+ >>> backbones = Registry('backbone')
295
+ >>> @backbones.register_module()
296
+ >>> class ResNet:
297
+ >>> pass
298
+
299
+ >>> backbones = Registry('backbone')
300
+ >>> @backbones.register_module(name='mnet')
301
+ >>> class MobileNet:
302
+ >>> pass
303
+
304
+ >>> backbones = Registry('backbone')
305
+ >>> class ResNet:
306
+ >>> pass
307
+ >>> backbones.register_module(ResNet)
308
+
309
+ Args:
310
+ name (str | None): The module name to be registered. If not
311
+ specified, the class name will be used.
312
+ force (bool, optional): Whether to override an existing class with
313
+ the same name. Default: False.
314
+ module (type): Module class to be registered.
315
+ """
316
+ if not isinstance(force, bool):
317
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
318
+ # NOTE: This is a walkaround to be compatible with the old api,
319
+ # while it may introduce unexpected bugs.
320
+ if isinstance(name, type):
321
+ return self.deprecated_register_module(name, force=force)
322
+
323
+ # raise the error ahead of time
324
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
325
+ raise TypeError(
326
+ "name must be either of None, an instance of str or a sequence"
327
+ f" of str, but got {type(name)}"
328
+ )
329
+
330
+ # use it as a normal method: x.register_module(module=SomeClass)
331
+ if module is not None:
332
+ self._register_module(module_class=module, module_name=name, force=force)
333
+ return module
334
+
335
+ # use it as a decorator: @x.register_module()
336
+ def _register(cls):
337
+ self._register_module(module_class=cls, module_name=name, force=force)
338
+ return cls
339
+
340
+ return _register
concerto/serialization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .default import (
2
+ encode,
3
+ decode,
4
+ z_order_encode,
5
+ z_order_decode,
6
+ hilbert_encode,
7
+ hilbert_decode,
8
+ )
concerto/serialization/default.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Serialization Encoding
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import torch
25
+ from .z_order import xyz2key as z_order_encode_
26
+ from .z_order import key2xyz as z_order_decode_
27
+ from .hilbert import encode as hilbert_encode_
28
+ from .hilbert import decode as hilbert_decode_
29
+
30
+
31
+ @torch.inference_mode()
32
+ def encode(grid_coord, batch=None, depth=16, order="z"):
33
+ assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
34
+ if order == "z":
35
+ code = z_order_encode(grid_coord, depth=depth)
36
+ elif order == "z-trans":
37
+ code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
38
+ elif order == "hilbert":
39
+ code = hilbert_encode(grid_coord, depth=depth)
40
+ elif order == "hilbert-trans":
41
+ code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
42
+ else:
43
+ raise NotImplementedError
44
+ if batch is not None:
45
+ batch = batch.long()
46
+ code = batch << depth * 3 | code
47
+ return code
48
+
49
+
50
+ @torch.inference_mode()
51
+ def decode(code, depth=16, order="z"):
52
+ assert order in {"z", "hilbert"}
53
+ batch = code >> depth * 3
54
+ code = code & ((1 << depth * 3) - 1)
55
+ if order == "z":
56
+ grid_coord = z_order_decode(code, depth=depth)
57
+ elif order == "hilbert":
58
+ grid_coord = hilbert_decode(code, depth=depth)
59
+ else:
60
+ raise NotImplementedError
61
+ return grid_coord, batch
62
+
63
+
64
+ def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
65
+ x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
66
+ # we block the support to batch, maintain batched code in Point class
67
+ code = z_order_encode_(x, y, z, b=None, depth=depth)
68
+ return code
69
+
70
+
71
+ def z_order_decode(code: torch.Tensor, depth):
72
+ x, y, z = z_order_decode_(code, depth=depth)
73
+ grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
74
+ return grid_coord
75
+
76
+
77
+ def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
78
+ return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
79
+
80
+
81
+ def hilbert_decode(code: torch.Tensor, depth: int = 16):
82
+ return hilbert_decode_(code, num_dims=3, num_bits=depth)
concerto/serialization/hilbert.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hilbert Order
3
+ Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import torch
25
+
26
+
27
+ def right_shift(binary, k=1, axis=-1):
28
+ """Right shift an array of binary values.
29
+
30
+ Parameters:
31
+ -----------
32
+ binary: An ndarray of binary values.
33
+
34
+ k: The number of bits to shift. Default 1.
35
+
36
+ axis: The axis along which to shift. Default -1.
37
+
38
+ Returns:
39
+ --------
40
+ Returns an ndarray with zero prepended and the ends truncated, along
41
+ whatever axis was specified."""
42
+
43
+ # If we're shifting the whole thing, just return zeros.
44
+ if binary.shape[axis] <= k:
45
+ return torch.zeros_like(binary)
46
+
47
+ # Determine the padding pattern.
48
+ # padding = [(0,0)] * len(binary.shape)
49
+ # padding[axis] = (k,0)
50
+
51
+ # Determine the slicing pattern to eliminate just the last one.
52
+ slicing = [slice(None)] * len(binary.shape)
53
+ slicing[axis] = slice(None, -k)
54
+ shifted = torch.nn.functional.pad(
55
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
56
+ )
57
+
58
+ return shifted
59
+
60
+
61
+ def binary2gray(binary, axis=-1):
62
+ """Convert an array of binary values into Gray codes.
63
+
64
+ This uses the classic X ^ (X >> 1) trick to compute the Gray code.
65
+
66
+ Parameters:
67
+ -----------
68
+ binary: An ndarray of binary values.
69
+
70
+ axis: The axis along which to compute the gray code. Default=-1.
71
+
72
+ Returns:
73
+ --------
74
+ Returns an ndarray of Gray codes.
75
+ """
76
+ shifted = right_shift(binary, axis=axis)
77
+
78
+ # Do the X ^ (X >> 1) trick.
79
+ gray = torch.logical_xor(binary, shifted)
80
+
81
+ return gray
82
+
83
+
84
+ def gray2binary(gray, axis=-1):
85
+ """Convert an array of Gray codes back into binary values.
86
+
87
+ Parameters:
88
+ -----------
89
+ gray: An ndarray of gray codes.
90
+
91
+ axis: The axis along which to perform Gray decoding. Default=-1.
92
+
93
+ Returns:
94
+ --------
95
+ Returns an ndarray of binary values.
96
+ """
97
+
98
+ # Loop the log2(bits) number of times necessary, with shift and xor.
99
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
100
+ while shift > 0:
101
+ gray = torch.logical_xor(gray, right_shift(gray, shift))
102
+ shift = torch.div(shift, 2, rounding_mode="floor")
103
+ return gray
104
+
105
+
106
+ def encode(locs, num_dims, num_bits):
107
+ """Decode an array of locations in a hypercube into a Hilbert integer.
108
+
109
+ This is a vectorized-ish version of the Hilbert curve implementation by John
110
+ Skilling as described in:
111
+
112
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
113
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
114
+
115
+ Params:
116
+ -------
117
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
118
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
119
+ be arbitrary, as long as the last dimension of the same has size
120
+ num_dims.
121
+
122
+ num_dims - The dimensionality of the hypercube. Integer.
123
+
124
+ num_bits - The number of bits for each dimension. Integer.
125
+
126
+ Returns:
127
+ --------
128
+ The output is an ndarray of uint64 integers with the same shape as the
129
+ input, excluding the last dimension, which needs to be num_dims.
130
+ """
131
+
132
+ # Keep around the original shape for later.
133
+ orig_shape = locs.shape
134
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
135
+ bitpack_mask_rev = bitpack_mask.flip(-1)
136
+
137
+ if orig_shape[-1] != num_dims:
138
+ raise ValueError(
139
+ """
140
+ The shape of locs was surprising in that the last dimension was of size
141
+ %d, but num_dims=%d. These need to be equal.
142
+ """
143
+ % (orig_shape[-1], num_dims)
144
+ )
145
+
146
+ if num_dims * num_bits > 63:
147
+ raise ValueError(
148
+ """
149
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
150
+ into a int64. Are you sure you need that many points on your Hilbert
151
+ curve?
152
+ """
153
+ % (num_dims, num_bits, num_dims * num_bits)
154
+ )
155
+
156
+ # Treat the location integers as 64-bit unsigned and then split them up into
157
+ # a sequence of uint8s. Preserve the association by dimension.
158
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
159
+
160
+ # Now turn these into bits and truncate to num_bits.
161
+ gray = (
162
+ locs_uint8.unsqueeze(-1)
163
+ .bitwise_and(bitpack_mask_rev)
164
+ .ne(0)
165
+ .byte()
166
+ .flatten(-2, -1)[..., -num_bits:]
167
+ )
168
+
169
+ # Run the decoding process the other way.
170
+ # Iterate forwards through the bits.
171
+ for bit in range(0, num_bits):
172
+ # Iterate forwards through the dimensions.
173
+ for dim in range(0, num_dims):
174
+ # Identify which ones have this bit active.
175
+ mask = gray[:, dim, bit]
176
+
177
+ # Where this bit is on, invert the 0 dimension for lower bits.
178
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
179
+ gray[:, 0, bit + 1 :], mask[:, None]
180
+ )
181
+
182
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
183
+ to_flip = torch.logical_and(
184
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
185
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
186
+ )
187
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
188
+ gray[:, dim, bit + 1 :], to_flip
189
+ )
190
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
191
+
192
+ # Now flatten out.
193
+ gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
194
+
195
+ # Convert Gray back to binary.
196
+ hh_bin = gray2binary(gray)
197
+
198
+ # Pad back out to 64 bits.
199
+ extra_dims = 64 - num_bits * num_dims
200
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
201
+
202
+ # Convert binary values into uint8s.
203
+ hh_uint8 = (
204
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
205
+ .sum(2)
206
+ .squeeze()
207
+ .type(torch.uint8)
208
+ )
209
+
210
+ # Convert uint8s into uint64s.
211
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
212
+
213
+ return hh_uint64
214
+
215
+
216
+ def decode(hilberts, num_dims, num_bits):
217
+ """Decode an array of Hilbert integers into locations in a hypercube.
218
+
219
+ This is a vectorized-ish version of the Hilbert curve implementation by John
220
+ Skilling as described in:
221
+
222
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
223
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
224
+
225
+ Params:
226
+ -------
227
+ hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
228
+ cannot have fewer bits than num_dims * num_bits.
229
+
230
+ num_dims - The dimensionality of the hypercube. Integer.
231
+
232
+ num_bits - The number of bits for each dimension. Integer.
233
+
234
+ Returns:
235
+ --------
236
+ The output is an ndarray of unsigned integers with the same shape as hilberts
237
+ but with an additional dimension of size num_dims.
238
+ """
239
+
240
+ if num_dims * num_bits > 64:
241
+ raise ValueError(
242
+ """
243
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
244
+ into a uint64. Are you sure you need that many points on your Hilbert
245
+ curve?
246
+ """
247
+ % (num_dims, num_bits)
248
+ )
249
+
250
+ # Handle the case where we got handed a naked integer.
251
+ hilberts = torch.atleast_1d(hilberts)
252
+
253
+ # Keep around the shape for later.
254
+ orig_shape = hilberts.shape
255
+ bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
256
+ bitpack_mask_rev = bitpack_mask.flip(-1)
257
+
258
+ # Treat each of the hilberts as a s equence of eight uint8.
259
+ # This treats all of the inputs as uint64 and makes things uniform.
260
+ hh_uint8 = (
261
+ hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
262
+ )
263
+
264
+ # Turn these lists of uints into lists of bits and then truncate to the size
265
+ # we actually need for using Skilling's procedure.
266
+ hh_bits = (
267
+ hh_uint8.unsqueeze(-1)
268
+ .bitwise_and(bitpack_mask_rev)
269
+ .ne(0)
270
+ .byte()
271
+ .flatten(-2, -1)[:, -num_dims * num_bits :]
272
+ )
273
+
274
+ # Take the sequence of bits and Gray-code it.
275
+ gray = binary2gray(hh_bits)
276
+
277
+ # There has got to be a better way to do this.
278
+ # I could index them differently, but the eventual packbits likes it this way.
279
+ gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
280
+
281
+ # Iterate backwards through the bits.
282
+ for bit in range(num_bits - 1, -1, -1):
283
+ # Iterate backwards through the dimensions.
284
+ for dim in range(num_dims - 1, -1, -1):
285
+ # Identify which ones have this bit active.
286
+ mask = gray[:, dim, bit]
287
+
288
+ # Where this bit is on, invert the 0 dimension for lower bits.
289
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
290
+ gray[:, 0, bit + 1 :], mask[:, None]
291
+ )
292
+
293
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
294
+ to_flip = torch.logical_and(
295
+ torch.logical_not(mask[:, None]),
296
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
297
+ )
298
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
299
+ gray[:, dim, bit + 1 :], to_flip
300
+ )
301
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
302
+
303
+ # Pad back out to 64 bits.
304
+ extra_dims = 64 - num_bits
305
+ padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
306
+
307
+ # Now chop these up into blocks of 8.
308
+ locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
309
+
310
+ # Take those blocks and turn them unto uint8s.
311
+ # from IPython import embed; embed()
312
+ locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
313
+
314
+ # Finally, treat these as uint64s.
315
+ flat_locs = locs_uint8.view(torch.int64)
316
+
317
+ # Return them in the expected shape.
318
+ return flat_locs.reshape((*orig_shape, num_dims))
concerto/serialization/z_order.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @lint-ignore-every LICENSELINT
2
+ # --------------------------------------------------------
3
+ # Octree-based Sparse Convolutional Neural Networks
4
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Written by Peng-Shuai Wang
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ from typing import Optional, Union
11
+
12
+
13
+ class KeyLUT:
14
+ def __init__(self):
15
+ r256 = torch.arange(256, dtype=torch.int64)
16
+ r512 = torch.arange(512, dtype=torch.int64)
17
+ zero = torch.zeros(256, dtype=torch.int64)
18
+ device = torch.device("cpu")
19
+
20
+ self._encode = {
21
+ device: (
22
+ self.xyz2key(r256, zero, zero, 8),
23
+ self.xyz2key(zero, r256, zero, 8),
24
+ self.xyz2key(zero, zero, r256, 8),
25
+ )
26
+ }
27
+ self._decode = {device: self.key2xyz(r512, 9)}
28
+
29
+ def encode_lut(self, device=torch.device("cpu")):
30
+ if device not in self._encode:
31
+ cpu = torch.device("cpu")
32
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
33
+ return self._encode[device]
34
+
35
+ def decode_lut(self, device=torch.device("cpu")):
36
+ if device not in self._decode:
37
+ cpu = torch.device("cpu")
38
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
39
+ return self._decode[device]
40
+
41
+ def xyz2key(self, x, y, z, depth):
42
+ key = torch.zeros_like(x)
43
+ for i in range(depth):
44
+ mask = 1 << i
45
+ key = (
46
+ key
47
+ | ((x & mask) << (2 * i + 2))
48
+ | ((y & mask) << (2 * i + 1))
49
+ | ((z & mask) << (2 * i + 0))
50
+ )
51
+ return key
52
+
53
+ def key2xyz(self, key, depth):
54
+ x = torch.zeros_like(key)
55
+ y = torch.zeros_like(key)
56
+ z = torch.zeros_like(key)
57
+ for i in range(depth):
58
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
59
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
60
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
61
+ return x, y, z
62
+
63
+
64
+ _key_lut = KeyLUT()
65
+
66
+
67
+ def xyz2key(
68
+ x: torch.Tensor,
69
+ y: torch.Tensor,
70
+ z: torch.Tensor,
71
+ b: Optional[Union[torch.Tensor, int]] = None,
72
+ depth: int = 16,
73
+ ):
74
+ r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
75
+ based on pre-computed look up tables. The speed of this function is much
76
+ faster than the method based on for-loop.
77
+
78
+ Args:
79
+ x (torch.Tensor): The x coordinate.
80
+ y (torch.Tensor): The y coordinate.
81
+ z (torch.Tensor): The z coordinate.
82
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
83
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
84
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
85
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
86
+ """
87
+
88
+ EX, EY, EZ = _key_lut.encode_lut(x.device)
89
+ x, y, z = x.long(), y.long(), z.long()
90
+
91
+ mask = 255 if depth > 8 else (1 << depth) - 1
92
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
93
+ if depth > 8:
94
+ mask = (1 << (depth - 8)) - 1
95
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
96
+ key = key16 << 24 | key
97
+
98
+ if b is not None:
99
+ b = b.long()
100
+ key = b << 48 | key
101
+
102
+ return key
103
+
104
+
105
+ def key2xyz(key: torch.Tensor, depth: int = 16):
106
+ r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
107
+ and the batch index based on pre-computed look up tables.
108
+
109
+ Args:
110
+ key (torch.Tensor): The shuffled key.
111
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
112
+ """
113
+
114
+ DX, DY, DZ = _key_lut.decode_lut(key.device)
115
+ x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
116
+
117
+ b = key >> 48
118
+ key = key & ((1 << 48) - 1)
119
+
120
+ n = (depth + 2) // 3
121
+ for i in range(n):
122
+ k = key >> (i * 9) & 511
123
+ x = x | (DX[k] << (i * 3))
124
+ y = y | (DY[k] << (i * 3))
125
+ z = z | (DZ[k] << (i * 3))
126
+
127
+ return x, y, z, b
concerto/structure.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data structure for 3D point cloud
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import torch
24
+ import spconv.pytorch as spconv
25
+ from addict import Dict
26
+
27
+ from .serialization import encode
28
+ from .utils import offset2batch, batch2offset
29
+
30
+
31
+ class Point(Dict):
32
+ """
33
+ Point Structure of Pointcept
34
+
35
+ A Point (point cloud) in Pointcept is a dictionary that contains various properties of
36
+ a batched point cloud. The property with the following names have a specific definition
37
+ as follows:
38
+
39
+ - "coord": original coordinate of point cloud;
40
+ - "grid_coord": grid coordinate for specific grid size (related to GridSampling);
41
+ Point also support the following optional attributes:
42
+ - "offset": if not exist, initialized as batch size is 1;
43
+ - "batch": if not exist, initialized as batch size is 1;
44
+ - "feat": feature of point cloud, default input of model;
45
+ - "grid_size": Grid size of point cloud (related to GridSampling);
46
+ (related to Serialization)
47
+ - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
48
+ - "serialized_code": a list of serialization codes;
49
+ - "serialized_order": a list of serialization order determined by code;
50
+ - "serialized_inverse": a list of inverse mapping determined by code;
51
+ (related to Sparsify: SpConv)
52
+ - "sparse_shape": Sparse shape for Sparse Conv Tensor;
53
+ - "sparse_conv_feat": SparseConvTensor init with information provide by Point;
54
+ """
55
+
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+ # If one of "offset" or "batch" do not exist, generate by the existing one
59
+ if "batch" not in self.keys() and "offset" in self.keys():
60
+ self["batch"] = offset2batch(self.offset)
61
+ elif "offset" not in self.keys() and "batch" in self.keys():
62
+ self["offset"] = batch2offset(self.batch)
63
+
64
+ def serialization(self, order="z", depth=None, shuffle_orders=False):
65
+ """
66
+ Point Cloud Serialization
67
+
68
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
69
+ """
70
+ self["order"] = order
71
+ assert "batch" in self.keys()
72
+ if "grid_coord" not in self.keys():
73
+ # if you don't want to operate GridSampling in data augmentation,
74
+ # please add the following augmentation into your pipeline:
75
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
76
+ # (adjust `grid_size` to what your want)
77
+ assert {"grid_size", "coord"}.issubset(self.keys())
78
+
79
+ self["grid_coord"] = torch.div(
80
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
81
+ ).int()
82
+
83
+ if depth is None:
84
+ # Adaptive measure the depth of serialization cube (length = 2 ^ depth)
85
+ depth = int(self.grid_coord.max() + 1).bit_length()
86
+ self["serialized_depth"] = depth
87
+ # Maximum bit length for serialization code is 63 (int64)
88
+ assert depth * 3 + len(self.offset).bit_length() <= 63
89
+ # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
90
+ # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
91
+ # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
92
+ # We can unlock the limitation by optimizing the z-order encoding function if necessary.
93
+ assert depth <= 16
94
+
95
+ # The serialization codes are arranged as following structures:
96
+ # [Order1 ([n]),
97
+ # Order2 ([n]),
98
+ # ...
99
+ # OrderN ([n])] (k, n)
100
+ code = [
101
+ encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
102
+ ]
103
+ code = torch.stack(code)
104
+ order = torch.argsort(code)
105
+ inverse = torch.zeros_like(order).scatter_(
106
+ dim=1,
107
+ index=order,
108
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
109
+ code.shape[0], 1
110
+ ),
111
+ )
112
+
113
+ if shuffle_orders:
114
+ perm = torch.randperm(code.shape[0])
115
+ code = code[perm]
116
+ order = order[perm]
117
+ inverse = inverse[perm]
118
+
119
+ self["serialized_code"] = code
120
+ self["serialized_order"] = order
121
+ self["serialized_inverse"] = inverse
122
+
123
+ def sparsify(self, pad=96):
124
+ """
125
+ Point Cloud Serialization
126
+
127
+ Point cloud is sparse, here we use "sparsify" to specifically refer to
128
+ preparing "spconv.SparseConvTensor" for SpConv.
129
+
130
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
131
+
132
+ pad: padding sparse for sparse shape.
133
+ """
134
+ assert {"feat", "batch"}.issubset(self.keys())
135
+ if "grid_coord" not in self.keys():
136
+ # if you don't want to operate GridSampling in data augmentation,
137
+ # please add the following augmentation into your pipeline:
138
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
139
+ # (adjust `grid_size` to what your want)
140
+ assert {"grid_size", "coord"}.issubset(self.keys())
141
+ self["grid_coord"] = torch.div(
142
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
143
+ ).int()
144
+ if "sparse_shape" in self.keys():
145
+ sparse_shape = self.sparse_shape
146
+ else:
147
+ sparse_shape = torch.add(
148
+ torch.max(self.grid_coord, dim=0).values, pad
149
+ ).tolist()
150
+ sparse_conv_feat = spconv.SparseConvTensor(
151
+ features=self.feat,
152
+ indices=torch.cat(
153
+ [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
154
+ ).contiguous(),
155
+ spatial_shape=sparse_shape,
156
+ batch_size=self.batch[-1].tolist() + 1,
157
+ )
158
+ self["sparse_shape"] = sparse_shape
159
+ self["sparse_conv_feat"] = sparse_conv_feat
concerto/transform.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D point cloud augmentation
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import random
24
+ import numbers
25
+ import scipy
26
+ import scipy.ndimage
27
+ import scipy.interpolate
28
+ import scipy.stats
29
+ import numpy as np
30
+ import torch
31
+ import copy
32
+ from collections.abc import Sequence, Mapping
33
+
34
+ from .registry import Registry
35
+
36
+ TRANSFORMS = Registry("transforms")
37
+
38
+
39
+ def index_operator(data_dict, index, duplicate=False):
40
+ # index selection operator for keys in "index_valid_keys"
41
+ # custom these keys by "Update" transform in config
42
+ if "index_valid_keys" not in data_dict:
43
+ data_dict["index_valid_keys"] = [
44
+ "coord",
45
+ "color",
46
+ "normal",
47
+ "strength",
48
+ "segment",
49
+ "instance",
50
+ ]
51
+ if not duplicate:
52
+ for key in data_dict["index_valid_keys"]:
53
+ if key in data_dict:
54
+ data_dict[key] = data_dict[key][index]
55
+ return data_dict
56
+ else:
57
+ data_dict_ = dict()
58
+ for key in data_dict.keys():
59
+ if key in data_dict["index_valid_keys"]:
60
+ data_dict_[key] = data_dict[key][index]
61
+ else:
62
+ data_dict_[key] = data_dict[key]
63
+ return data_dict_
64
+
65
+
66
+ @TRANSFORMS.register_module()
67
+ class Collect(object):
68
+ def __init__(self, keys, offset_keys_dict=None, **kwargs):
69
+ """
70
+ e.g. Collect(keys=[coord], feat_keys=[coord, color])
71
+ """
72
+ if offset_keys_dict is None:
73
+ offset_keys_dict = dict(offset="coord")
74
+ self.keys = keys
75
+ self.offset_keys = offset_keys_dict
76
+ self.kwargs = kwargs
77
+
78
+ def __call__(self, data_dict):
79
+ data = dict()
80
+ if isinstance(self.keys, str):
81
+ self.keys = [self.keys]
82
+ for key in self.keys:
83
+ data[key] = data_dict[key]
84
+ for key, value in self.offset_keys.items():
85
+ data[key] = torch.tensor([data_dict[value].shape[0]])
86
+ for name, keys in self.kwargs.items():
87
+ name = name.replace("_keys", "")
88
+ assert isinstance(keys, Sequence)
89
+ data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1)
90
+ return data
91
+
92
+
93
+ @TRANSFORMS.register_module()
94
+ class Copy(object):
95
+ def __init__(self, keys_dict=None):
96
+ if keys_dict is None:
97
+ keys_dict = dict(coord="origin_coord", segment="origin_segment")
98
+ self.keys_dict = keys_dict
99
+
100
+ def __call__(self, data_dict):
101
+ for key, value in self.keys_dict.items():
102
+ if isinstance(data_dict[key], np.ndarray):
103
+ data_dict[value] = data_dict[key].copy()
104
+ elif isinstance(data_dict[key], torch.Tensor):
105
+ data_dict[value] = data_dict[key].clone().detach()
106
+ else:
107
+ data_dict[value] = copy.deepcopy(data_dict[key])
108
+ return data_dict
109
+
110
+
111
+ @TRANSFORMS.register_module()
112
+ class Update(object):
113
+ def __init__(self, keys_dict=None):
114
+ if keys_dict is None:
115
+ keys_dict = dict()
116
+ self.keys_dict = keys_dict
117
+
118
+ def __call__(self, data_dict):
119
+ for key, value in self.keys_dict.items():
120
+ data_dict[key] = value
121
+ return data_dict
122
+
123
+
124
+ @TRANSFORMS.register_module()
125
+ class ToTensor(object):
126
+ def __call__(self, data):
127
+ if isinstance(data, torch.Tensor):
128
+ return data
129
+ elif isinstance(data, str):
130
+ # note that str is also a kind of sequence, judgement should before sequence
131
+ return data
132
+ elif isinstance(data, int):
133
+ return torch.LongTensor([data])
134
+ elif isinstance(data, float):
135
+ return torch.FloatTensor([data])
136
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool):
137
+ return torch.from_numpy(data)
138
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer):
139
+ return torch.from_numpy(data).long()
140
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating):
141
+ return torch.from_numpy(data).float()
142
+ elif isinstance(data, Mapping):
143
+ result = {sub_key: self(item) for sub_key, item in data.items()}
144
+ return result
145
+ elif isinstance(data, Sequence):
146
+ result = [self(item) for item in data]
147
+ return result
148
+ else:
149
+ raise TypeError(f"type {type(data)} cannot be converted to tensor.")
150
+
151
+
152
+ @TRANSFORMS.register_module()
153
+ class NormalizeColor(object):
154
+ def __call__(self, data_dict):
155
+ if "color" in data_dict.keys():
156
+ data_dict["color"] = data_dict["color"] / 255
157
+ return data_dict
158
+
159
+
160
+ @TRANSFORMS.register_module()
161
+ class NormalizeCoord(object):
162
+ def __call__(self, data_dict):
163
+ if "coord" in data_dict.keys():
164
+ # modified from pointnet2
165
+ centroid = np.mean(data_dict["coord"], axis=0)
166
+ data_dict["coord"] -= centroid
167
+ m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1)))
168
+ data_dict["coord"] = data_dict["coord"] / m
169
+ return data_dict
170
+
171
+
172
+ @TRANSFORMS.register_module()
173
+ class PositiveShift(object):
174
+ def __call__(self, data_dict):
175
+ if "coord" in data_dict.keys():
176
+ coord_min = np.min(data_dict["coord"], 0)
177
+ data_dict["coord"] -= coord_min
178
+ return data_dict
179
+
180
+
181
+ @TRANSFORMS.register_module()
182
+ class CenterShift(object):
183
+ def __init__(self, apply_z=True):
184
+ self.apply_z = apply_z
185
+
186
+ def __call__(self, data_dict):
187
+ if "coord" in data_dict.keys():
188
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
189
+ x_max, y_max, _ = data_dict["coord"].max(axis=0)
190
+ if self.apply_z:
191
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min]
192
+ else:
193
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0]
194
+ data_dict["coord"] -= shift
195
+ return data_dict
196
+
197
+
198
+ @TRANSFORMS.register_module()
199
+ class RandomShift(object):
200
+ def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))):
201
+ self.shift = shift
202
+
203
+ def __call__(self, data_dict):
204
+ if "coord" in data_dict.keys():
205
+ shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1])
206
+ shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1])
207
+ shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1])
208
+ data_dict["coord"] += [shift_x, shift_y, shift_z]
209
+ return data_dict
210
+
211
+
212
+ @TRANSFORMS.register_module()
213
+ class PointClip(object):
214
+ def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)):
215
+ self.point_cloud_range = point_cloud_range
216
+
217
+ def __call__(self, data_dict):
218
+ if "coord" in data_dict.keys():
219
+ data_dict["coord"] = np.clip(
220
+ data_dict["coord"],
221
+ a_min=self.point_cloud_range[:3],
222
+ a_max=self.point_cloud_range[3:],
223
+ )
224
+ return data_dict
225
+
226
+
227
+ @TRANSFORMS.register_module()
228
+ class RandomDropout(object):
229
+ def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
230
+ """
231
+ upright_axis: axis index among x,y,z, i.e. 2 for z
232
+ """
233
+ self.dropout_ratio = dropout_ratio
234
+ self.dropout_application_ratio = dropout_application_ratio
235
+
236
+ def __call__(self, data_dict):
237
+ if random.random() < self.dropout_application_ratio:
238
+ n = len(data_dict["coord"])
239
+ idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False)
240
+ if "sampled_index" in data_dict:
241
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
242
+ idx = np.unique(np.append(idx, data_dict["sampled_index"]))
243
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
244
+ mask[data_dict["sampled_index"]] = True
245
+ data_dict["sampled_index"] = np.where(mask[idx])[0]
246
+ data_dict = index_operator(data_dict, idx)
247
+ return data_dict
248
+
249
+
250
+ @TRANSFORMS.register_module()
251
+ class RandomRotate(object):
252
+ def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5):
253
+ self.angle = [-1, 1] if angle is None else angle
254
+ self.axis = axis
255
+ self.always_apply = always_apply
256
+ self.p = p if not self.always_apply else 1
257
+ self.center = center
258
+
259
+ def __call__(self, data_dict):
260
+ if random.random() > self.p:
261
+ return data_dict
262
+ angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi
263
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
264
+ if self.axis == "x":
265
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
266
+ elif self.axis == "y":
267
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
268
+ elif self.axis == "z":
269
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
270
+ else:
271
+ raise NotImplementedError
272
+ if "coord" in data_dict.keys():
273
+ if self.center is None:
274
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
275
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
276
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
277
+ else:
278
+ center = self.center
279
+ data_dict["coord"] -= center
280
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
281
+ data_dict["coord"] += center
282
+ if "normal" in data_dict.keys():
283
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
284
+ return data_dict
285
+
286
+
287
+ @TRANSFORMS.register_module()
288
+ class RandomRotateTargetAngle(object):
289
+ def __init__(
290
+ self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75
291
+ ):
292
+ self.angle = angle
293
+ self.axis = axis
294
+ self.always_apply = always_apply
295
+ self.p = p if not self.always_apply else 1
296
+ self.center = center
297
+
298
+ def __call__(self, data_dict):
299
+ if random.random() > self.p:
300
+ return data_dict
301
+ angle = np.random.choice(self.angle) * np.pi
302
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
303
+ if self.axis == "x":
304
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
305
+ elif self.axis == "y":
306
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
307
+ elif self.axis == "z":
308
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
309
+ else:
310
+ raise NotImplementedError
311
+ if "coord" in data_dict.keys():
312
+ if self.center is None:
313
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
314
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
315
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
316
+ else:
317
+ center = self.center
318
+ data_dict["coord"] -= center
319
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
320
+ data_dict["coord"] += center
321
+ if "normal" in data_dict.keys():
322
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
323
+ return data_dict
324
+
325
+
326
+ @TRANSFORMS.register_module()
327
+ class RandomScale(object):
328
+ def __init__(self, scale=None, anisotropic=False):
329
+ self.scale = scale if scale is not None else [0.95, 1.05]
330
+ self.anisotropic = anisotropic
331
+
332
+ def __call__(self, data_dict):
333
+ if "coord" in data_dict.keys():
334
+ scale = np.random.uniform(
335
+ self.scale[0], self.scale[1], 3 if self.anisotropic else 1
336
+ )
337
+ data_dict["coord"] *= scale
338
+ return data_dict
339
+
340
+
341
+ @TRANSFORMS.register_module()
342
+ class RandomFlip(object):
343
+ def __init__(self, p=0.5):
344
+ self.p = p
345
+
346
+ def __call__(self, data_dict):
347
+ if np.random.rand() < self.p:
348
+ if "coord" in data_dict.keys():
349
+ data_dict["coord"][:, 0] = -data_dict["coord"][:, 0]
350
+ if "normal" in data_dict.keys():
351
+ data_dict["normal"][:, 0] = -data_dict["normal"][:, 0]
352
+ if np.random.rand() < self.p:
353
+ if "coord" in data_dict.keys():
354
+ data_dict["coord"][:, 1] = -data_dict["coord"][:, 1]
355
+ if "normal" in data_dict.keys():
356
+ data_dict["normal"][:, 1] = -data_dict["normal"][:, 1]
357
+ return data_dict
358
+
359
+
360
+ @TRANSFORMS.register_module()
361
+ class RandomJitter(object):
362
+ def __init__(self, sigma=0.01, clip=0.05):
363
+ assert clip > 0
364
+ self.sigma = sigma
365
+ self.clip = clip
366
+
367
+ def __call__(self, data_dict):
368
+ if "coord" in data_dict.keys():
369
+ jitter = np.clip(
370
+ self.sigma * np.random.randn(data_dict["coord"].shape[0], 3),
371
+ -self.clip,
372
+ self.clip,
373
+ )
374
+ data_dict["coord"] += jitter
375
+ return data_dict
376
+
377
+
378
+ @TRANSFORMS.register_module()
379
+ class ClipGaussianJitter(object):
380
+ def __init__(self, scalar=0.02, store_jitter=False):
381
+ self.scalar = scalar
382
+ self.mean = np.mean(3)
383
+ self.cov = np.identity(3)
384
+ self.quantile = 1.96
385
+ self.store_jitter = store_jitter
386
+
387
+ def __call__(self, data_dict):
388
+ if "coord" in data_dict.keys():
389
+ jitter = np.random.multivariate_normal(
390
+ self.mean, self.cov, data_dict["coord"].shape[0]
391
+ )
392
+ jitter = self.scalar * np.clip(jitter / 1.96, -1, 1)
393
+ data_dict["coord"] += jitter
394
+ if self.store_jitter:
395
+ data_dict["jitter"] = jitter
396
+ return data_dict
397
+
398
+
399
+ @TRANSFORMS.register_module()
400
+ class ChromaticAutoContrast(object):
401
+ def __init__(self, p=0.2, blend_factor=None):
402
+ self.p = p
403
+ self.blend_factor = blend_factor
404
+
405
+ def __call__(self, data_dict):
406
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
407
+ lo = np.min(data_dict["color"], 0, keepdims=True)
408
+ hi = np.max(data_dict["color"], 0, keepdims=True)
409
+ scale = 255 / (hi - lo)
410
+ contrast_feat = (data_dict["color"][:, :3] - lo) * scale
411
+ blend_factor = (
412
+ np.random.rand() if self.blend_factor is None else self.blend_factor
413
+ )
414
+ data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][
415
+ :, :3
416
+ ] + blend_factor * contrast_feat
417
+ return data_dict
418
+
419
+
420
+ @TRANSFORMS.register_module()
421
+ class ChromaticTranslation(object):
422
+ def __init__(self, p=0.95, ratio=0.05):
423
+ self.p = p
424
+ self.ratio = ratio
425
+
426
+ def __call__(self, data_dict):
427
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
428
+ tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio
429
+ data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255)
430
+ return data_dict
431
+
432
+
433
+ @TRANSFORMS.register_module()
434
+ class ChromaticJitter(object):
435
+ def __init__(self, p=0.95, std=0.005):
436
+ self.p = p
437
+ self.std = std
438
+
439
+ def __call__(self, data_dict):
440
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
441
+ noise = np.random.randn(data_dict["color"].shape[0], 3)
442
+ noise *= self.std * 255
443
+ data_dict["color"][:, :3] = np.clip(
444
+ noise + data_dict["color"][:, :3], 0, 255
445
+ )
446
+ return data_dict
447
+
448
+
449
+ @TRANSFORMS.register_module()
450
+ class RandomColorGrayScale(object):
451
+ def __init__(self, p):
452
+ self.p = p
453
+
454
+ @staticmethod
455
+ def rgb_to_grayscale(color, num_output_channels=1):
456
+ if color.shape[-1] < 3:
457
+ raise TypeError(
458
+ "Input color should have at least 3 dimensions, but found {}".format(
459
+ color.shape[-1]
460
+ )
461
+ )
462
+
463
+ if num_output_channels not in (1, 3):
464
+ raise ValueError("num_output_channels should be either 1 or 3")
465
+
466
+ r, g, b = color[..., 0], color[..., 1], color[..., 2]
467
+ gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype)
468
+ gray = np.expand_dims(gray, axis=-1)
469
+
470
+ if num_output_channels == 3:
471
+ gray = np.broadcast_to(gray, color.shape)
472
+
473
+ return gray
474
+
475
+ def __call__(self, data_dict):
476
+ if np.random.rand() < self.p:
477
+ data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3)
478
+ return data_dict
479
+
480
+
481
+ @TRANSFORMS.register_module()
482
+ class RandomColorJitter(object):
483
+ """
484
+ Random Color Jitter for 3D point cloud (refer torchvision)
485
+ """
486
+
487
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95):
488
+ self.brightness = self._check_input(brightness, "brightness")
489
+ self.contrast = self._check_input(contrast, "contrast")
490
+ self.saturation = self._check_input(saturation, "saturation")
491
+ self.hue = self._check_input(
492
+ hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
493
+ )
494
+ self.p = p
495
+
496
+ @staticmethod
497
+ def _check_input(
498
+ value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
499
+ ):
500
+ if isinstance(value, numbers.Number):
501
+ if value < 0:
502
+ raise ValueError(
503
+ "If {} is a single number, it must be non negative.".format(name)
504
+ )
505
+ value = [center - float(value), center + float(value)]
506
+ if clip_first_on_zero:
507
+ value[0] = max(value[0], 0.0)
508
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
509
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
510
+ raise ValueError("{} values should be between {}".format(name, bound))
511
+ else:
512
+ raise TypeError(
513
+ "{} should be a single number or a list/tuple with length 2.".format(
514
+ name
515
+ )
516
+ )
517
+
518
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
519
+ # or (0., 0.) for hue, do nothing
520
+ if value[0] == value[1] == center:
521
+ value = None
522
+ return value
523
+
524
+ @staticmethod
525
+ def blend(color1, color2, ratio):
526
+ ratio = float(ratio)
527
+ bound = 255.0
528
+ return (
529
+ (ratio * color1 + (1.0 - ratio) * color2)
530
+ .clip(0, bound)
531
+ .astype(color1.dtype)
532
+ )
533
+
534
+ @staticmethod
535
+ def rgb2hsv(rgb):
536
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
537
+ maxc = np.max(rgb, axis=-1)
538
+ minc = np.min(rgb, axis=-1)
539
+ eqc = maxc == minc
540
+ cr = maxc - minc
541
+ s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc))
542
+ cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc)
543
+ rc = (maxc - r) / cr_divisor
544
+ gc = (maxc - g) / cr_divisor
545
+ bc = (maxc - b) / cr_divisor
546
+
547
+ hr = (maxc == r) * (bc - gc)
548
+ hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
549
+ hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
550
+ h = hr + hg + hb
551
+ h = (h / 6.0 + 1.0) % 1.0
552
+ return np.stack((h, s, maxc), axis=-1)
553
+
554
+ @staticmethod
555
+ def hsv2rgb(hsv):
556
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
557
+ i = np.floor(h * 6.0)
558
+ f = (h * 6.0) - i
559
+ i = i.astype(np.int32)
560
+
561
+ p = np.clip((v * (1.0 - s)), 0.0, 1.0)
562
+ q = np.clip((v * (1.0 - s * f)), 0.0, 1.0)
563
+ t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
564
+ i = i % 6
565
+ mask = np.expand_dims(i, axis=-1) == np.arange(6)
566
+
567
+ a1 = np.stack((v, q, p, p, t, v), axis=-1)
568
+ a2 = np.stack((t, v, v, q, p, p), axis=-1)
569
+ a3 = np.stack((p, p, t, v, v, q), axis=-1)
570
+ a4 = np.stack((a1, a2, a3), axis=-1)
571
+
572
+ return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4)
573
+
574
+ def adjust_brightness(self, color, brightness_factor):
575
+ if brightness_factor < 0:
576
+ raise ValueError(
577
+ "brightness_factor ({}) is not non-negative.".format(brightness_factor)
578
+ )
579
+
580
+ return self.blend(color, np.zeros_like(color), brightness_factor)
581
+
582
+ def adjust_contrast(self, color, contrast_factor):
583
+ if contrast_factor < 0:
584
+ raise ValueError(
585
+ "contrast_factor ({}) is not non-negative.".format(contrast_factor)
586
+ )
587
+ mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color))
588
+ return self.blend(color, mean, contrast_factor)
589
+
590
+ def adjust_saturation(self, color, saturation_factor):
591
+ if saturation_factor < 0:
592
+ raise ValueError(
593
+ "saturation_factor ({}) is not non-negative.".format(saturation_factor)
594
+ )
595
+ gray = RandomColorGrayScale.rgb_to_grayscale(color)
596
+ return self.blend(color, gray, saturation_factor)
597
+
598
+ def adjust_hue(self, color, hue_factor):
599
+ if not (-0.5 <= hue_factor <= 0.5):
600
+ raise ValueError(
601
+ "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)
602
+ )
603
+ orig_dtype = color.dtype
604
+ hsv = self.rgb2hsv(color / 255.0)
605
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
606
+ h = (h + hue_factor) % 1.0
607
+ hsv = np.stack((h, s, v), axis=-1)
608
+ color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype)
609
+ return color_hue_adj
610
+
611
+ @staticmethod
612
+ def get_params(brightness, contrast, saturation, hue):
613
+ fn_idx = torch.randperm(4)
614
+ b = (
615
+ None
616
+ if brightness is None
617
+ else np.random.uniform(brightness[0], brightness[1])
618
+ )
619
+ c = None if contrast is None else np.random.uniform(contrast[0], contrast[1])
620
+ s = (
621
+ None
622
+ if saturation is None
623
+ else np.random.uniform(saturation[0], saturation[1])
624
+ )
625
+ h = None if hue is None else np.random.uniform(hue[0], hue[1])
626
+ return fn_idx, b, c, s, h
627
+
628
+ def __call__(self, data_dict):
629
+ (
630
+ fn_idx,
631
+ brightness_factor,
632
+ contrast_factor,
633
+ saturation_factor,
634
+ hue_factor,
635
+ ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
636
+
637
+ for fn_id in fn_idx:
638
+ if (
639
+ fn_id == 0
640
+ and brightness_factor is not None
641
+ and np.random.rand() < self.p
642
+ ):
643
+ data_dict["color"] = self.adjust_brightness(
644
+ data_dict["color"], brightness_factor
645
+ )
646
+ elif (
647
+ fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p
648
+ ):
649
+ data_dict["color"] = self.adjust_contrast(
650
+ data_dict["color"], contrast_factor
651
+ )
652
+ elif (
653
+ fn_id == 2
654
+ and saturation_factor is not None
655
+ and np.random.rand() < self.p
656
+ ):
657
+ data_dict["color"] = self.adjust_saturation(
658
+ data_dict["color"], saturation_factor
659
+ )
660
+ elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p:
661
+ data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor)
662
+ return data_dict
663
+
664
+
665
+ @TRANSFORMS.register_module()
666
+ class HueSaturationTranslation(object):
667
+ @staticmethod
668
+ def rgb_to_hsv(rgb):
669
+ # Translated from source of colorsys.rgb_to_hsv
670
+ # r,g,b should be a numpy arrays with values between 0 and 255
671
+ # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
672
+ rgb = rgb.astype("float")
673
+ hsv = np.zeros_like(rgb)
674
+ # in case an RGBA array was passed, just copy the A channel
675
+ hsv[..., 3:] = rgb[..., 3:]
676
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
677
+ maxc = np.max(rgb[..., :3], axis=-1)
678
+ minc = np.min(rgb[..., :3], axis=-1)
679
+ hsv[..., 2] = maxc
680
+ mask = maxc != minc
681
+ hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
682
+ rc = np.zeros_like(r)
683
+ gc = np.zeros_like(g)
684
+ bc = np.zeros_like(b)
685
+ rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
686
+ gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
687
+ bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
688
+ hsv[..., 0] = np.select(
689
+ [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc
690
+ )
691
+ hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
692
+ return hsv
693
+
694
+ @staticmethod
695
+ def hsv_to_rgb(hsv):
696
+ # Translated from source of colorsys.hsv_to_rgb
697
+ # h,s should be a numpy arrays with values between 0.0 and 1.0
698
+ # v should be a numpy array with values between 0.0 and 255.0
699
+ # hsv_to_rgb returns an array of uints between 0 and 255.
700
+ rgb = np.empty_like(hsv)
701
+ rgb[..., 3:] = hsv[..., 3:]
702
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
703
+ i = (h * 6.0).astype("uint8")
704
+ f = (h * 6.0) - i
705
+ p = v * (1.0 - s)
706
+ q = v * (1.0 - s * f)
707
+ t = v * (1.0 - s * (1.0 - f))
708
+ i = i % 6
709
+ conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
710
+ rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
711
+ rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
712
+ rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
713
+ return rgb.astype("uint8")
714
+
715
+ def __init__(self, hue_max=0.5, saturation_max=0.2):
716
+ self.hue_max = hue_max
717
+ self.saturation_max = saturation_max
718
+
719
+ def __call__(self, data_dict):
720
+ if "color" in data_dict.keys():
721
+ # Assume color[:, :3] is rgb
722
+ hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3])
723
+ hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max
724
+ sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max
725
+ hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
726
+ hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
727
+ data_dict["color"][:, :3] = np.clip(
728
+ HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255
729
+ )
730
+ return data_dict
731
+
732
+
733
+ @TRANSFORMS.register_module()
734
+ class RandomColorDrop(object):
735
+ def __init__(self, p=0.2, color_augment=0.0):
736
+ self.p = p
737
+ self.color_augment = color_augment
738
+
739
+ def __call__(self, data_dict):
740
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
741
+ data_dict["color"] *= self.color_augment
742
+ return data_dict
743
+
744
+ def __repr__(self):
745
+ return "RandomColorDrop(color_augment: {}, p: {})".format(
746
+ self.color_augment, self.p
747
+ )
748
+
749
+
750
+ @TRANSFORMS.register_module()
751
+ class ElasticDistortion(object):
752
+ def __init__(self, distortion_params=None):
753
+ self.distortion_params = (
754
+ [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params
755
+ )
756
+
757
+ @staticmethod
758
+ def elastic_distortion(coords, granularity, magnitude):
759
+ """
760
+ Apply elastic distortion on sparse coordinate space.
761
+ pointcloud: numpy array of (number of points, at least 3 spatial dims)
762
+ granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
763
+ magnitude: noise multiplier
764
+ """
765
+ blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
766
+ blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
767
+ blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
768
+ coords_min = coords.min(0)
769
+
770
+ # Create Gaussian noise tensor of the size given by granularity.
771
+ noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
772
+ noise = np.random.randn(*noise_dim, 3).astype(np.float32)
773
+
774
+ # Smoothing.
775
+ for _ in range(2):
776
+ noise = scipy.ndimage.filters.convolve(
777
+ noise, blurx, mode="constant", cval=0
778
+ )
779
+ noise = scipy.ndimage.filters.convolve(
780
+ noise, blury, mode="constant", cval=0
781
+ )
782
+ noise = scipy.ndimage.filters.convolve(
783
+ noise, blurz, mode="constant", cval=0
784
+ )
785
+
786
+ # Trilinear interpolate noise filters for each spatial dimensions.
787
+ ax = [
788
+ np.linspace(d_min, d_max, d)
789
+ for d_min, d_max, d in zip(
790
+ coords_min - granularity,
791
+ coords_min + granularity * (noise_dim - 2),
792
+ noise_dim,
793
+ )
794
+ ]
795
+ interp = scipy.interpolate.RegularGridInterpolator(
796
+ ax, noise, bounds_error=False, fill_value=0
797
+ )
798
+ coords += interp(coords) * magnitude
799
+ return coords
800
+
801
+ def __call__(self, data_dict):
802
+ if "coord" in data_dict.keys() and self.distortion_params is not None:
803
+ if random.random() < 0.95:
804
+ for granularity, magnitude in self.distortion_params:
805
+ data_dict["coord"] = self.elastic_distortion(
806
+ data_dict["coord"], granularity, magnitude
807
+ )
808
+ return data_dict
809
+
810
+
811
+ @TRANSFORMS.register_module()
812
+ class GridSample(object):
813
+ def __init__(
814
+ self,
815
+ grid_size=0.05,
816
+ hash_type="fnv",
817
+ mode="train",
818
+ return_inverse=False,
819
+ return_grid_coord=False,
820
+ return_min_coord=False,
821
+ return_displacement=False,
822
+ project_displacement=False,
823
+ ):
824
+ self.grid_size = grid_size
825
+ self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec
826
+ assert mode in ["train", "test"]
827
+ self.mode = mode
828
+ self.return_inverse = return_inverse
829
+ self.return_grid_coord = return_grid_coord
830
+ self.return_min_coord = return_min_coord
831
+ self.return_displacement = return_displacement
832
+ self.project_displacement = project_displacement
833
+
834
+ def __call__(self, data_dict):
835
+ assert "coord" in data_dict.keys()
836
+ scaled_coord = data_dict["coord"] / np.array(self.grid_size)
837
+ grid_coord = np.floor(scaled_coord).astype(int)
838
+ min_coord = grid_coord.min(0)
839
+ grid_coord -= min_coord
840
+ scaled_coord -= min_coord
841
+ min_coord = min_coord * np.array(self.grid_size)
842
+ key = self.hash(grid_coord)
843
+ idx_sort = np.argsort(key)
844
+ key_sort = key[idx_sort]
845
+ _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
846
+ if self.mode == "train": # train mode
847
+ idx_select = (
848
+ np.cumsum(np.insert(count, 0, 0)[0:-1])
849
+ + np.random.randint(0, count.max(), count.size) % count
850
+ )
851
+ idx_unique = idx_sort[idx_select]
852
+ if "sampled_index" in data_dict:
853
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
854
+ idx_unique = np.unique(
855
+ np.append(idx_unique, data_dict["sampled_index"])
856
+ )
857
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
858
+ mask[data_dict["sampled_index"]] = True
859
+ data_dict["sampled_index"] = np.where(mask[idx_unique])[0]
860
+ data_dict = index_operator(data_dict, idx_unique)
861
+ if self.return_inverse:
862
+ data_dict["inverse"] = np.zeros_like(inverse)
863
+ data_dict["inverse"][idx_sort] = inverse
864
+ if self.return_grid_coord:
865
+ data_dict["grid_coord"] = grid_coord[idx_unique]
866
+ data_dict["index_valid_keys"].append("grid_coord")
867
+ if self.return_min_coord:
868
+ data_dict["min_coord"] = min_coord.reshape([1, 3])
869
+ if self.return_displacement:
870
+ displacement = (
871
+ scaled_coord - grid_coord - 0.5
872
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
873
+ if self.project_displacement:
874
+ displacement = np.sum(
875
+ displacement * data_dict["normal"], axis=-1, keepdims=True
876
+ )
877
+ data_dict["displacement"] = displacement[idx_unique]
878
+ data_dict["index_valid_keys"].append("displacement")
879
+ return data_dict
880
+
881
+ elif self.mode == "test": # test mode
882
+ data_part_list = []
883
+ for i in range(count.max()):
884
+ idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
885
+ idx_part = idx_sort[idx_select]
886
+ data_part = index_operator(data_dict, idx_part, duplicate=True)
887
+ data_part["index"] = idx_part
888
+ if self.return_inverse:
889
+ data_part["inverse"] = np.zeros_like(inverse)
890
+ data_part["inverse"][idx_sort] = inverse
891
+ if self.return_grid_coord:
892
+ data_part["grid_coord"] = grid_coord[idx_part]
893
+ data_dict["index_valid_keys"].append("grid_coord")
894
+ if self.return_min_coord:
895
+ data_part["min_coord"] = min_coord.reshape([1, 3])
896
+ if self.return_displacement:
897
+ displacement = (
898
+ scaled_coord - grid_coord - 0.5
899
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
900
+ if self.project_displacement:
901
+ displacement = np.sum(
902
+ displacement * data_dict["normal"], axis=-1, keepdims=True
903
+ )
904
+ data_dict["displacement"] = displacement[idx_part]
905
+ data_dict["index_valid_keys"].append("displacement")
906
+ data_part_list.append(data_part)
907
+ return data_part_list
908
+ else:
909
+ raise NotImplementedError
910
+
911
+ @staticmethod
912
+ def ravel_hash_vec(arr):
913
+ """
914
+ Ravel the coordinates after subtracting the min coordinates.
915
+ """
916
+ assert arr.ndim == 2
917
+ arr = arr.copy()
918
+ arr -= arr.min(0)
919
+ arr = arr.astype(np.uint64, copy=False)
920
+ arr_max = arr.max(0).astype(np.uint64) + 1
921
+
922
+ keys = np.zeros(arr.shape[0], dtype=np.uint64)
923
+ # Fortran style indexing
924
+ for j in range(arr.shape[1] - 1):
925
+ keys += arr[:, j]
926
+ keys *= arr_max[j + 1]
927
+ keys += arr[:, -1]
928
+ return keys
929
+
930
+ @staticmethod
931
+ def fnv_hash_vec(arr):
932
+ """
933
+ FNV64-1A
934
+ """
935
+ assert arr.ndim == 2
936
+ # Floor first for negative coordinates
937
+ arr = arr.copy()
938
+ arr = arr.astype(np.uint64, copy=False)
939
+ hashed_arr = np.uint64(14695981039346656037) * np.ones(
940
+ arr.shape[0], dtype=np.uint64
941
+ )
942
+ for j in range(arr.shape[1]):
943
+ hashed_arr *= np.uint64(1099511628211)
944
+ hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
945
+ return hashed_arr
946
+
947
+
948
+ @TRANSFORMS.register_module()
949
+ class SphereCrop(object):
950
+ def __init__(self, point_max=80000, sample_rate=None, mode="random"):
951
+ self.point_max = point_max
952
+ self.sample_rate = sample_rate
953
+ assert mode in ["random", "center", "all"]
954
+ self.mode = mode
955
+
956
+ def __call__(self, data_dict):
957
+ point_max = (
958
+ int(self.sample_rate * data_dict["coord"].shape[0])
959
+ if self.sample_rate is not None
960
+ else self.point_max
961
+ )
962
+
963
+ assert "coord" in data_dict.keys()
964
+ if data_dict["coord"].shape[0] > point_max:
965
+ if self.mode == "random":
966
+ center = data_dict["coord"][
967
+ np.random.randint(data_dict["coord"].shape[0])
968
+ ]
969
+ elif self.mode == "center":
970
+ center = data_dict["coord"][data_dict["coord"].shape[0] // 2]
971
+ else:
972
+ raise NotImplementedError
973
+ idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[
974
+ :point_max
975
+ ]
976
+ data_dict = index_operator(data_dict, idx_crop)
977
+ return data_dict
978
+
979
+
980
+ @TRANSFORMS.register_module()
981
+ class ShufflePoint(object):
982
+ def __call__(self, data_dict):
983
+ assert "coord" in data_dict.keys()
984
+ shuffle_index = np.arange(data_dict["coord"].shape[0])
985
+ np.random.shuffle(shuffle_index)
986
+ data_dict = index_operator(data_dict, shuffle_index)
987
+ return data_dict
988
+
989
+
990
+ @TRANSFORMS.register_module()
991
+ class CropBoundary(object):
992
+ def __call__(self, data_dict):
993
+ assert "segment" in data_dict
994
+ segment = data_dict["segment"].flatten()
995
+ mask = (segment != 0) * (segment != 1)
996
+ data_dict = index_operator(data_dict, mask)
997
+ return data_dict
998
+
999
+
1000
+ @TRANSFORMS.register_module()
1001
+ class ContrastiveViewsGenerator(object):
1002
+ def __init__(
1003
+ self,
1004
+ view_keys=("coord", "color", "normal", "origin_coord"),
1005
+ view_trans_cfg=None,
1006
+ ):
1007
+ self.view_keys = view_keys
1008
+ self.view_trans = Compose(view_trans_cfg)
1009
+
1010
+ def __call__(self, data_dict):
1011
+ view1_dict = dict()
1012
+ view2_dict = dict()
1013
+ for key in self.view_keys:
1014
+ view1_dict[key] = data_dict[key].copy()
1015
+ view2_dict[key] = data_dict[key].copy()
1016
+ view1_dict = self.view_trans(view1_dict)
1017
+ view2_dict = self.view_trans(view2_dict)
1018
+ for key, value in view1_dict.items():
1019
+ data_dict["view1_" + key] = value
1020
+ for key, value in view2_dict.items():
1021
+ data_dict["view2_" + key] = value
1022
+ return data_dict
1023
+
1024
+
1025
+ @TRANSFORMS.register_module()
1026
+ class MultiViewGenerator(object):
1027
+ def __init__(
1028
+ self,
1029
+ global_view_num=2,
1030
+ global_view_scale=(0.4, 1.0),
1031
+ local_view_num=4,
1032
+ local_view_scale=(0.1, 0.4),
1033
+ global_shared_transform=None,
1034
+ global_transform=None,
1035
+ local_transform=None,
1036
+ max_size=65536,
1037
+ center_height_scale=(0, 1),
1038
+ shared_global_view=False,
1039
+ view_keys=("coord", "origin_coord", "color", "normal"),
1040
+ ):
1041
+ self.global_view_num = global_view_num
1042
+ self.global_view_scale = global_view_scale
1043
+ self.local_view_num = local_view_num
1044
+ self.local_view_scale = local_view_scale
1045
+ self.global_shared_transform = Compose(global_shared_transform)
1046
+ self.global_transform = Compose(global_transform)
1047
+ self.local_transform = Compose(local_transform)
1048
+ self.max_size = max_size
1049
+ self.center_height_scale = center_height_scale
1050
+ self.shared_global_view = shared_global_view
1051
+ self.view_keys = view_keys
1052
+ assert "coord" in view_keys
1053
+
1054
+ def get_view(self, point, center, scale):
1055
+ coord = point["coord"]
1056
+ max_size = min(self.max_size, coord.shape[0])
1057
+ size = int(np.random.uniform(*scale) * max_size)
1058
+ index = np.argsort(np.sum(np.square(coord - center), axis=-1))[:size]
1059
+ view = dict(index=index)
1060
+ for key in point.keys():
1061
+ if key in self.view_keys:
1062
+ view[key] = point[key][index]
1063
+
1064
+ if "index_valid_keys" in point.keys():
1065
+ # inherit index_valid_keys from point
1066
+ view["index_valid_keys"] = point["index_valid_keys"]
1067
+ return view
1068
+
1069
+ def __call__(self, data_dict):
1070
+ coord = data_dict["coord"]
1071
+ point = self.global_shared_transform(copy.deepcopy(data_dict))
1072
+ z_min = coord[:, 2].min()
1073
+ z_max = coord[:, 2].max()
1074
+ z_min_ = z_min + (z_max - z_min) * self.center_height_scale[0]
1075
+ z_max_ = z_min + (z_max - z_min) * self.center_height_scale[1]
1076
+ center_mask = np.logical_and(coord[:, 2] >= z_min_, coord[:, 2] <= z_max_)
1077
+ # get major global view
1078
+ major_center = coord[np.random.choice(np.where(center_mask)[0])]
1079
+ major_view = self.get_view(point, major_center, self.global_view_scale)
1080
+ major_coord = major_view["coord"]
1081
+ # get global views: restrict the center of left global view within the major global view
1082
+ if not self.shared_global_view:
1083
+ global_views = [
1084
+ self.get_view(
1085
+ point=point,
1086
+ center=major_coord[np.random.randint(major_coord.shape[0])],
1087
+ scale=self.global_view_scale,
1088
+ )
1089
+ for _ in range(self.global_view_num - 1)
1090
+ ]
1091
+ else:
1092
+ global_views = [
1093
+ {key: value.copy() for key, value in major_view.items()}
1094
+ for _ in range(self.global_view_num - 1)
1095
+ ]
1096
+
1097
+ global_views = [major_view] + global_views
1098
+
1099
+ # get local views: restrict the center of local view within the major global view
1100
+ cover_mask = np.zeros_like(major_view["index"], dtype=bool)
1101
+ local_views = []
1102
+ for i in range(self.local_view_num):
1103
+ if sum(~cover_mask) == 0:
1104
+ # reset cover mask if all points are sampled
1105
+ cover_mask[:] = False
1106
+ local_view = self.get_view(
1107
+ point=data_dict,
1108
+ center=major_coord[np.random.choice(np.where(~cover_mask)[0])],
1109
+ scale=self.local_view_scale,
1110
+ )
1111
+ local_views.append(local_view)
1112
+ cover_mask[np.isin(major_view["index"], local_view["index"])] = True
1113
+
1114
+ # augmentation and concat
1115
+ view_dict = {}
1116
+ for global_view in global_views:
1117
+ global_view.pop("index")
1118
+ global_view = self.global_transform(global_view)
1119
+ for key in self.view_keys:
1120
+ if f"global_{key}" in view_dict.keys():
1121
+ view_dict[f"global_{key}"].append(global_view[key])
1122
+ else:
1123
+ view_dict[f"global_{key}"] = [global_view[key]]
1124
+ view_dict["global_offset"] = np.cumsum(
1125
+ [data.shape[0] for data in view_dict["global_coord"]]
1126
+ )
1127
+ for local_view in local_views:
1128
+ local_view.pop("index")
1129
+ local_view = self.local_transform(local_view)
1130
+ for key in self.view_keys:
1131
+ if f"local_{key}" in view_dict.keys():
1132
+ view_dict[f"local_{key}"].append(local_view[key])
1133
+ else:
1134
+ view_dict[f"local_{key}"] = [local_view[key]]
1135
+ view_dict["local_offset"] = np.cumsum(
1136
+ [data.shape[0] for data in view_dict["local_coord"]]
1137
+ )
1138
+ for key in view_dict.keys():
1139
+ if "offset" not in key:
1140
+ view_dict[key] = np.concatenate(view_dict[key], axis=0)
1141
+ data_dict.update(view_dict)
1142
+ return data_dict
1143
+
1144
+
1145
+ @TRANSFORMS.register_module()
1146
+ class InstanceParser(object):
1147
+ def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1):
1148
+ self.segment_ignore_index = segment_ignore_index
1149
+ self.instance_ignore_index = instance_ignore_index
1150
+
1151
+ def __call__(self, data_dict):
1152
+ coord = data_dict["coord"]
1153
+ segment = data_dict["segment"]
1154
+ instance = data_dict["instance"]
1155
+ mask = ~np.in1d(segment, self.segment_ignore_index)
1156
+ # mapping ignored instance to ignore index
1157
+ instance[~mask] = self.instance_ignore_index
1158
+ # reorder left instance
1159
+ unique, inverse = np.unique(instance[mask], return_inverse=True)
1160
+ instance_num = len(unique)
1161
+ instance[mask] = inverse
1162
+ # init instance information
1163
+ centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index
1164
+ bbox = np.ones((instance_num, 8)) * self.instance_ignore_index
1165
+ vacancy = [
1166
+ index for index in self.segment_ignore_index if index >= 0
1167
+ ] # vacate class index
1168
+
1169
+ for instance_id in range(instance_num):
1170
+ mask_ = instance == instance_id
1171
+ coord_ = coord[mask_]
1172
+ bbox_min = coord_.min(0)
1173
+ bbox_max = coord_.max(0)
1174
+ bbox_centroid = coord_.mean(0)
1175
+ bbox_center = (bbox_max + bbox_min) / 2
1176
+ bbox_size = bbox_max - bbox_min
1177
+ bbox_theta = np.zeros(1, dtype=coord_.dtype)
1178
+ bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype)
1179
+ # shift class index to fill vacate class index caused by segment ignore index
1180
+ bbox_class -= np.greater(bbox_class, vacancy).sum()
1181
+
1182
+ centroid[mask_] = bbox_centroid
1183
+ bbox[instance_id] = np.concatenate(
1184
+ [bbox_center, bbox_size, bbox_theta, bbox_class]
1185
+ ) # 3 + 3 + 1 + 1 = 8
1186
+ data_dict["instance"] = instance
1187
+ data_dict["instance_centroid"] = centroid
1188
+ data_dict["bbox"] = bbox
1189
+ return data_dict
1190
+
1191
+
1192
+ class Compose(object):
1193
+ def __init__(self, cfg=None):
1194
+ self.cfg = cfg if cfg is not None else []
1195
+ self.transforms = []
1196
+ for t_cfg in self.cfg:
1197
+ self.transforms.append(TRANSFORMS.build(t_cfg))
1198
+
1199
+ def __call__(self, data_dict):
1200
+ for t in self.transforms:
1201
+ data_dict = t(data_dict)
1202
+ return data_dict
1203
+
1204
+
1205
+ def default():
1206
+ config = [
1207
+ dict(type="CenterShift", apply_z=True),
1208
+ dict(
1209
+ type="GridSample",
1210
+ grid_size=0.02,
1211
+ hash_type="fnv",
1212
+ mode="train",
1213
+ return_grid_coord=True,
1214
+ return_inverse=True,
1215
+ ),
1216
+ dict(type="NormalizeColor"),
1217
+ dict(type="ToTensor"),
1218
+ dict(
1219
+ type="Collect",
1220
+ keys=("coord", "grid_coord", "color", "inverse"),
1221
+ feat_keys=("coord", "color", "normal"),
1222
+ ),
1223
+ ]
1224
+ return Compose(config)
concerto/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ General utils
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import os
24
+ import random
25
+ import numpy as np
26
+ import torch
27
+ import torch.backends.cudnn as cudnn
28
+ from datetime import datetime
29
+
30
+
31
+ @torch.no_grad()
32
+ def offset2bincount(offset):
33
+ return torch.diff(
34
+ offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
35
+ )
36
+
37
+
38
+ @torch.no_grad()
39
+ def bincount2offset(bincount):
40
+ return torch.cumsum(bincount, dim=0)
41
+
42
+
43
+ @torch.no_grad()
44
+ def offset2batch(offset):
45
+ bincount = offset2bincount(offset)
46
+ return torch.arange(
47
+ len(bincount), device=offset.device, dtype=torch.long
48
+ ).repeat_interleave(bincount)
49
+
50
+
51
+ @torch.no_grad()
52
+ def batch2offset(batch):
53
+ return torch.cumsum(batch.bincount(), dim=0).long()
54
+
55
+
56
+ def get_random_seed():
57
+ seed = (
58
+ os.getpid()
59
+ + int(datetime.now().strftime("%S%f"))
60
+ + int.from_bytes(os.urandom(2), "big")
61
+ )
62
+ return seed
63
+
64
+
65
+ def set_seed(seed=None):
66
+ if seed is None:
67
+ seed = get_random_seed()
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ torch.cuda.manual_seed(seed)
72
+ torch.cuda.manual_seed_all(seed)
73
+ cudnn.benchmark = False
74
+ cudnn.deterministic = True
75
+ os.environ["PYTHONHASHSEED"] = str(seed)
example/pcd/hm3d_00012_kDgLKdMd5X8_2.png ADDED

Git LFS Details

  • SHA256: 1b1b6a1d26f7ce82794352f3218560af3be5de7b10ffc4a66c4ed0114e33383e
  • Pointer size: 131 Bytes
  • Size of remote file: 937 kB
example/pcd/hm3d_00113_3goH1WRaCYC.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5997349536508f246ebb488d8a532929e8d6bbe5c07ac9c4bef99960f17caa5e
3
+ size 34904372
example/pcd/hm3d_00113_3goH1WRaCYC.png ADDED

Git LFS Details

  • SHA256: dc6f8b029c9a91ea51198a0ef1d880abd49e89000f4d56e74a8ba749fbbc27a0
  • Pointer size: 131 Bytes
  • Size of remote file: 834 kB
example/pcd/s3dis_Area2_auditorium1.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e75d63230786cf90bb81c77c376a3eba3856096edca75f4bb63c0535a14a1785
3
+ size 66995588
example/pcd/s3dis_Area2_auditorium1.png ADDED

Git LFS Details

  • SHA256: ec45fdd97a6fe4471a9c67b8e3bdb6b0221057b218c03fd22dbfa2455e6ae813
  • Pointer size: 132 Bytes
  • Size of remote file: 1.1 MB
example/pcd/s3dis_Area4_lobby1.png ADDED

Git LFS Details

  • SHA256: 9860aeac9f9fc78c326fbe7df02fdf7c02f7624928cae668e08ec89774c52958
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
example/pcd/scannet_0024.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53becbb97f65dd8b3b45b369a6045c4c7c8c0662f6197fb670fbd21f4f1f7635
3
+ size 4376723
example/pcd/scannet_0024.png ADDED

Git LFS Details

  • SHA256: ecf31bd143cb3451c0f81063dbc60a68e881caed3a24f7903c94892d0a56372e
  • Pointer size: 131 Bytes
  • Size of remote file: 641 kB
example/pcd/scannet_0603.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc15cf8c98ec7a228db65174256da0b5b15c05aabdf62405e9070c17963ea4f5
3
+ size 9694393
example/pcd/scannet_0603.png ADDED

Git LFS Details

  • SHA256: 95686e01bd0b8d69f1ca7be4b974406bafaae8e9fd443586263819e3baca852b
  • Pointer size: 131 Bytes
  • Size of remote file: 765 kB
example/video/re10k_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb32bf8a7e69dfa3a241623453dc2150b411a7544ed6c03827afc6acff9b812f
3
+ size 1435843
example/video/re10k_2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb80da2e8cd828ae6d438193a6e98be1ad974b2a0ce9145ef483442284b58f6f
3
+ size 12798788
example/video/re10k_3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95adb5f43a53ee86da5d3de4d4d21d4bf939e90dbc8c4104f2cddfcd17cef155
3
+ size 4275086
example/video/re10k_4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16f1622452362e2e67fb954494b40f958366720be6c69c29a7f35af5485211a3
3
+ size 1978088
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base Python packages
2
+ numpy<=1.26.4
3
+ scipy
4
+ addict
5
+ timm
6
+ psutil
7
+ huggingface_hub
8
+ opencv-python-headless
9
+ einops
10
+
11
+ # PyTorch and CPU version for torch 2.5.0
12
+ --find-links https://data.pyg.org/whl/torch-2.5.0+cpu.html
13
+ torch==2.5.0
14
+ torchvision==0.20.0
15
+ torchaudio==2.5.0
16
+
17
+ # Extra pip dependencies that depend on PyTorch
18
+ torch-scatter
19
+ spconv
20
+
21
+ # flash-attn
22
+
23
+ # visualization
24
+ trimesh
25
+ camtools
26
+ open3d
27
+ matplotlib
28
+
29
+ # Optional: If you need `ninja` as a build tool
30
+ ninja
31
+
32
+ # # Concerto (local package)
33
+ # -e .
setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import pkg_resources
18
+ from setuptools import setup, find_packages
19
+
20
+
21
+ setup(
22
+ name="concerto",
23
+ py_modules=["concerto"],
24
+ version="1.0",
25
+ description="",
26
+ author="Yujia Zhang",
27
+ packages=find_packages(exclude=["demo*"]),
28
+ include_package_data=True,
29
+ )
vggt/__init__.py ADDED
File without changes
vggt/heads/camera_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
73
+
74
+ # Adaptive layer normalization without affine parameters.
75
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
76
+ self.pose_branch = Mlp(
77
+ in_features=dim_in,
78
+ hidden_features=dim_in // 2,
79
+ out_features=self.target_dim,
80
+ drop=0,
81
+ )
82
+
83
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
84
+ """
85
+ Forward pass to predict camera parameters.
86
+
87
+ Args:
88
+ aggregated_tokens_list (list): List of token tensors from the network;
89
+ the last tensor is used for prediction.
90
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
91
+
92
+ Returns:
93
+ list: A list of predicted camera encodings (post-activation) from each iteration.
94
+ """
95
+ # Use tokens from the last block for camera prediction.
96
+ tokens = aggregated_tokens_list[-1]
97
+
98
+ # Extract the camera tokens
99
+ pose_tokens = tokens[:, :, 0]
100
+ pose_tokens = self.token_norm(pose_tokens)
101
+
102
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
103
+ return pred_pose_enc_list
104
+
105
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
106
+ """
107
+ Iteratively refine camera pose predictions.
108
+
109
+ Args:
110
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
111
+ num_iterations (int): Number of refinement iterations.
112
+
113
+ Returns:
114
+ list: List of activated camera encodings from each iteration.
115
+ """
116
+ B, S, C = pose_tokens.shape # S is expected to be 1.
117
+ pred_pose_enc = None
118
+ pred_pose_enc_list = []
119
+
120
+ for _ in range(num_iterations):
121
+ # Use a learned empty pose for the first iteration.
122
+ if pred_pose_enc is None:
123
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
124
+ else:
125
+ # Detach the previous prediction to avoid backprop through time.
126
+ pred_pose_enc = pred_pose_enc.detach()
127
+ module_input = self.embed_pose(pred_pose_enc)
128
+
129
+ # Generate modulation parameters and split them into shift, scale, and gate components.
130
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
131
+
132
+ # Adaptive layer normalization and modulation.
133
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
134
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
135
+
136
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
137
+ # Compute the delta update for the pose encoding.
138
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
139
+
140
+ if pred_pose_enc is None:
141
+ pred_pose_enc = pred_pose_enc_delta
142
+ else:
143
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
144
+
145
+ # Apply final activation functions for translation, quaternion, and field-of-view.
146
+ activated_pose = activate_pose(
147
+ pred_pose_enc,
148
+ trans_act=self.trans_act,
149
+ quat_act=self.quat_act,
150
+ fl_act=self.fl_act,
151
+ )
152
+ pred_pose_enc_list.append(activated_pose)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx]
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ # Select frames if processing a chunk
222
+ if frames_start_idx is not None and frames_end_idx is not None:
223
+ x = x[:, frames_start_idx:frames_end_idx]
224
+
225
+ x = x.view(B * S, -1, x.shape[-1])
226
+
227
+ x = self.norm(x)
228
+
229
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
230
+
231
+ x = self.projects[dpt_idx](x)
232
+ if self.pos_embed:
233
+ x = self._apply_pos_embed(x, W, H)
234
+ x = self.resize_layers[dpt_idx](x)
235
+
236
+ out.append(x)
237
+ dpt_idx += 1
238
+
239
+ # Fuse features from multiple layers.
240
+ out = self.scratch_forward(out)
241
+ # Interpolate fused output to match target image resolution.
242
+ out = custom_interpolate(
243
+ out,
244
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+
249
+ if self.pos_embed:
250
+ out = self._apply_pos_embed(out, W, H)
251
+
252
+ if self.feature_only:
253
+ return out.view(B, S, *out.shape[1:])
254
+
255
+ out = self.scratch.output_conv2(out)
256
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
257
+
258
+ preds = preds.view(B, S, *preds.shape[1:])
259
+ conf = conf.view(B, S, *conf.shape[1:])
260
+ return preds, conf
261
+
262
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
263
+ """
264
+ Apply positional embedding to tensor x.
265
+ """
266
+ patch_w = x.shape[-1]
267
+ patch_h = x.shape[-2]
268
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
269
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
270
+ pos_embed = pos_embed * ratio
271
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
272
+ return x + pos_embed
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+ ################################################################################
308
+ # Modules
309
+ ################################################################################
310
+
311
+
312
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
313
+ return FeatureFusionBlock(
314
+ features,
315
+ nn.ReLU(inplace=True),
316
+ deconv=False,
317
+ bn=False,
318
+ expand=False,
319
+ align_corners=True,
320
+ size=size,
321
+ has_residual=has_residual,
322
+ groups=groups,
323
+ )
324
+
325
+
326
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
327
+ scratch = nn.Module()
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(
342
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ scratch.layer2_rn = nn.Conv2d(
345
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer3_rn = nn.Conv2d(
348
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ if len(in_shape) >= 4:
351
+ scratch.layer4_rn = nn.Conv2d(
352
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
353
+ )
354
+ return scratch
355
+
356
+
357
+ class ResidualConvUnit(nn.Module):
358
+ """Residual convolution module."""
359
+
360
+ def __init__(self, features, activation, bn, groups=1):
361
+ """Init.
362
+
363
+ Args:
364
+ features (int): number of features
365
+ """
366
+ super().__init__()
367
+
368
+ self.bn = bn
369
+ self.groups = groups
370
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
371
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+
373
+ self.norm1 = None
374
+ self.norm2 = None
375
+
376
+ self.activation = activation
377
+ self.skip_add = nn.quantized.FloatFunctional()
378
+
379
+ def forward(self, x):
380
+ """Forward pass.
381
+
382
+ Args:
383
+ x (tensor): input
384
+
385
+ Returns:
386
+ tensor: output
387
+ """
388
+
389
+ out = self.activation(x)
390
+ out = self.conv1(out)
391
+ if self.norm1 is not None:
392
+ out = self.norm1(out)
393
+
394
+ out = self.activation(out)
395
+ out = self.conv2(out)
396
+ if self.norm2 is not None:
397
+ out = self.norm2(out)
398
+
399
+ return self.skip_add.add(out, x)
400
+
401
+
402
+ class FeatureFusionBlock(nn.Module):
403
+ """Feature fusion block."""
404
+
405
+ def __init__(
406
+ self,
407
+ features,
408
+ activation,
409
+ deconv=False,
410
+ bn=False,
411
+ expand=False,
412
+ align_corners=True,
413
+ size=None,
414
+ has_residual=True,
415
+ groups=1,
416
+ ):
417
+ """Init.
418
+
419
+ Args:
420
+ features (int): number of features
421
+ """
422
+ super(FeatureFusionBlock, self).__init__()
423
+
424
+ self.deconv = deconv
425
+ self.align_corners = align_corners
426
+ self.groups = groups
427
+ self.expand = expand
428
+ out_features = features
429
+ if self.expand == True:
430
+ out_features = features // 2
431
+
432
+ self.out_conv = nn.Conv2d(
433
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
434
+ )
435
+
436
+ if has_residual:
437
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
438
+
439
+ self.has_residual = has_residual
440
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.skip_add = nn.quantized.FloatFunctional()
443
+ self.size = size
444
+
445
+ def forward(self, *xs, size=None):
446
+ """Forward pass.
447
+
448
+ Returns:
449
+ tensor: output
450
+ """
451
+ output = xs[0]
452
+
453
+ if self.has_residual:
454
+ res = self.resConfUnit1(xs[1])
455
+ output = self.skip_add.add(output, res)
456
+
457
+ output = self.resConfUnit2(output)
458
+
459
+ if (size is None) and (self.size is None):
460
+ modifier = {"scale_factor": 2}
461
+ elif size is None:
462
+ modifier = {"size": self.size}
463
+ else:
464
+ modifier = {"size": size}
465
+
466
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
467
+ output = self.out_conv(output)
468
+
469
+ return output
470
+
471
+
472
+ def custom_interpolate(
473
+ x: torch.Tensor,
474
+ size: Tuple[int, int] = None,
475
+ scale_factor: float = None,
476
+ mode: str = "bilinear",
477
+ align_corners: bool = True,
478
+ ) -> torch.Tensor:
479
+ """
480
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
481
+ """
482
+ if size is None:
483
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
484
+
485
+ INT_MAX = 1610612736
486
+
487
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
488
+
489
+ if input_elements > INT_MAX:
490
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
491
+ interpolated_chunks = [
492
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
493
+ ]
494
+ x = torch.cat(interpolated_chunks, dim=0)
495
+ return x.contiguous()
496
+ else:
497
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
vggt/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
vggt/heads/track_head.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(
103
+ query_points=query_points,
104
+ fmaps=feature_maps,
105
+ iters=iters,
106
+ )
107
+
108
+ return coord_preds, vis_scores, conf_scores
vggt/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
vggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ from .blocks import EfficientUpdateFormer, CorrBlock
13
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
14
+ from .modules import Mlp
15
+
16
+
17
+ class BaseTrackerPredictor(nn.Module):
18
+ def __init__(
19
+ self,
20
+ stride=1,
21
+ corr_levels=5,
22
+ corr_radius=4,
23
+ latent_dim=128,
24
+ hidden_size=384,
25
+ use_spaceatt=True,
26
+ depth=6,
27
+ max_scale=518,
28
+ predict_conf=True,
29
+ ):
30
+ super(BaseTrackerPredictor, self).__init__()
31
+ """
32
+ The base template to create a track predictor
33
+
34
+ Modified from https://github.com/facebookresearch/co-tracker/
35
+ and https://github.com/facebookresearch/vggsfm
36
+ """
37
+
38
+ self.stride = stride
39
+ self.latent_dim = latent_dim
40
+ self.corr_levels = corr_levels
41
+ self.corr_radius = corr_radius
42
+ self.hidden_size = hidden_size
43
+ self.max_scale = max_scale
44
+ self.predict_conf = predict_conf
45
+
46
+ self.flows_emb_dim = latent_dim // 2
47
+
48
+ self.corr_mlp = Mlp(
49
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
50
+ hidden_features=self.hidden_size,
51
+ out_features=self.latent_dim,
52
+ )
53
+
54
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
55
+
56
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
57
+
58
+ space_depth = depth if use_spaceatt else 0
59
+ time_depth = depth
60
+
61
+ self.updateformer = EfficientUpdateFormer(
62
+ space_depth=space_depth,
63
+ time_depth=time_depth,
64
+ input_dim=self.transformer_dim,
65
+ hidden_size=self.hidden_size,
66
+ output_dim=self.latent_dim + 2,
67
+ mlp_ratio=4.0,
68
+ add_space_attn=use_spaceatt,
69
+ )
70
+
71
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
72
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
73
+
74
+ # A linear layer to update track feats at each iteration
75
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
76
+
77
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
78
+
79
+ if predict_conf:
80
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
81
+
82
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
83
+ """
84
+ query_points: B x N x 2, the number of batches, tracks, and xy
85
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
86
+ note HH and WW is the size of feature maps instead of original images
87
+ """
88
+ B, N, D = query_points.shape
89
+ B, S, C, HH, WW = fmaps.shape
90
+
91
+ assert D == 2, "Input points must be 2D coordinates"
92
+
93
+ # apply a layernorm to fmaps here
94
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
95
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
96
+
97
+ # Scale the input query_points because we may downsample the images
98
+ # by down_ratio or self.stride
99
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
100
+ # its query_points should be query_points/4
101
+ if down_ratio > 1:
102
+ query_points = query_points / float(down_ratio)
103
+
104
+ query_points = query_points / float(self.stride)
105
+
106
+ # Init with coords as the query points
107
+ # It means the search will start from the position of query points at the reference frames
108
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
109
+
110
+ # Sample/extract the features of the query points in the query frame
111
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
112
+
113
+ # init track feats by query feats
114
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
115
+ # back up the init coords
116
+ coords_backup = coords.clone()
117
+
118
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
119
+
120
+ coord_preds = []
121
+
122
+ # Iterative Refinement
123
+ for _ in range(iters):
124
+ # Detach the gradients from the last iteration
125
+ # (in my experience, not very important for performance)
126
+ coords = coords.detach()
127
+
128
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
129
+
130
+ corr_dim = fcorrs.shape[3]
131
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
132
+ fcorrs_ = self.corr_mlp(fcorrs_)
133
+
134
+ # Movement of current coords relative to query points
135
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
136
+
137
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
138
+
139
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
140
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
141
+
142
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
143
+
144
+ # Concatenate them as the input for the transformers
145
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
146
+
147
+ # 2D positional embed
148
+ # TODO: this can be much simplified
149
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
150
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
151
+
152
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
153
+
154
+ x = transformer_input + sampled_pos_emb
155
+
156
+ # Add the query ref token to the track feats
157
+ query_ref_token = torch.cat(
158
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
159
+ )
160
+ x = x + query_ref_token.to(x.device).to(x.dtype)
161
+
162
+ # B, N, S, C
163
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
164
+
165
+ # Compute the delta coordinates and delta track features
166
+ delta, _ = self.updateformer(x)
167
+
168
+ # BN, S, C
169
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
170
+ delta_coords_ = delta[:, :, :2]
171
+ delta_feats_ = delta[:, :, 2:]
172
+
173
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
174
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
175
+
176
+ # Update the track features
177
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
178
+
179
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
180
+
181
+ # B x S x N x 2
182
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
183
+
184
+ # Force coord0 as query
185
+ # because we assume the query points should not be changed
186
+ coords[:, 0] = coords_backup[:, 0]
187
+
188
+ # The predicted tracks are in the original image scale
189
+ if down_ratio > 1:
190
+ coord_preds.append(coords * self.stride * down_ratio)
191
+ else:
192
+ coord_preds.append(coords * self.stride)
193
+
194
+ # B, S, N
195
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
196
+ if apply_sigmoid:
197
+ vis_e = torch.sigmoid(vis_e)
198
+
199
+ if self.predict_conf:
200
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
201
+ if apply_sigmoid:
202
+ conf_e = torch.sigmoid(conf_e)
203
+ else:
204
+ conf_e = None
205
+
206
+ if return_feat:
207
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
208
+ else:
209
+ return coord_preds, vis_e, conf_e
vggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import bilinear_sampler
16
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
17
+
18
+
19
+ class EfficientUpdateFormer(nn.Module):
20
+ """
21
+ Transformer model that updates track estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ space_depth=6,
27
+ time_depth=6,
28
+ input_dim=320,
29
+ hidden_size=384,
30
+ num_heads=8,
31
+ output_dim=130,
32
+ mlp_ratio=4.0,
33
+ add_space_attn=True,
34
+ num_virtual_tracks=64,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.out_channels = 2
39
+ self.num_heads = num_heads
40
+ self.hidden_size = hidden_size
41
+ self.add_space_attn = add_space_attn
42
+
43
+ # Add input LayerNorm before linear projection
44
+ self.input_norm = nn.LayerNorm(input_dim)
45
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
46
+
47
+ # Add output LayerNorm before final projection
48
+ self.output_norm = nn.LayerNorm(hidden_size)
49
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
50
+ self.num_virtual_tracks = num_virtual_tracks
51
+
52
+ if self.add_space_attn:
53
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
54
+ else:
55
+ self.virual_tracks = None
56
+
57
+ self.time_blocks = nn.ModuleList(
58
+ [
59
+ AttnBlock(
60
+ hidden_size,
61
+ num_heads,
62
+ mlp_ratio=mlp_ratio,
63
+ attn_class=nn.MultiheadAttention,
64
+ )
65
+ for _ in range(time_depth)
66
+ ]
67
+ )
68
+
69
+ if add_space_attn:
70
+ self.space_virtual_blocks = nn.ModuleList(
71
+ [
72
+ AttnBlock(
73
+ hidden_size,
74
+ num_heads,
75
+ mlp_ratio=mlp_ratio,
76
+ attn_class=nn.MultiheadAttention,
77
+ )
78
+ for _ in range(space_depth)
79
+ ]
80
+ )
81
+ self.space_point2virtual_blocks = nn.ModuleList(
82
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
83
+ )
84
+ self.space_virtual2point_blocks = nn.ModuleList(
85
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
86
+ )
87
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
88
+ self.initialize_weights()
89
+
90
+ def initialize_weights(self):
91
+ def _basic_init(module):
92
+ if isinstance(module, nn.Linear):
93
+ torch.nn.init.xavier_uniform_(module.weight)
94
+ if module.bias is not None:
95
+ nn.init.constant_(module.bias, 0)
96
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
97
+
98
+ self.apply(_basic_init)
99
+
100
+ def forward(self, input_tensor, mask=None):
101
+ # Apply input LayerNorm
102
+ input_tensor = self.input_norm(input_tensor)
103
+ tokens = self.input_transform(input_tensor)
104
+
105
+ init_tokens = tokens
106
+
107
+ B, _, T, _ = tokens.shape
108
+
109
+ if self.add_space_attn:
110
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
111
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
112
+
113
+ _, N, _, _ = tokens.shape
114
+
115
+ j = 0
116
+ for i in range(len(self.time_blocks)):
117
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
118
+
119
+ time_tokens = self.time_blocks[i](time_tokens)
120
+
121
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
122
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
123
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
124
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
125
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
126
+
127
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
128
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
129
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
130
+
131
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
132
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
133
+ j += 1
134
+
135
+ if self.add_space_attn:
136
+ tokens = tokens[:, : N - self.num_virtual_tracks]
137
+
138
+ tokens = tokens + init_tokens
139
+
140
+ # Apply output LayerNorm before final projection
141
+ tokens = self.output_norm(tokens)
142
+ flow = self.flow_head(tokens)
143
+
144
+ return flow, None
145
+
146
+
147
+ class CorrBlock:
148
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
149
+ """
150
+ Build a pyramid of feature maps from the input.
151
+
152
+ fmaps: Tensor (B, S, C, H, W)
153
+ num_levels: number of pyramid levels (each downsampled by factor 2)
154
+ radius: search radius for sampling correlation
155
+ multiple_track_feats: if True, split the target features per pyramid level
156
+ padding_mode: passed to grid_sample / bilinear_sampler
157
+ """
158
+ B, S, C, H, W = fmaps.shape
159
+ self.S, self.C, self.H, self.W = S, C, H, W
160
+ self.num_levels = num_levels
161
+ self.radius = radius
162
+ self.padding_mode = padding_mode
163
+ self.multiple_track_feats = multiple_track_feats
164
+
165
+ # Build pyramid: each level is half the spatial resolution of the previous
166
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
167
+ current_fmaps = fmaps
168
+ for i in range(num_levels - 1):
169
+ B, S, C, H, W = current_fmaps.shape
170
+ # Merge batch & sequence dimensions
171
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
172
+ # Avg pool down by factor 2
173
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
174
+ _, _, H_new, W_new = current_fmaps.shape
175
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
176
+ self.fmaps_pyramid.append(current_fmaps)
177
+
178
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
179
+ # This grid is added to the (scaled) coordinate centroids.
180
+ r = self.radius
181
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
182
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
183
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
184
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
185
+
186
+ def corr_sample(self, targets, coords):
187
+ """
188
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
189
+ volume, sample it immediately, then discard it. This saves GPU memory.
190
+
191
+ Args:
192
+ targets: Tensor (B, S, N, C) — features for the current targets.
193
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
194
+
195
+ Returns:
196
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
197
+ """
198
+ B, S, N, C = targets.shape
199
+
200
+ # If you have multiple track features, split them per level.
201
+ if self.multiple_track_feats:
202
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
203
+
204
+ out_pyramid = []
205
+ for i, fmaps in enumerate(self.fmaps_pyramid):
206
+ # Get current spatial resolution H, W for this pyramid level.
207
+ B, S, C, H, W = fmaps.shape
208
+ # Reshape feature maps for correlation computation:
209
+ # fmap2s: (B, S, C, H*W)
210
+ fmap2s = fmaps.view(B, S, C, H * W)
211
+ # Choose appropriate target features.
212
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
213
+
214
+ # Compute correlation directly
215
+ corrs = compute_corr_level(fmap1, fmap2s, C)
216
+ corrs = corrs.view(B, S, N, H, W)
217
+
218
+ # Prepare sampling grid:
219
+ # Scale down the coordinates for the current level.
220
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
221
+ # Make sure our precomputed delta grid is on the same device/dtype.
222
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
223
+ # Now the grid for grid_sample is:
224
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
225
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
226
+
227
+ # Sample from the correlation volume using bilinear interpolation.
228
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
229
+ corrs_sampled = bilinear_sampler(
230
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
231
+ )
232
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
233
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
234
+ out_pyramid.append(corrs_sampled)
235
+
236
+ # Concatenate all levels along the last dimension.
237
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
238
+ return out
239
+
240
+
241
+ def compute_corr_level(fmap1, fmap2s, C):
242
+ # fmap1: (B, S, N, C)
243
+ # fmap2s: (B, S, C, H*W)
244
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
245
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
246
+ return corrs / math.sqrt(C)
vggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes,
49
+ planes,
50
+ kernel_size=kernel_size,
51
+ padding=1,
52
+ stride=stride,
53
+ padding_mode="zeros",
54
+ )
55
+ self.conv2 = nn.Conv2d(
56
+ planes,
57
+ planes,
58
+ kernel_size=kernel_size,
59
+ padding=1,
60
+ padding_mode="zeros",
61
+ )
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ num_groups = planes // 8
65
+
66
+ if norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
69
+ if not stride == 1:
70
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
71
+
72
+ elif norm_fn == "batch":
73
+ self.norm1 = nn.BatchNorm2d(planes)
74
+ self.norm2 = nn.BatchNorm2d(planes)
75
+ if not stride == 1:
76
+ self.norm3 = nn.BatchNorm2d(planes)
77
+
78
+ elif norm_fn == "instance":
79
+ self.norm1 = nn.InstanceNorm2d(planes)
80
+ self.norm2 = nn.InstanceNorm2d(planes)
81
+ if not stride == 1:
82
+ self.norm3 = nn.InstanceNorm2d(planes)
83
+
84
+ elif norm_fn == "none":
85
+ self.norm1 = nn.Sequential()
86
+ self.norm2 = nn.Sequential()
87
+ if not stride == 1:
88
+ self.norm3 = nn.Sequential()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ if stride == 1:
93
+ self.downsample = None
94
+ else:
95
+ self.downsample = nn.Sequential(
96
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
97
+ self.norm3,
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = x
102
+ y = self.relu(self.norm1(self.conv1(y)))
103
+ y = self.relu(self.norm2(self.conv2(y)))
104
+
105
+ if self.downsample is not None:
106
+ x = self.downsample(x)
107
+
108
+ return self.relu(x + y)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_features,
117
+ hidden_features=None,
118
+ out_features=None,
119
+ act_layer=nn.GELU,
120
+ norm_layer=None,
121
+ bias=True,
122
+ drop=0.0,
123
+ use_conv=False,
124
+ ):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ bias = to_2tuple(bias)
129
+ drop_probs = to_2tuple(drop)
130
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
131
+
132
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
133
+ self.act = act_layer()
134
+ self.drop1 = nn.Dropout(drop_probs[0])
135
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136
+ self.drop2 = nn.Dropout(drop_probs[1])
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = self.act(x)
141
+ x = self.drop1(x)
142
+ x = self.fc2(x)
143
+ x = self.drop2(x)
144
+ return x
145
+
146
+
147
+ class AttnBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_size,
151
+ num_heads,
152
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
153
+ mlp_ratio=4.0,
154
+ **block_kwargs
155
+ ):
156
+ """
157
+ Self attention block
158
+ """
159
+ super().__init__()
160
+
161
+ self.norm1 = nn.LayerNorm(hidden_size)
162
+ self.norm2 = nn.LayerNorm(hidden_size)
163
+
164
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
165
+
166
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
167
+
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
169
+
170
+ def forward(self, x, mask=None):
171
+ # Prepare the mask for PyTorch's attention (it expects a different format)
172
+ # attn_mask = mask if mask is not None else None
173
+ # Normalize before attention
174
+ x = self.norm1(x)
175
+
176
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
177
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
178
+
179
+ attn_output, _ = self.attn(x, x, x)
180
+
181
+ # Add & Norm
182
+ x = x + attn_output
183
+ x = x + self.mlp(self.norm2(x))
184
+ return x
185
+
186
+
187
+ class CrossAttnBlock(nn.Module):
188
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
189
+ """
190
+ Cross attention block
191
+ """
192
+ super().__init__()
193
+
194
+ self.norm1 = nn.LayerNorm(hidden_size)
195
+ self.norm_context = nn.LayerNorm(hidden_size)
196
+ self.norm2 = nn.LayerNorm(hidden_size)
197
+
198
+ self.cross_attn = nn.MultiheadAttention(
199
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
200
+ )
201
+
202
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
203
+
204
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
205
+
206
+ def forward(self, x, context, mask=None):
207
+ # Normalize inputs
208
+ x = self.norm1(x)
209
+ context = self.norm_context(context)
210
+
211
+ # Apply cross attention
212
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
213
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
214
+
215
+ # Add & Norm
216
+ x = x + attn_output
217
+ x = x + self.mlp(self.norm2(x))
218
+ return x
vggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (
40
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
41
+ grid,
42
+ )
43
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
44
+
45
+
46
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
49
+
50
+ Args:
51
+ - embed_dim: The embedding dimension.
52
+ - grid: The grid to generate the embedding from.
53
+
54
+ Returns:
55
+ - emb: The generated 2D positional embedding.
56
+ """
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
70
+
71
+ Args:
72
+ - embed_dim: The embedding dimension.
73
+ - pos: The position to generate the embedding from.
74
+
75
+ Returns:
76
+ - emb: The generated 1D positional embedding.
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = torch.sin(out) # (M, D/2)
87
+ emb_cos = torch.cos(out) # (M, D/2)
88
+
89
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
90
+ return emb[None].float()
91
+
92
+
93
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
94
+ """
95
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
96
+
97
+ Args:
98
+ - xy: The coordinates to generate the embedding from.
99
+ - C: The size of the embedding.
100
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
101
+
102
+ Returns:
103
+ - pe: The generated 2D positional embedding.
104
+ """
105
+ B, N, D = xy.shape
106
+ assert D == 2
107
+
108
+ x = xy[:, :, 0:1]
109
+ y = xy[:, :, 1:2]
110
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
111
+
112
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
114
+
115
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
116
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
117
+
118
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
119
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
120
+
121
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
122
+ if cat_coords:
123
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
124
+ return pe
125
+
126
+
127
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
128
+ r"""Sample a tensor using bilinear interpolation
129
+
130
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
131
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
132
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
133
+ convention.
134
+
135
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
136
+ :math:`B` is the batch size, :math:`C` is the number of channels,
137
+ :math:`H` is the height of the image, and :math:`W` is the width of the
138
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
139
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
140
+
141
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
142
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
143
+ that in this case the order of the components is slightly different
144
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
145
+
146
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
147
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
148
+ left-most image pixel :math:`W-1` to the center of the right-most
149
+ pixel.
150
+
151
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
152
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
153
+ the left-most pixel :math:`W` to the right edge of the right-most
154
+ pixel.
155
+
156
+ Similar conventions apply to the :math:`y` for the range
157
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
158
+ :math:`[0,T-1]` and :math:`[0,T]`.
159
+
160
+ Args:
161
+ input (Tensor): batch of input images.
162
+ coords (Tensor): batch of coordinates.
163
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
164
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
165
+
166
+ Returns:
167
+ Tensor: sampled points.
168
+ """
169
+ coords = coords.detach().clone()
170
+ ############################################################
171
+ # IMPORTANT:
172
+ coords = coords.to(input.device).to(input.dtype)
173
+ ############################################################
174
+
175
+ sizes = input.shape[2:]
176
+
177
+ assert len(sizes) in [2, 3]
178
+
179
+ if len(sizes) == 3:
180
+ # t x y -> x y t to match dimensions T H W in grid_sample
181
+ coords = coords[..., [1, 2, 0]]
182
+
183
+ if align_corners:
184
+ scale = torch.tensor(
185
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
186
+ )
187
+ else:
188
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
189
+
190
+ coords.mul_(scale) # coords = coords * scale
191
+ coords.sub_(1) # coords = coords - 1
192
+
193
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
194
+
195
+
196
+ def sample_features4d(input, coords):
197
+ r"""Sample spatial features
198
+
199
+ `sample_features4d(input, coords)` samples the spatial features
200
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
201
+
202
+ The field is sampled at coordinates :attr:`coords` using bilinear
203
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
204
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
205
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
206
+
207
+ The output tensor has one feature per point, and has shape :math:`(B,
208
+ R, C)`.
209
+
210
+ Args:
211
+ input (Tensor): spatial features.
212
+ coords (Tensor): points.
213
+
214
+ Returns:
215
+ Tensor: sampled features.
216
+ """
217
+
218
+ B, _, _, _ = input.shape
219
+
220
+ # B R 2 -> B R 1 2
221
+ coords = coords.unsqueeze(2)
222
+
223
+ # B C R 1
224
+ feats = bilinear_sampler(input, coords)
225
+
226
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
vggt/heads/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
vggt/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
vggt/layers/attention.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ XFORMERS_AVAILABLE = False
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def forward(self, x: Tensor, pos=None) -> Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+
56
+ if self.rope is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+
60
+ if self.fused_attn:
61
+ x = F.scaled_dot_product_attention(
62
+ q,
63
+ k,
64
+ v,
65
+ dropout_p=self.attn_drop.p if self.training else 0.0,
66
+ )
67
+ else:
68
+ q = q * self.scale
69
+ attn = q @ k.transpose(-2, -1)
70
+ attn = attn.softmax(dim=-1)
71
+ attn = self.attn_drop(attn)
72
+ x = attn @ v
73
+
74
+ x = x.transpose(1, 2).reshape(B, N, C)
75
+ x = self.proj(x)
76
+ x = self.proj_drop(x)
77
+ return x
78
+
79
+
80
+ class MemEffAttention(Attention):
81
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
82
+ assert pos is None
83
+ if not XFORMERS_AVAILABLE:
84
+ if attn_bias is not None:
85
+ raise AssertionError("xFormers is required for using nested tensors")
86
+ return super().forward(x)
87
+
88
+ B, N, C = x.shape
89
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
90
+
91
+ q, k, v = unbind(qkv, 2)
92
+
93
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
94
+ x = x.reshape([B, N, C])
95
+
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x
vggt/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim,
71
+ hidden_features=mlp_hidden_dim,
72
+ act_layer=act_layer,
73
+ drop=drop,
74
+ bias=ffn_bias,
75
+ )
76
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
77
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
78
+
79
+ self.sample_drop_ratio = drop_path
80
+
81
+ def forward(self, x: Tensor, pos=None) -> Tensor:
82
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
83
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
84
+
85
+ def ffn_residual_func(x: Tensor) -> Tensor:
86
+ return self.ls2(self.mlp(self.norm2(x)))
87
+
88
+ if self.training and self.sample_drop_ratio > 0.1:
89
+ # the overhead is compensated only for a drop path rate larger than 0.1
90
+ x = drop_add_residual_stochastic_depth(
91
+ x,
92
+ pos=pos,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x, pos=pos)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ pos=None,
115
+ ) -> Tensor:
116
+ # 1) extract subset using permutation
117
+ b, n, d = x.shape
118
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
119
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
120
+ x_subset = x[brange]
121
+
122
+ # 2) apply residual_func to get residual
123
+ if pos is not None:
124
+ # if necessary, apply rope to the subset
125
+ pos = pos[brange]
126
+ residual = residual_func(x_subset, pos=pos)
127
+ else:
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
vggt/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)