Migrated from GitHub
Browse files- CONTRIBUTING.md +34 -0
- LICENSE +228 -0
- acc_configs/gpu1.yaml +15 -0
- acc_configs/gpu4.yaml +15 -0
- acc_configs/gpu6.yaml +15 -0
- acc_configs/gpu8.yaml +15 -0
- assets/teaser.jpg +0 -0
- blender_scripts/render_objaverse.py +537 -0
- core/__init__.py +0 -0
- core/attention.py +85 -0
- core/gs.py +198 -0
- core/models.py +195 -0
- core/options.py +128 -0
- core/provider_objaverse_4d.py +254 -0
- core/provider_objaverse_4d_interp.py +283 -0
- core/unet.py +428 -0
- core/utils.py +109 -0
- data_test/000000_fg.mp4 +0 -0
- data_test/000070_fg.mp4 +0 -0
- data_test/000370_fg.mp4 +0 -0
- data_test/blooming_rose_fg.mp4 +0 -0
- data_test/cat_king_fg.mp4 +0 -0
- data_test/dancing_robot_fg.mp4 +0 -0
- data_test/lifting1_fg.mp4 +0 -0
- data_test/monster-with-melting-candle_fg.mp4 +0 -0
- data_test/otter-on-surfboard_fg.mp4 +0 -0
- data_test/sighing_frog_fg.mp4 +0 -0
- environment.yml +32 -0
- infer_3d.py +233 -0
- infer_4d.py +392 -0
- main.py +274 -0
- mvdream/mv_unet.py +1005 -0
- mvdream/pipeline_mvdream.py +559 -0
- readme.md +77 -0
- requirements.txt +20 -0
CONTRIBUTING.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Developer Certificate of Origin
|
2 |
+
Version 1.1
|
3 |
+
|
4 |
+
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
5 |
+
|
6 |
+
Everyone is permitted to copy and distribute verbatim copies of this
|
7 |
+
license document, but changing it is not allowed.
|
8 |
+
|
9 |
+
|
10 |
+
Developer's Certificate of Origin 1.1
|
11 |
+
|
12 |
+
By making a contribution to this project, I certify that:
|
13 |
+
|
14 |
+
(a) The contribution was created in whole or in part by me and I
|
15 |
+
have the right to submit it under the open source license
|
16 |
+
indicated in the file; or
|
17 |
+
|
18 |
+
(b) The contribution is based upon previous work that, to the best
|
19 |
+
of my knowledge, is covered under an appropriate open source
|
20 |
+
license and I have the right under that license to submit that
|
21 |
+
work with modifications, whether created in whole or in part
|
22 |
+
by me, under the same open source license (unless I am
|
23 |
+
permitted to submit under a different license), as indicated
|
24 |
+
in the file; or
|
25 |
+
|
26 |
+
(c) The contribution was provided directly to me by some other
|
27 |
+
person who certified (a), (b) or (c) and I have not modified
|
28 |
+
it.
|
29 |
+
|
30 |
+
(d) I understand and agree that this project and the contribution
|
31 |
+
are public and that a record of the contribution (including all
|
32 |
+
personal information I submit with it, including my sign-off) is
|
33 |
+
maintained indefinitely and may be redistributed consistent with
|
34 |
+
this project or the open source license(s) involved.
|
LICENSE
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
Copyright 2024 NVIDIA Corporation
|
180 |
+
|
181 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
182 |
+
you may not use this file except in compliance with the License.
|
183 |
+
You may obtain a copy of the License at
|
184 |
+
|
185 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
186 |
+
|
187 |
+
Unless required by applicable law or agreed to in writing, software
|
188 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
189 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
190 |
+
See the License for the specific language governing permissions and
|
191 |
+
limitations under the License.
|
192 |
+
|
193 |
+
|
194 |
+
PORTIONS LICENSED AS FOLLOWS
|
195 |
+
|
196 |
+
> core/utils.py
|
197 |
+
> mvdream/mv_unet.py
|
198 |
+
> mvdream/pipeline_mvdream.py
|
199 |
+
|
200 |
+
MIT License
|
201 |
+
|
202 |
+
Copyright (c) 2024 3D Topia
|
203 |
+
|
204 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
205 |
+
of this software and associated documentation files (the "Software"), to deal
|
206 |
+
in the Software without restriction, including without limitation the rights
|
207 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
208 |
+
copies of the Software, and to permit persons to whom the Software is
|
209 |
+
furnished to do so, subject to the following conditions:
|
210 |
+
|
211 |
+
The above copyright notice and this permission notice shall be included in all
|
212 |
+
copies or substantial portions of the Software.
|
213 |
+
|
214 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
215 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
216 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
217 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
218 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
219 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
220 |
+
SOFTWARE.
|
221 |
+
|
222 |
+
|
223 |
+
> core/attention.py
|
224 |
+
|
225 |
+
Copyright (c) Meta Platforms, Inc. and affiliates.
|
226 |
+
|
227 |
+
This source code is licensed under the Apache License, Version 2.0
|
228 |
+
found in the LICENSE file in the root directory of this source tree.
|
acc_configs/gpu1.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
machine_rank: 0
|
6 |
+
main_training_function: main
|
7 |
+
mixed_precision: bf16
|
8 |
+
num_machines: 1
|
9 |
+
num_processes: 1
|
10 |
+
rdzv_backend: static
|
11 |
+
same_network: true
|
12 |
+
tpu_env: []
|
13 |
+
tpu_use_cluster: false
|
14 |
+
tpu_use_sudo: false
|
15 |
+
use_cpu: false
|
acc_configs/gpu4.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
machine_rank: 0
|
6 |
+
main_training_function: main
|
7 |
+
mixed_precision: fp16
|
8 |
+
num_machines: 1
|
9 |
+
num_processes: 4
|
10 |
+
rdzv_backend: static
|
11 |
+
same_network: true
|
12 |
+
tpu_env: []
|
13 |
+
tpu_use_cluster: false
|
14 |
+
tpu_use_sudo: false
|
15 |
+
use_cpu: false
|
acc_configs/gpu6.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
machine_rank: 0
|
6 |
+
main_training_function: main
|
7 |
+
mixed_precision: fp16
|
8 |
+
num_machines: 1
|
9 |
+
num_processes: 6
|
10 |
+
rdzv_backend: static
|
11 |
+
same_network: true
|
12 |
+
tpu_env: []
|
13 |
+
tpu_use_cluster: false
|
14 |
+
tpu_use_sudo: false
|
15 |
+
use_cpu: false
|
acc_configs/gpu8.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: MULTI_GPU
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
machine_rank: 0
|
6 |
+
main_training_function: main
|
7 |
+
mixed_precision: bf16
|
8 |
+
num_machines: 1
|
9 |
+
num_processes: 8
|
10 |
+
rdzv_backend: static
|
11 |
+
same_network: true
|
12 |
+
tpu_env: []
|
13 |
+
tpu_use_cluster: false
|
14 |
+
tpu_use_sudo: false
|
15 |
+
use_cpu: false
|
assets/teaser.jpg
ADDED
blender_scripts/render_objaverse.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import argparse, sys, os, math, re
|
17 |
+
import bpy
|
18 |
+
from mathutils import Vector, Matrix
|
19 |
+
import numpy as np
|
20 |
+
import cv2
|
21 |
+
import signal
|
22 |
+
from contextlib import contextmanager
|
23 |
+
from loguru import logger
|
24 |
+
from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple
|
25 |
+
import random
|
26 |
+
class TimeoutException(Exception): pass
|
27 |
+
|
28 |
+
logger.info('Rendering started.')
|
29 |
+
|
30 |
+
@contextmanager
|
31 |
+
def time_limit(seconds):
|
32 |
+
def signal_handler(signum, frame):
|
33 |
+
raise TimeoutException("Timed out!")
|
34 |
+
signal.signal(signal.SIGALRM, signal_handler)
|
35 |
+
signal.alarm(seconds)
|
36 |
+
try:
|
37 |
+
yield
|
38 |
+
finally:
|
39 |
+
signal.alarm(0)
|
40 |
+
|
41 |
+
parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
|
42 |
+
parser.add_argument(
|
43 |
+
'--seed', type=int, default=0,
|
44 |
+
help='number of views to be rendered')
|
45 |
+
parser.add_argument(
|
46 |
+
'--views', type=int, default=4,
|
47 |
+
help='number of views to be rendered')
|
48 |
+
parser.add_argument(
|
49 |
+
'obj', type=str,
|
50 |
+
help='Path to the obj file to be rendered.')
|
51 |
+
parser.add_argument(
|
52 |
+
'--output_folder', type=str, default='/tmp',
|
53 |
+
help='The path the output will be dumped to.')
|
54 |
+
parser.add_argument(
|
55 |
+
'--scale', type=float, default=1,
|
56 |
+
help='Scaling factor applied to model. Depends on size of mesh.')
|
57 |
+
parser.add_argument(
|
58 |
+
'--format', type=str, default='PNG',
|
59 |
+
help='Format of files generated. Either PNG or OPEN_EXR')
|
60 |
+
|
61 |
+
parser.add_argument(
|
62 |
+
'--resolution', type=int, default=512,
|
63 |
+
help='Resolution of the images.')
|
64 |
+
parser.add_argument(
|
65 |
+
'--engine', type=str, default='CYCLES',
|
66 |
+
help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
|
67 |
+
parser.add_argument(
|
68 |
+
'--gpu', type=int, default=0,
|
69 |
+
help='gpu.')
|
70 |
+
parser.add_argument(
|
71 |
+
'--animation_idx', type=int, default=0,
|
72 |
+
help='The index of animation')
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
'--camera_option', type=str, default='fixed',
|
76 |
+
help='Camera Options')
|
77 |
+
parser.add_argument(
|
78 |
+
'--fixed_animation_length', type=int, default=-1,
|
79 |
+
help='Set animation length to fixed number of framnes')
|
80 |
+
parser.add_argument(
|
81 |
+
'--step_angle', type=int, default=3,
|
82 |
+
help='Angle in degree for each step camera rotation')
|
83 |
+
parser.add_argument(
|
84 |
+
'--downsample', type=int, default=1,
|
85 |
+
help='Downsample ratio. No downsample by default')
|
86 |
+
|
87 |
+
argv = sys.argv[sys.argv.index("--") + 1:]
|
88 |
+
args = parser.parse_args(argv)
|
89 |
+
|
90 |
+
|
91 |
+
model_identifier = os.path.split(args.obj)[1].split('.')[0]
|
92 |
+
synset_idx = args.obj.split('/')[-2]
|
93 |
+
|
94 |
+
save_root = os.path.join(os.path.abspath(args.output_folder), synset_idx, model_identifier, f'{args.animation_idx:03d}')
|
95 |
+
|
96 |
+
# Set up rendering
|
97 |
+
context = bpy.context
|
98 |
+
scene = bpy.context.scene
|
99 |
+
render = bpy.context.scene.render
|
100 |
+
|
101 |
+
render.engine = args.engine# 'BLENDER_EEVEE'
|
102 |
+
render.image_settings.color_mode = 'RGBA' # ('RGB', 'RGBA', ...)
|
103 |
+
render.image_settings.file_format = args.format # ('PNG', 'OPEN_EXR', 'JPEG, ...)
|
104 |
+
render.resolution_x = args.resolution
|
105 |
+
render.resolution_y = args.resolution
|
106 |
+
render.resolution_percentage = 100
|
107 |
+
bpy.context.scene.cycles.filter_width = 0.01
|
108 |
+
bpy.context.scene.render.film_transparent = True
|
109 |
+
render_depth_normal = False
|
110 |
+
bpy.context.scene.cycles.device = 'GPU'
|
111 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
112 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
113 |
+
bpy.context.scene.cycles.transparent_max_bounces = 1
|
114 |
+
bpy.context.scene.cycles.transmission_bounces = 1
|
115 |
+
bpy.context.scene.cycles.samples = 16
|
116 |
+
bpy.context.scene.cycles.use_denoising = True
|
117 |
+
bpy.context.scene.cycles.denoiser = 'OPTIX'
|
118 |
+
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
119 |
+
bpy.context.scene.cycles.device = 'GPU'
|
120 |
+
|
121 |
+
|
122 |
+
def enable_cuda_devices():
|
123 |
+
prefs = bpy.context.preferences
|
124 |
+
cprefs = prefs.addons['cycles'].preferences
|
125 |
+
cprefs.get_devices()
|
126 |
+
# Attempt to set GPU device types if available
|
127 |
+
for compute_device_type in ('CUDA', 'OPENCL', 'NONE'):
|
128 |
+
try:
|
129 |
+
cprefs.compute_device_type = compute_device_type
|
130 |
+
print("Compute device selected: {0}".format(compute_device_type))
|
131 |
+
break
|
132 |
+
except TypeError:
|
133 |
+
pass
|
134 |
+
|
135 |
+
# Any CUDA/OPENCL devices?
|
136 |
+
acceleratedTypes = ['CUDA', 'OPENCL', 'OPTIX']
|
137 |
+
acceleratedTypes = ['CUDA', 'OPENCL']
|
138 |
+
accelerated = any(device.type in acceleratedTypes for device in cprefs.devices)
|
139 |
+
print('Accelerated render = {0}'.format(accelerated))
|
140 |
+
|
141 |
+
# If we have CUDA/OPENCL devices, enable only them, otherwise enable
|
142 |
+
# all devices (assumed to be CPU)
|
143 |
+
print(cprefs.devices)
|
144 |
+
for idx, device in enumerate(cprefs.devices):
|
145 |
+
device.use = (not accelerated or device.type in acceleratedTypes)# and idx == args.gpu
|
146 |
+
print('Device enabled ({type}) = {enabled}'.format(type=device.type, enabled=device.use))
|
147 |
+
return accelerated
|
148 |
+
|
149 |
+
enable_cuda_devices()
|
150 |
+
context.active_object.select_set(True)
|
151 |
+
bpy.ops.object.delete()
|
152 |
+
|
153 |
+
# Import textured mesh
|
154 |
+
bpy.ops.object.select_all(action='DESELECT')
|
155 |
+
|
156 |
+
try:
|
157 |
+
with time_limit(1000):
|
158 |
+
imported_object = bpy.ops.import_scene.gltf(filepath=args.obj, merge_vertices=True, guess_original_bind_pose=False, bone_heuristic="TEMPERANCE")
|
159 |
+
except TimeoutException as e:
|
160 |
+
print("Timed out finished!")
|
161 |
+
exit()
|
162 |
+
|
163 |
+
|
164 |
+
# count animated frames
|
165 |
+
animation_names = []
|
166 |
+
ending_frame_list = {}
|
167 |
+
for k in bpy.data.actions.keys():
|
168 |
+
matched_obj_name = ''
|
169 |
+
for obj in bpy.context.selected_objects:
|
170 |
+
if '_'+obj.name in k and len(obj.name) > len(matched_obj_name):
|
171 |
+
matched_obj_name = obj.name
|
172 |
+
a_name = k.replace('_'+matched_obj_name, '')
|
173 |
+
a = bpy.data.actions[k]
|
174 |
+
frame_start, frame_end = map(int, a.frame_range)
|
175 |
+
logger.info(f'{k} | frame start: {frame_start}, frame end: {frame_end} | fps: {bpy.context.scene.render.fps}')
|
176 |
+
if a_name not in animation_names:
|
177 |
+
animation_names.append(a_name)
|
178 |
+
ending_frame_list[a_name] = frame_end
|
179 |
+
else:
|
180 |
+
ending_frame_list[a_name] = max(frame_end, ending_frame_list[a_name])
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
selected_a_name = animation_names[args.animation_idx]
|
185 |
+
max_frame = ending_frame_list[selected_a_name]
|
186 |
+
for obj in bpy.context.selected_objects:
|
187 |
+
if obj.animation_data is not None:
|
188 |
+
obj_a_name = selected_a_name+'_'+obj.name
|
189 |
+
if obj_a_name in bpy.data.actions:
|
190 |
+
print('Found ', obj_a_name)
|
191 |
+
obj.animation_data.action = bpy.data.actions[obj_a_name]
|
192 |
+
else:
|
193 |
+
print('Miss ', obj_a_name)
|
194 |
+
|
195 |
+
num_frames = args.fixed_animation_length if args.fixed_animation_length != -1 else max_frame
|
196 |
+
num_frames = num_frames // args.downsample
|
197 |
+
|
198 |
+
if num_frames == 0:
|
199 |
+
print("No animation!")
|
200 |
+
exit()
|
201 |
+
|
202 |
+
# from https://github.com/allenai/objaverse-xl/blob/main/scripts/rendering/blender_script.py
|
203 |
+
def get_3x4_RT_matrix_from_blender(cam: bpy.types.Object):
|
204 |
+
"""Returns the 3x4 RT matrix from the given camera.
|
205 |
+
|
206 |
+
Taken from Zero123, which in turn was taken from
|
207 |
+
https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py
|
208 |
+
|
209 |
+
Args:
|
210 |
+
cam (bpy.types.Object): The camera object.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Matrix: The 3x4 RT matrix from the given camera.
|
214 |
+
"""
|
215 |
+
# Use matrix_world instead to account for all constraints
|
216 |
+
location, rotation = cam.matrix_world.decompose()[0:2]
|
217 |
+
R_world2bcam = rotation.to_matrix().transposed()
|
218 |
+
|
219 |
+
# Use location from matrix_world to account for constraints:
|
220 |
+
T_world2bcam = -1 * R_world2bcam @ location
|
221 |
+
|
222 |
+
# put into 3x4 matrix
|
223 |
+
RT = Matrix(
|
224 |
+
(
|
225 |
+
R_world2bcam[0][:] + (T_world2bcam[0],),
|
226 |
+
R_world2bcam[1][:] + (T_world2bcam[1],),
|
227 |
+
R_world2bcam[2][:] + (T_world2bcam[2],),
|
228 |
+
)
|
229 |
+
)
|
230 |
+
return RT
|
231 |
+
def _create_light(
|
232 |
+
name: str,
|
233 |
+
light_type: Literal["POINT", "SUN", "SPOT", "AREA"],
|
234 |
+
location: Tuple[float, float, float],
|
235 |
+
rotation: Tuple[float, float, float],
|
236 |
+
energy: float,
|
237 |
+
use_shadow: bool = False,
|
238 |
+
specular_factor: float = 1.0,
|
239 |
+
):
|
240 |
+
"""Creates a light object.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
name (str): Name of the light object.
|
244 |
+
light_type (Literal["POINT", "SUN", "SPOT", "AREA"]): Type of the light.
|
245 |
+
location (Tuple[float, float, float]): Location of the light.
|
246 |
+
rotation (Tuple[float, float, float]): Rotation of the light.
|
247 |
+
energy (float): Energy of the light.
|
248 |
+
use_shadow (bool, optional): Whether to use shadows. Defaults to False.
|
249 |
+
specular_factor (float, optional): Specular factor of the light. Defaults to 1.0.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
bpy.types.Object: The light object.
|
253 |
+
"""
|
254 |
+
|
255 |
+
light_data = bpy.data.lights.new(name=name, type=light_type)
|
256 |
+
light_object = bpy.data.objects.new(name, light_data)
|
257 |
+
bpy.context.collection.objects.link(light_object)
|
258 |
+
light_object.location = location
|
259 |
+
light_object.rotation_euler = rotation
|
260 |
+
light_data.use_shadow = use_shadow
|
261 |
+
light_data.specular_factor = specular_factor
|
262 |
+
light_data.energy = energy
|
263 |
+
return light_object
|
264 |
+
|
265 |
+
|
266 |
+
def randomize_lighting() -> Dict[str, bpy.types.Object]:
|
267 |
+
"""Randomizes the lighting in the scene.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
Dict[str, bpy.types.Object]: Dictionary of the lights in the scene. The keys are
|
271 |
+
"key_light", "fill_light", "rim_light", and "bottom_light".
|
272 |
+
"""
|
273 |
+
|
274 |
+
# Clear existing lights
|
275 |
+
bpy.ops.object.select_all(action="DESELECT")
|
276 |
+
bpy.ops.object.select_by_type(type="LIGHT")
|
277 |
+
bpy.ops.object.delete()
|
278 |
+
|
279 |
+
# Create key light
|
280 |
+
key_light = _create_light(
|
281 |
+
name="Key_Light",
|
282 |
+
light_type="SUN",
|
283 |
+
location=(0, 0, 0),
|
284 |
+
rotation=(0.785398, 0, -0.785398),
|
285 |
+
# energy=random.choice([3, 4, 5]),
|
286 |
+
energy=4,
|
287 |
+
)
|
288 |
+
|
289 |
+
# Create fill light
|
290 |
+
fill_light = _create_light(
|
291 |
+
name="Fill_Light",
|
292 |
+
light_type="SUN",
|
293 |
+
location=(0, 0, 0),
|
294 |
+
rotation=(0.785398, 0, 2.35619),
|
295 |
+
# energy=random.choice([2, 3, 4]),
|
296 |
+
energy=3,
|
297 |
+
)
|
298 |
+
|
299 |
+
# Create rim light
|
300 |
+
rim_light = _create_light(
|
301 |
+
name="Rim_Light",
|
302 |
+
light_type="SUN",
|
303 |
+
location=(0, 0, 0),
|
304 |
+
rotation=(-0.785398, 0, -3.92699),
|
305 |
+
# energy=random.choice([3, 4, 5]),
|
306 |
+
energy=4,
|
307 |
+
)
|
308 |
+
|
309 |
+
# Create bottom light
|
310 |
+
bottom_light = _create_light(
|
311 |
+
name="Bottom_Light",
|
312 |
+
light_type="SUN",
|
313 |
+
location=(0, 0, 0),
|
314 |
+
rotation=(3.14159, 0, 0),
|
315 |
+
# energy=random.choice([1, 2, 3]),
|
316 |
+
energy=2,
|
317 |
+
)
|
318 |
+
|
319 |
+
return dict(
|
320 |
+
key_light=key_light,
|
321 |
+
fill_light=fill_light,
|
322 |
+
rim_light=rim_light,
|
323 |
+
bottom_light=bottom_light,
|
324 |
+
)
|
325 |
+
|
326 |
+
def scene_bbox(
|
327 |
+
single_obj = None, ignore_matrix = False
|
328 |
+
):
|
329 |
+
"""Returns the bounding box of the scene.
|
330 |
+
|
331 |
+
Taken from Shap-E rendering script
|
332 |
+
(https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
|
333 |
+
|
334 |
+
Args:
|
335 |
+
single_obj (Optional[bpy.types.Object], optional): If not None, only computes
|
336 |
+
the bounding box for the given object. Defaults to None.
|
337 |
+
ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
|
338 |
+
to False.
|
339 |
+
|
340 |
+
Raises:
|
341 |
+
RuntimeError: If there are no objects in the scene.
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
|
345 |
+
"""
|
346 |
+
bbox_min = (math.inf,) * 3
|
347 |
+
bbox_max = (-math.inf,) * 3
|
348 |
+
found = False
|
349 |
+
for i in range(num_frames):
|
350 |
+
bpy.context.scene.frame_set(i * args.downsample)
|
351 |
+
for obj in get_scene_meshes() if single_obj is None else [single_obj]:
|
352 |
+
found = True
|
353 |
+
for coord in obj.bound_box:
|
354 |
+
coord = Vector(coord)
|
355 |
+
if not ignore_matrix:
|
356 |
+
coord = obj.matrix_world @ coord
|
357 |
+
bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
|
358 |
+
bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
|
359 |
+
|
360 |
+
if not found:
|
361 |
+
raise RuntimeError("no objects in scene to compute bounding box for")
|
362 |
+
|
363 |
+
return Vector(bbox_min), Vector(bbox_max)
|
364 |
+
|
365 |
+
def get_scene_meshes():
|
366 |
+
"""Returns all meshes in the scene.
|
367 |
+
|
368 |
+
Yields:
|
369 |
+
Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
|
370 |
+
"""
|
371 |
+
for obj in bpy.context.scene.objects.values():
|
372 |
+
if isinstance(obj.data, (bpy.types.Mesh)):
|
373 |
+
yield obj
|
374 |
+
|
375 |
+
def get_scene_root_objects():
|
376 |
+
"""Returns all root objects in the scene.
|
377 |
+
|
378 |
+
Yields:
|
379 |
+
Generator[bpy.types.Object, None, None]: Generator of all root objects in the
|
380 |
+
scene.
|
381 |
+
"""
|
382 |
+
for obj in bpy.context.scene.objects.values():
|
383 |
+
if not obj.parent:
|
384 |
+
yield obj
|
385 |
+
|
386 |
+
def normalize_scene():
|
387 |
+
"""Normalizes the scene by scaling and translating it to fit in a unit cube centered
|
388 |
+
at the origin.
|
389 |
+
|
390 |
+
Mostly taken from the Point-E / Shap-E rendering script
|
391 |
+
(https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
|
392 |
+
but fix for multiple root objects: (see bug report here:
|
393 |
+
https://github.com/openai/shap-e/pull/60).
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
None
|
397 |
+
"""
|
398 |
+
if len(list(get_scene_root_objects())) > 1:
|
399 |
+
# create an empty object to be used as a parent for all root objects
|
400 |
+
parent_empty = bpy.data.objects.new("ParentEmpty", None)
|
401 |
+
bpy.context.scene.collection.objects.link(parent_empty)
|
402 |
+
|
403 |
+
# parent all root objects to the empty object
|
404 |
+
for obj in get_scene_root_objects():
|
405 |
+
if obj != parent_empty:
|
406 |
+
obj.parent = parent_empty
|
407 |
+
|
408 |
+
bbox_min, bbox_max = scene_bbox()
|
409 |
+
scale = 1 / max(bbox_max - bbox_min)
|
410 |
+
logger.info(f"Scale: {scale}")
|
411 |
+
for obj in get_scene_root_objects():
|
412 |
+
obj.scale = obj.scale * scale
|
413 |
+
|
414 |
+
# Apply scale to matrix_world.
|
415 |
+
bpy.context.view_layer.update()
|
416 |
+
bbox_min, bbox_max = scene_bbox()
|
417 |
+
offset = -(bbox_min + bbox_max) / 2
|
418 |
+
for obj in get_scene_root_objects():
|
419 |
+
obj.matrix_world.translation += offset
|
420 |
+
bpy.ops.object.select_all(action="DESELECT")
|
421 |
+
|
422 |
+
# unparent the camera
|
423 |
+
bpy.data.objects["Camera"].parent = None
|
424 |
+
|
425 |
+
normalize_scene()
|
426 |
+
|
427 |
+
randomize_lighting()
|
428 |
+
|
429 |
+
# Place camera
|
430 |
+
cam = scene.objects['Camera']
|
431 |
+
cam.location = (0, 1.5, 0) # radius equals to 1
|
432 |
+
cam.data.lens = 35
|
433 |
+
cam.data.sensor_width = 32
|
434 |
+
|
435 |
+
cam_constraint = cam.constraints.new(type='TRACK_TO')
|
436 |
+
cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
|
437 |
+
cam_constraint.up_axis = 'UP_Y'
|
438 |
+
|
439 |
+
cam_empty = bpy.data.objects.new("Empty", None)
|
440 |
+
cam_empty.location = (0, 0, 0)
|
441 |
+
cam.parent = cam_empty
|
442 |
+
|
443 |
+
scene.collection.objects.link(cam_empty)
|
444 |
+
context.view_layer.objects.active = cam_empty
|
445 |
+
cam_constraint.target = cam_empty
|
446 |
+
|
447 |
+
stepsize = 360.0 / args.views
|
448 |
+
rotation_mode = 'XYZ'
|
449 |
+
|
450 |
+
|
451 |
+
np.random.seed(args.seed)
|
452 |
+
|
453 |
+
if args.camera_option == "fixed":
|
454 |
+
for scene in bpy.data.scenes:
|
455 |
+
scene.cycles.device = 'GPU'
|
456 |
+
|
457 |
+
elevation_angle = 0.
|
458 |
+
rotation_angle = 0.
|
459 |
+
|
460 |
+
for view_idx in range(args.views):
|
461 |
+
img_folder = os.path.join(save_root, f'{view_idx:03d}', 'img')
|
462 |
+
mask_folder = os.path.join(save_root, f'{view_idx:03d}', 'mask')
|
463 |
+
camera_folder = os.path.join(save_root, f'{view_idx:03d}', 'camera')
|
464 |
+
|
465 |
+
os.makedirs(img_folder, exist_ok=True)
|
466 |
+
os.makedirs(mask_folder, exist_ok=True)
|
467 |
+
os.makedirs(camera_folder, exist_ok=True)
|
468 |
+
|
469 |
+
np.save(os.path.join(camera_folder, 'rotation'), np.array([rotation_angle + view_idx * stepsize for _ in range(num_frames)]))
|
470 |
+
np.save(os.path.join(camera_folder, 'elevation'), np.array([elevation_angle for _ in range(num_frames)]))
|
471 |
+
|
472 |
+
cam_empty.rotation_euler[2] = math.radians(rotation_angle + view_idx * stepsize)
|
473 |
+
cam_empty.rotation_euler[0] = math.radians(elevation_angle)
|
474 |
+
|
475 |
+
# save camera RT matrix
|
476 |
+
rt_matrix = get_3x4_RT_matrix_from_blender(cam)
|
477 |
+
rt_matrix_path = os.path.join(camera_folder, "rt_matrix.npy")
|
478 |
+
np.save(rt_matrix_path, rt_matrix)
|
479 |
+
for i in range(0, num_frames):
|
480 |
+
bpy.context.scene.frame_set(i * args.downsample)
|
481 |
+
render_file_path = os.path.join(img_folder,'%03d.png' % (i))
|
482 |
+
scene.render.filepath = render_file_path
|
483 |
+
bpy.ops.render.render(write_still=True)
|
484 |
+
|
485 |
+
for i in range(0, num_frames):
|
486 |
+
img = cv2.imread(os.path.join(img_folder, '%03d.png' % (i)), cv2.IMREAD_UNCHANGED)
|
487 |
+
mask = img[:, :, 3:4] / 255.0
|
488 |
+
white_img = img[:, :, :3] * mask + np.ones_like(img[:, :, :3]) * (1 - mask) * 255
|
489 |
+
white_img = np.clip(white_img, 0, 255)
|
490 |
+
cv2.imwrite(os.path.join(img_folder, '%03d.jpg' % (i)), white_img)
|
491 |
+
cv2.imwrite(os.path.join(mask_folder, '%03d.png'%(i)), img[:, :, 3])
|
492 |
+
os.system('rm %s'%(os.path.join(img_folder, '%03d.png' % (i))))
|
493 |
+
|
494 |
+
elif args.camera_option == "random":
|
495 |
+
for scene in bpy.data.scenes:
|
496 |
+
scene.cycles.device = 'GPU'
|
497 |
+
|
498 |
+
for view_idx in range(args.views):
|
499 |
+
elevation_angle = np.random.rand(1) * 35 - 5 # [-5, 30]
|
500 |
+
rotation_angle = np.random.rand(1) * 360
|
501 |
+
|
502 |
+
img_folder = os.path.join(save_root, f'{view_idx:03d}', 'img')
|
503 |
+
mask_folder = os.path.join(save_root, f'{view_idx:03d}', 'mask')
|
504 |
+
camera_folder = os.path.join(save_root, f'{view_idx:03d}', 'camera')
|
505 |
+
|
506 |
+
os.makedirs(img_folder, exist_ok=True)
|
507 |
+
os.makedirs(mask_folder, exist_ok=True)
|
508 |
+
os.makedirs(camera_folder, exist_ok=True)
|
509 |
+
|
510 |
+
np.save(os.path.join(camera_folder, 'rotation'), np.array([rotation_angle for _ in range(num_frames)]))
|
511 |
+
np.save(os.path.join(camera_folder, 'elevation'), np.array([elevation_angle for _ in range(num_frames)]))
|
512 |
+
|
513 |
+
cam_empty.rotation_euler[2] = math.radians(rotation_angle)
|
514 |
+
cam_empty.rotation_euler[0] = math.radians(elevation_angle)
|
515 |
+
|
516 |
+
# save camera RT matrix
|
517 |
+
rt_matrix = get_3x4_RT_matrix_from_blender(cam)
|
518 |
+
rt_matrix_path = os.path.join(camera_folder, "rt_matrix.npy")
|
519 |
+
np.save(rt_matrix_path, rt_matrix)
|
520 |
+
|
521 |
+
for i in range(0, num_frames):
|
522 |
+
bpy.context.scene.frame_set(i * args.downsample)
|
523 |
+
render_file_path = os.path.join(img_folder,'%03d.png' % (i))
|
524 |
+
scene.render.filepath = render_file_path
|
525 |
+
bpy.ops.render.render(write_still=True)
|
526 |
+
|
527 |
+
for i in range(0, num_frames):
|
528 |
+
img = cv2.imread(os.path.join(img_folder, '%03d.png' % (i)), cv2.IMREAD_UNCHANGED)
|
529 |
+
mask = img[:, :, 3:4] / 255.0
|
530 |
+
white_img = img[:, :, :3] * mask + np.ones_like(img[:, :, :3]) * (1 - mask) * 255
|
531 |
+
white_img = np.clip(white_img, 0, 255)
|
532 |
+
cv2.imwrite(os.path.join(img_folder, '%03d.jpg' % (i)), white_img)
|
533 |
+
cv2.imwrite(os.path.join(mask_folder, '%03d.png'%(i)), img[:, :, 3])
|
534 |
+
os.system('rm %s'%(os.path.join(img_folder, '%03d.png' % (i))))
|
535 |
+
|
536 |
+
else:
|
537 |
+
raise NotImplemented
|
core/__init__.py
ADDED
File without changes
|
core/attention.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import os
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
17 |
+
try:
|
18 |
+
if XFORMERS_ENABLED:
|
19 |
+
from xformers.ops import memory_efficient_attention, unbind
|
20 |
+
|
21 |
+
XFORMERS_AVAILABLE = True
|
22 |
+
# warnings.warn("xFormers is available (Attention)")
|
23 |
+
else:
|
24 |
+
warnings.warn("xFormers is disabled (Attention)")
|
25 |
+
raise ImportError
|
26 |
+
except ImportError:
|
27 |
+
XFORMERS_AVAILABLE = False
|
28 |
+
warnings.warn("xFormers is not available (Attention)")
|
29 |
+
|
30 |
+
|
31 |
+
class Attention(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
dim: int,
|
35 |
+
num_heads: int = 8,
|
36 |
+
qkv_bias: bool = False,
|
37 |
+
proj_bias: bool = True,
|
38 |
+
attn_drop: float = 0.0,
|
39 |
+
proj_drop: float = 0.0,
|
40 |
+
) -> None:
|
41 |
+
super().__init__()
|
42 |
+
self.num_heads = num_heads
|
43 |
+
head_dim = dim // num_heads
|
44 |
+
self.scale = head_dim**-0.5
|
45 |
+
|
46 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
47 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
48 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
49 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
50 |
+
|
51 |
+
def forward(self, x: Tensor) -> Tensor:
|
52 |
+
B, N, C = x.shape
|
53 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
54 |
+
|
55 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
56 |
+
attn = q @ k.transpose(-2, -1)
|
57 |
+
|
58 |
+
attn = attn.softmax(dim=-1)
|
59 |
+
attn = self.attn_drop(attn)
|
60 |
+
|
61 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
62 |
+
x = self.proj(x)
|
63 |
+
x = self.proj_drop(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class MemEffAttention(Attention):
|
68 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
69 |
+
if not XFORMERS_AVAILABLE:
|
70 |
+
if attn_bias is not None:
|
71 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
72 |
+
return super().forward(x)
|
73 |
+
|
74 |
+
B, N, C = x.shape
|
75 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
76 |
+
|
77 |
+
q, k, v = unbind(qkv, 2)
|
78 |
+
|
79 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
80 |
+
|
81 |
+
x = x.reshape([B, N, C])
|
82 |
+
|
83 |
+
x = self.proj(x)
|
84 |
+
x = self.proj_drop(x)
|
85 |
+
return x
|
core/gs.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from core.options import Options
|
23 |
+
|
24 |
+
import kiui
|
25 |
+
|
26 |
+
from gsplat.rendering import rasterization
|
27 |
+
|
28 |
+
class GaussianRenderer:
|
29 |
+
def __init__(self, opt: Options):
|
30 |
+
|
31 |
+
self.opt = opt
|
32 |
+
self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
|
33 |
+
|
34 |
+
# intrinsics
|
35 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
36 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
37 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
38 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
39 |
+
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
40 |
+
self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
41 |
+
self.proj_matrix[2, 3] = 1
|
42 |
+
|
43 |
+
f = self.opt.output_size / (2 * self.tan_half_fov)
|
44 |
+
self.K = torch.tensor([[f, 0., self.opt.output_size/2.], [0., f, self.opt.output_size/2.], [0., 0., 1.]], dtype=torch.float32, device="cuda")
|
45 |
+
|
46 |
+
def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None):
|
47 |
+
# gaussians: [B, N, 14]
|
48 |
+
# cam_view, cam_view_proj: [B, V, 4, 4]
|
49 |
+
# cam_pos: [B, V, 3]
|
50 |
+
|
51 |
+
device = gaussians.device
|
52 |
+
B, V = cam_view.shape[:2]
|
53 |
+
|
54 |
+
# loop of loop...
|
55 |
+
images = []
|
56 |
+
alphas = []
|
57 |
+
for b in range(B):
|
58 |
+
|
59 |
+
# pos, opacity, scale, rotation, shs
|
60 |
+
means3D = gaussians[b, :, 0:3].contiguous().float()
|
61 |
+
opacity = gaussians[b, :, 3:4].contiguous().float()
|
62 |
+
scales = gaussians[b, :, 4:7].contiguous().float()
|
63 |
+
rotations = gaussians[b, :, 7:11].contiguous().float()
|
64 |
+
rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
|
65 |
+
|
66 |
+
# render novel views
|
67 |
+
view_matrix = cam_view[b].float()
|
68 |
+
view_proj_matrix = cam_view_proj[b].float()
|
69 |
+
campos = cam_pos[b].float()
|
70 |
+
|
71 |
+
viewmat = view_matrix.transpose(2, 1) # [V, 4, 4]
|
72 |
+
|
73 |
+
|
74 |
+
rendered_image_all, rendered_alpha_all, info = rasterization(
|
75 |
+
means=means3D,
|
76 |
+
quats=rotations,
|
77 |
+
scales=scales,
|
78 |
+
opacities=opacity.squeeze(-1),
|
79 |
+
colors=rgbs,
|
80 |
+
viewmats=viewmat,
|
81 |
+
Ks=torch.stack([self.K for _ in range(V)]),
|
82 |
+
width=self.opt.output_size,
|
83 |
+
height=self.opt.output_size,
|
84 |
+
near_plane=self.opt.znear,
|
85 |
+
far_plane=self.opt.zfar,
|
86 |
+
packed=False,
|
87 |
+
backgrounds=torch.stack([self.bg_color for _ in range(V)]) if self.bg_color is not None else None,
|
88 |
+
render_mode="RGB",
|
89 |
+
)
|
90 |
+
for rendered_image, rendered_alpha in zip(rendered_image_all, rendered_alpha_all):
|
91 |
+
|
92 |
+
rendered_image = rendered_image.permute(2, 0, 1)
|
93 |
+
rendered_image = rendered_image.clamp(0, 1)
|
94 |
+
|
95 |
+
rendered_alpha = rendered_alpha.permute(2, 0, 1)
|
96 |
+
|
97 |
+
images.append(rendered_image)
|
98 |
+
alphas.append(rendered_alpha)
|
99 |
+
|
100 |
+
images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
|
101 |
+
alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
|
102 |
+
|
103 |
+
return {
|
104 |
+
"image": images, # [B, V, 3, H, W]
|
105 |
+
"alpha": alphas, # [B, V, 1, H, W]
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
def save_ply(self, gaussians, path, compatible=True):
|
110 |
+
# gaussians: [B, N, 14]
|
111 |
+
# compatible: save pre-activated gaussians as in the original paper
|
112 |
+
|
113 |
+
assert gaussians.shape[0] == 1, 'only support batch size 1'
|
114 |
+
|
115 |
+
from plyfile import PlyData, PlyElement
|
116 |
+
|
117 |
+
means3D = gaussians[0, :, 0:3].contiguous().float()
|
118 |
+
opacity = gaussians[0, :, 3:4].contiguous().float()
|
119 |
+
scales = gaussians[0, :, 4:7].contiguous().float()
|
120 |
+
rotations = gaussians[0, :, 7:11].contiguous().float()
|
121 |
+
shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
|
122 |
+
|
123 |
+
# prune by opacity
|
124 |
+
mask = opacity.squeeze(-1) >= 0.005
|
125 |
+
means3D = means3D[mask]
|
126 |
+
opacity = opacity[mask]
|
127 |
+
scales = scales[mask]
|
128 |
+
rotations = rotations[mask]
|
129 |
+
shs = shs[mask]
|
130 |
+
|
131 |
+
# invert activation to make it compatible with the original ply format
|
132 |
+
if compatible:
|
133 |
+
opacity = kiui.op.inverse_sigmoid(opacity)
|
134 |
+
scales = torch.log(scales + 1e-8)
|
135 |
+
shs = (shs - 0.5) / 0.28209479177387814
|
136 |
+
|
137 |
+
xyzs = means3D.detach().cpu().numpy()
|
138 |
+
f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
139 |
+
opacities = opacity.detach().cpu().numpy()
|
140 |
+
scales = scales.detach().cpu().numpy()
|
141 |
+
rotations = rotations.detach().cpu().numpy()
|
142 |
+
|
143 |
+
l = ['x', 'y', 'z']
|
144 |
+
# All channels except the 3 DC
|
145 |
+
for i in range(f_dc.shape[1]):
|
146 |
+
l.append('f_dc_{}'.format(i))
|
147 |
+
l.append('opacity')
|
148 |
+
for i in range(scales.shape[1]):
|
149 |
+
l.append('scale_{}'.format(i))
|
150 |
+
for i in range(rotations.shape[1]):
|
151 |
+
l.append('rot_{}'.format(i))
|
152 |
+
|
153 |
+
dtype_full = [(attribute, 'f4') for attribute in l]
|
154 |
+
|
155 |
+
elements = np.empty(xyzs.shape[0], dtype=dtype_full)
|
156 |
+
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
|
157 |
+
elements[:] = list(map(tuple, attributes))
|
158 |
+
el = PlyElement.describe(elements, 'vertex')
|
159 |
+
|
160 |
+
PlyData([el]).write(path)
|
161 |
+
|
162 |
+
def load_ply(self, path, compatible=True):
|
163 |
+
|
164 |
+
from plyfile import PlyData, PlyElement
|
165 |
+
|
166 |
+
plydata = PlyData.read(path)
|
167 |
+
|
168 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
169 |
+
np.asarray(plydata.elements[0]["y"]),
|
170 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
171 |
+
print("Number of points at loading : ", xyz.shape[0])
|
172 |
+
|
173 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
174 |
+
|
175 |
+
shs = np.zeros((xyz.shape[0], 3))
|
176 |
+
shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
177 |
+
shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
|
178 |
+
shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
|
179 |
+
|
180 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
181 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
182 |
+
for idx, attr_name in enumerate(scale_names):
|
183 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
184 |
+
|
185 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
|
186 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
187 |
+
for idx, attr_name in enumerate(rot_names):
|
188 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
189 |
+
|
190 |
+
gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
|
191 |
+
gaussians = torch.from_numpy(gaussians).float() # cpu
|
192 |
+
|
193 |
+
if compatible:
|
194 |
+
gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
|
195 |
+
gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
|
196 |
+
gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
|
197 |
+
|
198 |
+
return gaussians
|
core/models.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import kiui
|
22 |
+
from kiui.lpips import LPIPS
|
23 |
+
|
24 |
+
from core.unet import UNet
|
25 |
+
from core.options import Options
|
26 |
+
from core.gs import GaussianRenderer
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
class LGM(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
opt: Options,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.opt = opt
|
38 |
+
|
39 |
+
# unet
|
40 |
+
self.unet = UNet(
|
41 |
+
9, 14 * self.opt.gaussian_perpixel,
|
42 |
+
down_channels=self.opt.down_channels,
|
43 |
+
down_attention=self.opt.down_attention,
|
44 |
+
mid_attention=self.opt.mid_attention,
|
45 |
+
up_channels=self.opt.up_channels,
|
46 |
+
up_attention=self.opt.up_attention,
|
47 |
+
num_views=self.opt.num_input_views,
|
48 |
+
num_frames=self.opt.num_frames,
|
49 |
+
use_temp_attn=self.opt.use_temp_attn
|
50 |
+
)
|
51 |
+
|
52 |
+
# last conv
|
53 |
+
self.conv = nn.Conv2d(14 * self.opt.gaussian_perpixel, 14 * self.opt.gaussian_perpixel, kernel_size=1) # NOTE: maybe remove it if train again
|
54 |
+
|
55 |
+
# Gaussian Renderer
|
56 |
+
self.gs = GaussianRenderer(opt)
|
57 |
+
|
58 |
+
# activations...
|
59 |
+
self.pos_act = lambda x: x.clamp(-1, 1)
|
60 |
+
self.scale_act = lambda x: 0.1 * F.softplus(x)
|
61 |
+
self.opacity_act = lambda x: torch.sigmoid(x)
|
62 |
+
self.rot_act = lambda x: F.normalize(x, dim=-1)
|
63 |
+
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
|
64 |
+
|
65 |
+
# LPIPS loss
|
66 |
+
if self.opt.lambda_lpips > 0:
|
67 |
+
self.lpips_loss = LPIPS(net='vgg')
|
68 |
+
self.lpips_loss.requires_grad_(False)
|
69 |
+
|
70 |
+
|
71 |
+
def state_dict(self, **kwargs):
|
72 |
+
# remove lpips_loss
|
73 |
+
state_dict = super().state_dict(**kwargs)
|
74 |
+
for k in list(state_dict.keys()):
|
75 |
+
if 'lpips_loss' in k:
|
76 |
+
del state_dict[k]
|
77 |
+
return state_dict
|
78 |
+
|
79 |
+
|
80 |
+
def prepare_default_rays(self, device, elevation=0):
|
81 |
+
|
82 |
+
from kiui.cam import orbit_camera
|
83 |
+
from core.utils import get_rays
|
84 |
+
|
85 |
+
cam_poses = np.stack([
|
86 |
+
orbit_camera(elevation, 0, radius=self.opt.cam_radius),
|
87 |
+
orbit_camera(elevation, 90, radius=self.opt.cam_radius),
|
88 |
+
orbit_camera(elevation, 180, radius=self.opt.cam_radius),
|
89 |
+
orbit_camera(elevation, 270, radius=self.opt.cam_radius),
|
90 |
+
], axis=0) # [4, 4, 4]
|
91 |
+
cam_poses = torch.from_numpy(cam_poses)
|
92 |
+
|
93 |
+
rays_embeddings = []
|
94 |
+
for i in range(cam_poses.shape[0]):
|
95 |
+
rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
96 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
97 |
+
rays_embeddings.append(rays_plucker)
|
98 |
+
|
99 |
+
## visualize rays for plotting figure
|
100 |
+
# kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
|
101 |
+
|
102 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
|
103 |
+
|
104 |
+
return rays_embeddings
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
def forward_gaussians(self, images):
|
109 |
+
# images: [B, T, 4, 9, H, W]
|
110 |
+
# return: Gaussians: [B, dim_t]
|
111 |
+
|
112 |
+
B, TV, C, H, W = images.shape
|
113 |
+
T = self.opt.num_frames
|
114 |
+
V = TV // T
|
115 |
+
images = images.view(B*T*V, C, H, W)
|
116 |
+
|
117 |
+
x = self.unet(images) # [B*4, 14, h, w]
|
118 |
+
x = self.conv(x) # [B*4, 14, h, w]
|
119 |
+
|
120 |
+
x = x.reshape(B*T, V, 14 * self.opt.gaussian_perpixel, self.opt.splat_size, self.opt.splat_size)
|
121 |
+
|
122 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(B*T, -1, 14).contiguous()
|
123 |
+
|
124 |
+
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
125 |
+
opacity = self.opacity_act(x[..., 3:4])
|
126 |
+
scale = self.scale_act(x[..., 4:7])
|
127 |
+
rotation = self.rot_act(x[..., 7:11])
|
128 |
+
rgbs = self.rgb_act(x[..., 11:])
|
129 |
+
|
130 |
+
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, T, N, 14]
|
131 |
+
|
132 |
+
return gaussians
|
133 |
+
|
134 |
+
|
135 |
+
def forward(self, data, step_ratio=1):
|
136 |
+
# data: output of the dataloader
|
137 |
+
# return: loss
|
138 |
+
|
139 |
+
results = {}
|
140 |
+
loss = 0
|
141 |
+
|
142 |
+
images = data['input'] # [B, Tx4, 9, h, W], input features
|
143 |
+
|
144 |
+
B, TV, C, H, W = images.shape
|
145 |
+
T = self.opt.num_frames
|
146 |
+
|
147 |
+
# use the first view to predict gaussians
|
148 |
+
gaussians = self.forward_gaussians(images) # [B * T, N, 14]
|
149 |
+
|
150 |
+
results['gaussians'] = gaussians
|
151 |
+
|
152 |
+
# always use white bg
|
153 |
+
bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
|
154 |
+
|
155 |
+
# use the other views for rendering and supervision
|
156 |
+
data['cam_view'] = data['cam_view'].reshape(B*T, -1, *data['cam_view'].shape[2:])
|
157 |
+
data['cam_view_proj'] = data['cam_view_proj'].reshape(B*T, -1, *data['cam_view_proj'].shape[2:])
|
158 |
+
data['cam_pos'] = data['cam_pos'].reshape(B*T, -1, *data['cam_pos'].shape[2:])
|
159 |
+
|
160 |
+
results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
|
161 |
+
pred_images = results['image'] # [B*T, V, C, output_size, output_size]
|
162 |
+
pred_alphas = results['alpha'] # [B*T, V, 1, output_size, output_size]
|
163 |
+
|
164 |
+
results['images_pred'] = pred_images
|
165 |
+
results['alphas_pred'] = pred_alphas
|
166 |
+
|
167 |
+
|
168 |
+
data['images_output'] = data['images_output'].reshape(B*T, -1, *data['images_output'].shape[2:])
|
169 |
+
data['masks_output'] = data['masks_output'].reshape(B*T, -1, *data['masks_output'].shape[2:])
|
170 |
+
|
171 |
+
gt_images = data['images_output'] # [B*T, V, 3, output_size, output_size], ground-truth novel views
|
172 |
+
gt_masks = data['masks_output'] # [B*T, V, 1, output_size, output_size], ground-truth masks
|
173 |
+
|
174 |
+
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
|
175 |
+
|
176 |
+
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
|
177 |
+
loss = loss + loss_mse
|
178 |
+
|
179 |
+
if self.opt.lambda_lpips > 0:
|
180 |
+
loss_lpips = self.lpips_loss(
|
181 |
+
# downsampled to at most 256 to reduce memory cost
|
182 |
+
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
183 |
+
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
|
184 |
+
).mean()
|
185 |
+
results['loss_lpips'] = loss_lpips
|
186 |
+
loss = loss + self.opt.lambda_lpips * loss_lpips
|
187 |
+
|
188 |
+
results['loss'] = loss
|
189 |
+
|
190 |
+
# metric
|
191 |
+
with torch.no_grad():
|
192 |
+
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
|
193 |
+
results['psnr'] = psnr
|
194 |
+
|
195 |
+
return results
|
core/options.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import tyro
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Tuple, Literal, Dict, Optional
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Options:
|
23 |
+
### model
|
24 |
+
# Unet image input size
|
25 |
+
input_size: int = 256
|
26 |
+
# Unet definition
|
27 |
+
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
|
28 |
+
down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
|
29 |
+
mid_attention: bool = True
|
30 |
+
up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
|
31 |
+
up_attention: Tuple[bool, ...] = (True, True, True, False)
|
32 |
+
# Unet output size, dependent on the input_size and U-Net structure!
|
33 |
+
splat_size: int = 64
|
34 |
+
# gaussian render size
|
35 |
+
output_size: int = 256
|
36 |
+
|
37 |
+
### dataset
|
38 |
+
# data mode (only support s3 now)
|
39 |
+
data_mode: str = '4d'
|
40 |
+
# fovy of the dataset
|
41 |
+
fovy: float = 49.1
|
42 |
+
# camera near plane
|
43 |
+
znear: float = 0.5
|
44 |
+
# camera far plane
|
45 |
+
zfar: float = 2.5
|
46 |
+
# number of all views (input + output)
|
47 |
+
num_views: int = 12
|
48 |
+
# number of views
|
49 |
+
num_input_views: int = 4
|
50 |
+
# camera radius
|
51 |
+
cam_radius: float = 1.5 # to better use [-1, 1]^3 space
|
52 |
+
# num workers
|
53 |
+
num_workers: int = 16
|
54 |
+
datalist: str=''
|
55 |
+
|
56 |
+
### training
|
57 |
+
# workspace
|
58 |
+
workspace: str = './workspace'
|
59 |
+
# resume
|
60 |
+
resume: Optional[str] = None
|
61 |
+
# batch size (per-GPU)
|
62 |
+
batch_size: int = 8
|
63 |
+
# gradient accumulation
|
64 |
+
gradient_accumulation_steps: int = 1
|
65 |
+
# training epochs
|
66 |
+
num_epochs: int = 30
|
67 |
+
# lpips loss weight
|
68 |
+
lambda_lpips: float = 1.0
|
69 |
+
# gradient clip
|
70 |
+
gradient_clip: float = 1.0
|
71 |
+
# mixed precision
|
72 |
+
mixed_precision: str = 'bf16'
|
73 |
+
# learning rate
|
74 |
+
lr: float = 4e-4
|
75 |
+
# augmentation prob for grid distortion
|
76 |
+
prob_grid_distortion: float = 0.5
|
77 |
+
# augmentation prob for camera jitter
|
78 |
+
prob_cam_jitter: float = 0.5
|
79 |
+
# number of gaussians per pixel
|
80 |
+
gaussian_perpixel: int = 1
|
81 |
+
|
82 |
+
### testing
|
83 |
+
# test image path
|
84 |
+
test_path: Optional[str] = None
|
85 |
+
|
86 |
+
### misc
|
87 |
+
# nvdiffrast backend setting
|
88 |
+
force_cuda_rast: bool = False
|
89 |
+
# render fancy video with gaussian scaling effect
|
90 |
+
fancy_video: bool = False
|
91 |
+
|
92 |
+
# 4D
|
93 |
+
num_frames: int = 8
|
94 |
+
use_temp_attn: bool = True
|
95 |
+
shuffle_input: bool = True
|
96 |
+
|
97 |
+
# s3
|
98 |
+
sample_by_anim: bool = True
|
99 |
+
|
100 |
+
# interp
|
101 |
+
interpresume: Optional[str] = None
|
102 |
+
interpolate_rate: int = 3
|
103 |
+
|
104 |
+
|
105 |
+
# all the default settings
|
106 |
+
config_defaults: Dict[str, Options] = {}
|
107 |
+
config_doc: Dict[str, str] = {}
|
108 |
+
|
109 |
+
config_doc['lrm'] = 'the default settings for LGM'
|
110 |
+
config_defaults['lrm'] = Options()
|
111 |
+
|
112 |
+
|
113 |
+
config_doc['big'] = 'big model with higher resolution Gaussians'
|
114 |
+
config_defaults['big'] = Options(
|
115 |
+
input_size=256,
|
116 |
+
up_channels=(1024, 1024, 512, 256, 128), # one more decoder
|
117 |
+
up_attention=(True, True, True, False, False),
|
118 |
+
splat_size=128,
|
119 |
+
output_size=512, # render & supervise Gaussians at a higher resolution.
|
120 |
+
batch_size=1,
|
121 |
+
num_views=8,
|
122 |
+
gradient_accumulation_steps=1,
|
123 |
+
mixed_precision='bf16',
|
124 |
+
resume='pretrained/model_fp16_fixrot.safetensors',
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
|
core/provider_objaverse_4d.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import cv2
|
18 |
+
import random
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torchvision.transforms.functional as TF
|
25 |
+
from torch.utils.data import Dataset
|
26 |
+
|
27 |
+
import kiui
|
28 |
+
from core.options import Options
|
29 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
30 |
+
|
31 |
+
from kiui.cam import orbit_camera
|
32 |
+
|
33 |
+
import tarfile
|
34 |
+
from io import BytesIO
|
35 |
+
|
36 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
37 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
38 |
+
|
39 |
+
|
40 |
+
def load_np_array_from_tar(tar, path):
|
41 |
+
array_file = BytesIO()
|
42 |
+
array_file.write(tar.extractfile(path).read())
|
43 |
+
array_file.seek(0)
|
44 |
+
return np.load(array_file)
|
45 |
+
|
46 |
+
|
47 |
+
class ObjaverseDataset(Dataset):
|
48 |
+
|
49 |
+
def _warn(self):
|
50 |
+
raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
|
51 |
+
|
52 |
+
def __init__(self, opt: Options, training=True, evaluating=False):
|
53 |
+
|
54 |
+
self.opt = opt
|
55 |
+
self.training = training
|
56 |
+
self.evaluating = evaluating
|
57 |
+
|
58 |
+
self.items = []
|
59 |
+
with open(self.opt.datalist, 'r') as f:
|
60 |
+
for line in f.readlines():
|
61 |
+
self.items.append(line.strip())
|
62 |
+
|
63 |
+
|
64 |
+
anim_map = {}
|
65 |
+
for x in self.items:
|
66 |
+
k = x.split('-')[1]
|
67 |
+
if k in anim_map:
|
68 |
+
anim_map[k] += '|'+x
|
69 |
+
else:
|
70 |
+
anim_map[k] = x
|
71 |
+
self.items = list(anim_map.values())
|
72 |
+
|
73 |
+
|
74 |
+
# naive split
|
75 |
+
if self.training:
|
76 |
+
self.items = self.items[:-self.opt.batch_size]
|
77 |
+
elif self.evaluating:
|
78 |
+
self.items = self.items[::1000]
|
79 |
+
else:
|
80 |
+
self.items = self.items[-self.opt.batch_size:]
|
81 |
+
|
82 |
+
|
83 |
+
# default camera intrinsics
|
84 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
85 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
86 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
87 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
88 |
+
self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
89 |
+
self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
90 |
+
self.proj_matrix[2, 3] = 1
|
91 |
+
|
92 |
+
def __len__(self):
|
93 |
+
return len(self.items)
|
94 |
+
|
95 |
+
def _get_batch(self, idx):
|
96 |
+
if self.training:
|
97 |
+
uid = random.choice(self.items[idx].split('|'))
|
98 |
+
else:
|
99 |
+
uid = self.items[idx].split('|')[0]
|
100 |
+
|
101 |
+
results = {}
|
102 |
+
|
103 |
+
# load num_views images
|
104 |
+
images = []
|
105 |
+
masks = []
|
106 |
+
cam_poses = []
|
107 |
+
|
108 |
+
if self.training and self.opt.shuffle_input:
|
109 |
+
vids = np.random.permutation(np.arange(32, 48))[:self.opt.num_input_views].tolist() + np.random.permutation(32).tolist()
|
110 |
+
else:
|
111 |
+
vids = np.arange(32, 48, 4).tolist() + np.arange(32).tolist()
|
112 |
+
|
113 |
+
|
114 |
+
random_tar_name = 'random_clip/' + uid
|
115 |
+
fixed_16_tar_name = 'fixed_16_clip/' + uid
|
116 |
+
|
117 |
+
local_random_tar_name = os.environ["DATA_HOME"] + random_tar_name.replace('/', '-')
|
118 |
+
local_fixed_16_tar_name = os.environ["DATA_HOME"] + fixed_16_tar_name.replace('/', '-')
|
119 |
+
|
120 |
+
tar_random = tarfile.open(local_random_tar_name)
|
121 |
+
tar_fixed = tarfile.open(local_fixed_16_tar_name)
|
122 |
+
|
123 |
+
|
124 |
+
T = self.opt.num_frames
|
125 |
+
for t_idx in range(T):
|
126 |
+
t = t_idx
|
127 |
+
vid_cnt = 0
|
128 |
+
for vid in vids:
|
129 |
+
if vid >= 32:
|
130 |
+
vid = vid % 32
|
131 |
+
tar = tar_fixed
|
132 |
+
else:
|
133 |
+
tar = tar_random
|
134 |
+
|
135 |
+
image_path = os.path.join('.', f'{vid:03d}/img', f'{t:03d}.jpg')
|
136 |
+
mask_path = os.path.join('.', f'{vid:03d}/mask', f'{t:03d}.png')
|
137 |
+
|
138 |
+
elevation_path = os.path.join('.', f'{vid:03d}/camera', f'elevation.npy')
|
139 |
+
rotation_path = os.path.join('.', f'{vid:03d}/camera', f'rotation.npy')
|
140 |
+
|
141 |
+
image = np.frombuffer(tar.extractfile(image_path).read(), np.uint8)
|
142 |
+
image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
|
143 |
+
|
144 |
+
azi = load_np_array_from_tar(tar, rotation_path)[t, None]
|
145 |
+
elevation = load_np_array_from_tar(tar, elevation_path)[t, None] * -1 # to align with pretrained LGM
|
146 |
+
azi = float(azi)
|
147 |
+
elevation = float(elevation)
|
148 |
+
c2w = torch.from_numpy(orbit_camera(elevation, azi, radius=1.5, opengl=True))
|
149 |
+
|
150 |
+
image = image.permute(2, 0, 1) # [4, 512, 512]
|
151 |
+
|
152 |
+
mask = np.frombuffer(tar.extractfile(mask_path).read(), np.uint8)
|
153 |
+
mask = torch.from_numpy(cv2.imdecode(mask, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255).unsqueeze(0) # [512, 512, 4] in [0, 1]
|
154 |
+
|
155 |
+
image = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
|
156 |
+
mask = F.interpolate(mask.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
|
157 |
+
|
158 |
+
image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
|
159 |
+
image = image[[2,1,0]].contiguous() # bgr to rgb
|
160 |
+
|
161 |
+
images.append(image)
|
162 |
+
masks.append(mask.squeeze(0))
|
163 |
+
cam_poses.append(c2w)
|
164 |
+
|
165 |
+
vid_cnt += 1
|
166 |
+
if vid_cnt == self.opt.num_views:
|
167 |
+
break
|
168 |
+
|
169 |
+
if vid_cnt < self.opt.num_views:
|
170 |
+
print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
|
171 |
+
n = self.opt.num_views - vid_cnt
|
172 |
+
images = images + [images[-1]] * n
|
173 |
+
masks = masks + [masks[-1]] * n
|
174 |
+
cam_poses = cam_poses + [cam_poses[-1]] * n
|
175 |
+
|
176 |
+
images = torch.stack(images, dim=0) # [V, C, H, W]
|
177 |
+
masks = torch.stack(masks, dim=0) # [V, H, W]
|
178 |
+
cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
|
179 |
+
|
180 |
+
# normalized camera feats as in paper (transform the first pose to a fixed position)
|
181 |
+
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
|
182 |
+
cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
|
183 |
+
|
184 |
+
images_input = F.interpolate(images.reshape(T, self.opt.num_views, *images.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *images.shape[1:]).clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
|
185 |
+
cam_poses_input = cam_poses.reshape(T, self.opt.num_views, *cam_poses.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *cam_poses.shape[1:]).clone()
|
186 |
+
|
187 |
+
# data augmentation
|
188 |
+
if self.training:
|
189 |
+
images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
|
190 |
+
cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
|
191 |
+
|
192 |
+
# apply random grid distortion to simulate 3D inconsistency
|
193 |
+
if random.random() < self.opt.prob_grid_distortion:
|
194 |
+
for t in range(T):
|
195 |
+
images_input[t, 1:] = grid_distortion(images_input[t, 1:])
|
196 |
+
# apply camera jittering (only to input!)
|
197 |
+
if random.random() < self.opt.prob_cam_jitter:
|
198 |
+
for t in range(T):
|
199 |
+
cam_poses_input[t, 1:] = orbit_camera_jitter(cam_poses_input[t, 1:])
|
200 |
+
|
201 |
+
images_input = images_input.reshape(-1, *images_input.shape[2:])
|
202 |
+
cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
|
203 |
+
|
204 |
+
# masking other views
|
205 |
+
images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
|
206 |
+
images_input[1:, 1:] = images_input[0:1, 1:]
|
207 |
+
images_input = images_input.reshape(-1, *images_input.shape[2:])
|
208 |
+
|
209 |
+
cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
|
210 |
+
cam_poses_input[1:, 1:] = cam_poses_input[0:1, 1:]
|
211 |
+
cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
|
212 |
+
|
213 |
+
images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
214 |
+
|
215 |
+
# resize render ground-truth images, range still in [0, 1]
|
216 |
+
results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
|
217 |
+
results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
|
218 |
+
|
219 |
+
# build rays for input views
|
220 |
+
rays_embeddings = []
|
221 |
+
for i in range(self.opt.num_input_views * T):
|
222 |
+
rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
223 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
224 |
+
rays_embeddings.append(rays_plucker)
|
225 |
+
|
226 |
+
|
227 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
|
228 |
+
|
229 |
+
final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
|
230 |
+
results['input'] = final_input
|
231 |
+
|
232 |
+
# opengl to colmap camera for gaussian renderer
|
233 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
234 |
+
|
235 |
+
# cameras needed by gaussian rasterizer
|
236 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
237 |
+
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
|
238 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
239 |
+
|
240 |
+
results['cam_view'] = cam_view
|
241 |
+
results['cam_view_proj'] = cam_view_proj
|
242 |
+
results['cam_pos'] = cam_pos
|
243 |
+
|
244 |
+
return results
|
245 |
+
|
246 |
+
def __getitem__(self, idx):
|
247 |
+
while True:
|
248 |
+
try:
|
249 |
+
results = self._get_batch(idx)
|
250 |
+
break
|
251 |
+
except Exception as e:
|
252 |
+
print(f"{e}")
|
253 |
+
idx = random.randint(0, len(self.items) - 1)
|
254 |
+
return results
|
core/provider_objaverse_4d_interp.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import cv2
|
18 |
+
import random
|
19 |
+
import numpy as np
|
20 |
+
import copy
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import torchvision.transforms.functional as TF
|
26 |
+
from torch.utils.data import Dataset
|
27 |
+
|
28 |
+
import kiui
|
29 |
+
from core.options import Options
|
30 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
31 |
+
|
32 |
+
from kiui.cam import orbit_camera
|
33 |
+
|
34 |
+
import tarfile
|
35 |
+
from io import BytesIO
|
36 |
+
|
37 |
+
|
38 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
39 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
40 |
+
|
41 |
+
|
42 |
+
def load_np_array_from_tar(tar, path):
|
43 |
+
array_file = BytesIO()
|
44 |
+
array_file.write(tar.extractfile(path).read())
|
45 |
+
array_file.seek(0)
|
46 |
+
return np.load(array_file)
|
47 |
+
|
48 |
+
def interpolate_tensors(tensor):
|
49 |
+
# Extract the first and last tensors along the first dimension (B)
|
50 |
+
start_tensor = tensor[0] # shape [4, 3, 256, 256]
|
51 |
+
end_tensor = tensor[-1] # shape [4, 3, 256, 256]
|
52 |
+
tensor_interp = copy.deepcopy(tensor)
|
53 |
+
|
54 |
+
# Iterate over the range from 1 to second-last index
|
55 |
+
for i in range(1, tensor.size(0) - 1):
|
56 |
+
# Calculate the weight for interpolation
|
57 |
+
|
58 |
+
weight = (i - 0) / (tensor.size(0) - 1)
|
59 |
+
# Interpolate between start_tensor and end_tensor
|
60 |
+
tensor_interp[i] = torch.lerp(start_tensor, end_tensor, weight)
|
61 |
+
|
62 |
+
|
63 |
+
return tensor_interp
|
64 |
+
|
65 |
+
class ObjaverseDataset(Dataset):
|
66 |
+
|
67 |
+
def _warn(self):
|
68 |
+
raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
|
69 |
+
|
70 |
+
def __init__(self, opt: Options, training=True, evaluating=False):
|
71 |
+
|
72 |
+
self.opt = opt
|
73 |
+
self.training = training
|
74 |
+
self.evaluating = evaluating
|
75 |
+
|
76 |
+
self.items = []
|
77 |
+
with open(self.opt.datalist, 'r') as f:
|
78 |
+
for line in f.readlines():
|
79 |
+
self.items.append(line.strip())
|
80 |
+
|
81 |
+
anim_map = {}
|
82 |
+
for x in self.items:
|
83 |
+
k = x.split('-')[1]
|
84 |
+
if k in anim_map:
|
85 |
+
anim_map[k] += '|'+x
|
86 |
+
else:
|
87 |
+
anim_map[k] = x
|
88 |
+
self.items = list(anim_map.values())
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
# naive split
|
93 |
+
if self.training:
|
94 |
+
self.items = self.items[:-self.opt.batch_size]
|
95 |
+
elif self.evaluating:
|
96 |
+
self.items = self.items[::1000]
|
97 |
+
else:
|
98 |
+
self.items = self.items[-self.opt.batch_size:]
|
99 |
+
|
100 |
+
|
101 |
+
# default camera intrinsics
|
102 |
+
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
|
103 |
+
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
|
104 |
+
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
105 |
+
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
106 |
+
self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
107 |
+
self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
|
108 |
+
self.proj_matrix[2, 3] = 1
|
109 |
+
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
return len(self.items)
|
113 |
+
|
114 |
+
def _get_batch(self, idx):
|
115 |
+
# uid = self.items[idx]
|
116 |
+
if self.training:
|
117 |
+
uid = random.choice(self.items[idx].split('|'))
|
118 |
+
else:
|
119 |
+
uid = self.items[idx].split('|')[0]
|
120 |
+
|
121 |
+
results = {}
|
122 |
+
|
123 |
+
# load num_views images
|
124 |
+
images = []
|
125 |
+
masks = []
|
126 |
+
cam_poses = []
|
127 |
+
|
128 |
+
if self.training and self.opt.shuffle_input:
|
129 |
+
vids = np.random.permutation(np.arange(32, 48))[:self.opt.num_input_views].tolist() + np.random.permutation(32).tolist()
|
130 |
+
else:
|
131 |
+
vids = np.arange(32, 48, 4).tolist() + np.arange(32).tolist()
|
132 |
+
|
133 |
+
random_tar_name = 'random_24fps/' + uid
|
134 |
+
fixed_16_tar_name = 'fixed_16_24fps/' + uid
|
135 |
+
|
136 |
+
local_random_tar_name = os.environ["DATA_HOME"] + random_tar_name.replace('/', '-')
|
137 |
+
local_fixed_16_tar_name = os.environ["DATA_HOME"] + fixed_16_tar_name.replace('/', '-')
|
138 |
+
|
139 |
+
tar_random = tarfile.open(local_random_tar_name)
|
140 |
+
tar_fixed = tarfile.open(local_fixed_16_tar_name)
|
141 |
+
|
142 |
+
max_T = 24
|
143 |
+
|
144 |
+
T = self.opt.num_frames
|
145 |
+
|
146 |
+
start_frame = np.random.randint(max_T - T)
|
147 |
+
|
148 |
+
for t_idx in range(T):
|
149 |
+
t = start_frame + t_idx
|
150 |
+
vid_cnt = 0
|
151 |
+
for vid in vids:
|
152 |
+
if vid >= 32:
|
153 |
+
vid = vid % 32
|
154 |
+
tar = tar_fixed
|
155 |
+
else:
|
156 |
+
tar = tar_random
|
157 |
+
|
158 |
+
image_path = os.path.join('.', f'{vid:03d}/img', f'{t:03d}.jpg')
|
159 |
+
mask_path = os.path.join('.', f'{vid:03d}/mask', f'{t:03d}.png')
|
160 |
+
|
161 |
+
elevation_path = os.path.join('.', f'{vid:03d}/camera', f'elevation.npy')
|
162 |
+
rotation_path = os.path.join('.', f'{vid:03d}/camera', f'rotation.npy')
|
163 |
+
|
164 |
+
try :
|
165 |
+
image = np.frombuffer(tar.extractfile(image_path).read(), np.uint8)
|
166 |
+
image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
|
167 |
+
|
168 |
+
azi = load_np_array_from_tar(tar, rotation_path)[t, None]
|
169 |
+
elevation = load_np_array_from_tar(tar, elevation_path)[t, None] * -1 # to align with pretrained LGM
|
170 |
+
azi = float(azi)
|
171 |
+
elevation = float(elevation)
|
172 |
+
c2w = torch.from_numpy(orbit_camera(elevation, azi, radius=1.5, opengl=True))
|
173 |
+
|
174 |
+
image = image.permute(2, 0, 1) # [4, 512, 512]
|
175 |
+
|
176 |
+
mask = np.frombuffer(tar.extractfile(mask_path).read(), np.uint8)
|
177 |
+
mask = torch.from_numpy(cv2.imdecode(mask, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255).unsqueeze(0) # [512, 512, 4] in [0, 1]
|
178 |
+
except:
|
179 |
+
|
180 |
+
return self.__getitem__(idx - 1)
|
181 |
+
image = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
|
182 |
+
mask = F.interpolate(mask.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
|
183 |
+
|
184 |
+
image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
|
185 |
+
image = image[[2,1,0]].contiguous() # bgr to rgb
|
186 |
+
|
187 |
+
images.append(image)
|
188 |
+
masks.append(mask.squeeze(0))
|
189 |
+
cam_poses.append(c2w)
|
190 |
+
|
191 |
+
vid_cnt += 1
|
192 |
+
if vid_cnt == self.opt.num_views:
|
193 |
+
break
|
194 |
+
|
195 |
+
if vid_cnt < self.opt.num_views:
|
196 |
+
print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
|
197 |
+
n = self.opt.num_views - vid_cnt
|
198 |
+
images = images + [images[-1]] * n
|
199 |
+
masks = masks + [masks[-1]] * n
|
200 |
+
cam_poses = cam_poses + [cam_poses[-1]] * n
|
201 |
+
|
202 |
+
images = torch.stack(images, dim=0) # [V, C, H, W]
|
203 |
+
masks = torch.stack(masks, dim=0) # [V, H, W]
|
204 |
+
cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
|
205 |
+
|
206 |
+
# normalized camera feats as in paper (transform the first pose to a fixed position)
|
207 |
+
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
|
208 |
+
cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
|
209 |
+
|
210 |
+
images_input = F.interpolate(images.reshape(T, self.opt.num_views, *images.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *images.shape[1:]).clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
|
211 |
+
cam_poses_input = cam_poses.reshape(T, self.opt.num_views, *cam_poses.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *cam_poses.shape[1:]).clone()
|
212 |
+
|
213 |
+
# data augmentation
|
214 |
+
if self.training:
|
215 |
+
images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
|
216 |
+
cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
|
217 |
+
|
218 |
+
# apply random grid distortion to simulate 3D inconsistency
|
219 |
+
if random.random() < self.opt.prob_grid_distortion:
|
220 |
+
for t in range(T):
|
221 |
+
images_input[t, 1:] = grid_distortion(images_input[t, 1:])
|
222 |
+
# apply camera jittering (only to input!)
|
223 |
+
if random.random() < self.opt.prob_cam_jitter:
|
224 |
+
for t in range(T):
|
225 |
+
cam_poses_input[t, 1:] = orbit_camera_jitter(cam_poses_input[t, 1:])
|
226 |
+
|
227 |
+
images_input = images_input.reshape(-1, *images_input.shape[2:])
|
228 |
+
cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
|
229 |
+
|
230 |
+
# masking other views
|
231 |
+
images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
|
232 |
+
|
233 |
+
images_input_interp = interpolate_tensors(images_input)
|
234 |
+
|
235 |
+
images_input[1:-1, :] = images_input_interp[1:-1, :]
|
236 |
+
images_input = images_input.reshape(-1, *images_input.shape[2:])
|
237 |
+
|
238 |
+
cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
|
239 |
+
cam_poses_input[1:, 1:] = cam_poses_input[0:1, 1:]
|
240 |
+
cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
|
241 |
+
|
242 |
+
images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
243 |
+
|
244 |
+
# resize render ground-truth images, range still in [0, 1]
|
245 |
+
results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
|
246 |
+
results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
|
247 |
+
|
248 |
+
# build rays for input views
|
249 |
+
rays_embeddings = []
|
250 |
+
for i in range(self.opt.num_input_views * T):
|
251 |
+
rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
|
252 |
+
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
|
253 |
+
rays_embeddings.append(rays_plucker)
|
254 |
+
|
255 |
+
|
256 |
+
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
|
257 |
+
|
258 |
+
final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
|
259 |
+
results['input'] = final_input
|
260 |
+
|
261 |
+
# opengl to colmap camera for gaussian renderer
|
262 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
263 |
+
|
264 |
+
# cameras needed by gaussian rasterizer
|
265 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
266 |
+
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
|
267 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
268 |
+
|
269 |
+
results['cam_view'] = cam_view
|
270 |
+
results['cam_view_proj'] = cam_view_proj
|
271 |
+
results['cam_pos'] = cam_pos
|
272 |
+
|
273 |
+
return results
|
274 |
+
|
275 |
+
def __getitem__(self, idx):
|
276 |
+
while True:
|
277 |
+
try:
|
278 |
+
results = self._get_batch(idx)
|
279 |
+
break
|
280 |
+
except Exception as e:
|
281 |
+
# print(f"{e}")
|
282 |
+
idx = random.randint(0, len(self.items) - 1)
|
283 |
+
return results
|
core/unet.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from typing import Tuple, Literal
|
22 |
+
from functools import partial
|
23 |
+
|
24 |
+
from core.attention import MemEffAttention
|
25 |
+
|
26 |
+
|
27 |
+
class MVAttention(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
dim: int,
|
31 |
+
num_heads: int = 8,
|
32 |
+
qkv_bias: bool = False,
|
33 |
+
proj_bias: bool = True,
|
34 |
+
attn_drop: float = 0.0,
|
35 |
+
proj_drop: float = 0.0,
|
36 |
+
groups: int = 32,
|
37 |
+
eps: float = 1e-5,
|
38 |
+
residual: bool = True,
|
39 |
+
skip_scale: float = 1,
|
40 |
+
num_views: int = 4,
|
41 |
+
num_frames: int = 8
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.residual = residual
|
46 |
+
self.skip_scale = skip_scale
|
47 |
+
self.num_views = num_views
|
48 |
+
self.num_frames = num_frames
|
49 |
+
|
50 |
+
self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
|
51 |
+
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
# x: [B*T*V, C, H, W]
|
55 |
+
BTV, C, H, W = x.shape
|
56 |
+
BT = BTV // self.num_views # assert BV % self.num_views == 0
|
57 |
+
|
58 |
+
res = x
|
59 |
+
x = self.norm(x)
|
60 |
+
|
61 |
+
x = x.reshape(BT, self.num_views, C, H, W).permute(0, 1, 3, 4, 2).contiguous().reshape(BT, -1, C).contiguous()
|
62 |
+
x = self.attn(x)
|
63 |
+
x = x.reshape(BT, self.num_views, H, W, C).permute(0, 1, 4, 2, 3).contiguous().reshape(BTV, C, H, W).contiguous()
|
64 |
+
|
65 |
+
if self.residual:
|
66 |
+
x = (x + res) * self.skip_scale
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
class TempAttention(nn.Module):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
dim: int,
|
74 |
+
num_heads: int = 8,
|
75 |
+
qkv_bias: bool = False,
|
76 |
+
proj_bias: bool = True,
|
77 |
+
attn_drop: float = 0.0,
|
78 |
+
proj_drop: float = 0.0,
|
79 |
+
groups: int = 32,
|
80 |
+
eps: float = 1e-5,
|
81 |
+
residual: bool = True,
|
82 |
+
skip_scale: float = 1,
|
83 |
+
num_views: int = 4,
|
84 |
+
num_frames: int = 8
|
85 |
+
):
|
86 |
+
super().__init__()
|
87 |
+
|
88 |
+
self.residual = residual
|
89 |
+
self.skip_scale = skip_scale
|
90 |
+
self.num_views = num_views
|
91 |
+
self.num_frames = num_frames
|
92 |
+
|
93 |
+
self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
|
94 |
+
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
# x: [B*T*V, C, H, W]
|
98 |
+
BTV, C, H, W = x.shape
|
99 |
+
BV = BTV // self.num_frames # assert BV % self.num_views == 0
|
100 |
+
B = BV // self.num_views
|
101 |
+
|
102 |
+
res = x
|
103 |
+
x = self.norm(x)
|
104 |
+
|
105 |
+
# BTV -> BVT
|
106 |
+
x = x.reshape(B, self.num_frames, self.num_views, C, H, W).permute(0, 2, 1, 3, 4, 5).contiguous()
|
107 |
+
|
108 |
+
x = x.reshape(BV, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).contiguous().reshape(BV, -1, C).contiguous().contiguous()
|
109 |
+
x = self.attn(x)
|
110 |
+
x = x.reshape(BV, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).contiguous().reshape(BTV, C, H, W).contiguous().contiguous()
|
111 |
+
|
112 |
+
# BVT -> BTV
|
113 |
+
x = x.reshape(B, self.num_views, self.num_frames, C, H, W).permute(0, 2, 1, 3, 4, 5).contiguous().reshape(BTV, C, H, W).contiguous()
|
114 |
+
|
115 |
+
if self.residual:
|
116 |
+
x = (x + res) * self.skip_scale
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class ResnetBlock(nn.Module):
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
in_channels: int,
|
124 |
+
out_channels: int,
|
125 |
+
resample: Literal['default', 'up', 'down'] = 'default',
|
126 |
+
groups: int = 32,
|
127 |
+
eps: float = 1e-5,
|
128 |
+
skip_scale: float = 1, # multiplied to output
|
129 |
+
):
|
130 |
+
super().__init__()
|
131 |
+
|
132 |
+
self.in_channels = in_channels
|
133 |
+
self.out_channels = out_channels
|
134 |
+
self.skip_scale = skip_scale
|
135 |
+
|
136 |
+
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
137 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
138 |
+
|
139 |
+
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
140 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
141 |
+
|
142 |
+
self.act = F.silu
|
143 |
+
|
144 |
+
self.resample = None
|
145 |
+
if resample == 'up':
|
146 |
+
self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
147 |
+
elif resample == 'down':
|
148 |
+
self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
|
149 |
+
|
150 |
+
self.shortcut = nn.Identity()
|
151 |
+
if self.in_channels != self.out_channels:
|
152 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
|
153 |
+
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
res = x
|
157 |
+
|
158 |
+
x = self.norm1(x)
|
159 |
+
x = self.act(x)
|
160 |
+
|
161 |
+
if self.resample:
|
162 |
+
res = self.resample(res)
|
163 |
+
x = self.resample(x)
|
164 |
+
|
165 |
+
x = self.conv1(x)
|
166 |
+
x = self.norm2(x)
|
167 |
+
x = self.act(x)
|
168 |
+
x = self.conv2(x)
|
169 |
+
|
170 |
+
x = (x + self.shortcut(res)) * self.skip_scale
|
171 |
+
|
172 |
+
return x
|
173 |
+
|
174 |
+
class DownBlock(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
in_channels: int,
|
178 |
+
out_channels: int,
|
179 |
+
num_layers: int = 1,
|
180 |
+
downsample: bool = True,
|
181 |
+
attention: bool = True,
|
182 |
+
attention_heads: int = 16,
|
183 |
+
skip_scale: float = 1,
|
184 |
+
num_views: int = 4,
|
185 |
+
num_frames: int = 8,
|
186 |
+
use_temp_attn=True
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
nets = []
|
191 |
+
attns = []
|
192 |
+
t_attns = []
|
193 |
+
for i in range(num_layers):
|
194 |
+
in_channels = in_channels if i == 0 else out_channels
|
195 |
+
nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
|
196 |
+
if attention:
|
197 |
+
attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames))
|
198 |
+
t_attns.append(TempAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames) if use_temp_attn else None)
|
199 |
+
else:
|
200 |
+
attns.append(None)
|
201 |
+
t_attns.append(None)
|
202 |
+
self.nets = nn.ModuleList(nets)
|
203 |
+
self.attns = nn.ModuleList(attns)
|
204 |
+
self.t_attns = nn.ModuleList(t_attns)
|
205 |
+
|
206 |
+
self.downsample = None
|
207 |
+
if downsample:
|
208 |
+
self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
xs = []
|
212 |
+
|
213 |
+
for attn, t_attn, net in zip(self.attns, self.t_attns, self.nets):
|
214 |
+
x = net(x)
|
215 |
+
if attn:
|
216 |
+
x = attn(x)
|
217 |
+
if t_attn:
|
218 |
+
x = t_attn(x)
|
219 |
+
xs.append(x)
|
220 |
+
|
221 |
+
if self.downsample:
|
222 |
+
x = self.downsample(x)
|
223 |
+
xs.append(x)
|
224 |
+
|
225 |
+
return x, xs
|
226 |
+
|
227 |
+
|
228 |
+
class MidBlock(nn.Module):
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
in_channels: int,
|
232 |
+
num_layers: int = 1,
|
233 |
+
attention: bool = True,
|
234 |
+
attention_heads: int = 16,
|
235 |
+
skip_scale: float = 1,
|
236 |
+
num_views: int = 4,
|
237 |
+
num_frames: int = 8,
|
238 |
+
use_temp_attn=True
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
|
242 |
+
nets = []
|
243 |
+
attns = []
|
244 |
+
t_attns = []
|
245 |
+
# first layer
|
246 |
+
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
|
247 |
+
# more layers
|
248 |
+
for i in range(num_layers):
|
249 |
+
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
|
250 |
+
if attention:
|
251 |
+
attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames))
|
252 |
+
t_attns.append(TempAttention(in_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames) if use_temp_attn else None)
|
253 |
+
else:
|
254 |
+
attns.append(None)
|
255 |
+
t_attns.append(None)
|
256 |
+
self.nets = nn.ModuleList(nets)
|
257 |
+
self.attns = nn.ModuleList(attns)
|
258 |
+
self.t_attns = nn.ModuleList(t_attns)
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
x = self.nets[0](x)
|
262 |
+
for attn, t_attn,net in zip(self.attns, self.t_attns, self.nets[1:]):
|
263 |
+
if attn:
|
264 |
+
x = attn(x)
|
265 |
+
if t_attn:
|
266 |
+
x = t_attn(x)
|
267 |
+
x = net(x)
|
268 |
+
return x
|
269 |
+
|
270 |
+
|
271 |
+
class UpBlock(nn.Module):
|
272 |
+
def __init__(
|
273 |
+
self,
|
274 |
+
in_channels: int,
|
275 |
+
prev_out_channels: int,
|
276 |
+
out_channels: int,
|
277 |
+
num_layers: int = 1,
|
278 |
+
upsample: bool = True,
|
279 |
+
attention: bool = True,
|
280 |
+
attention_heads: int = 16,
|
281 |
+
skip_scale: float = 1,
|
282 |
+
num_views: int = 4,
|
283 |
+
num_frames: int = 8,
|
284 |
+
use_temp_attn=True
|
285 |
+
):
|
286 |
+
super().__init__()
|
287 |
+
|
288 |
+
nets = []
|
289 |
+
attns = []
|
290 |
+
t_attns = []
|
291 |
+
for i in range(num_layers):
|
292 |
+
cin = in_channels if i == 0 else out_channels
|
293 |
+
cskip = prev_out_channels if (i == num_layers - 1) else out_channels
|
294 |
+
|
295 |
+
nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
|
296 |
+
if attention:
|
297 |
+
attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames))
|
298 |
+
t_attns.append(TempAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames) if use_temp_attn else None)
|
299 |
+
else:
|
300 |
+
attns.append(None)
|
301 |
+
t_attns.append(None)
|
302 |
+
self.nets = nn.ModuleList(nets)
|
303 |
+
self.attns = nn.ModuleList(attns)
|
304 |
+
self.t_attns = nn.ModuleList(t_attns)
|
305 |
+
|
306 |
+
self.upsample = None
|
307 |
+
if upsample:
|
308 |
+
self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
309 |
+
|
310 |
+
def forward(self, x, xs):
|
311 |
+
|
312 |
+
for attn, t_attn, net in zip(self.attns, self.t_attns, self.nets):
|
313 |
+
res_x = xs[-1]
|
314 |
+
xs = xs[:-1]
|
315 |
+
x = torch.cat([x, res_x], dim=1)
|
316 |
+
x = net(x)
|
317 |
+
if attn:
|
318 |
+
x = attn(x)
|
319 |
+
if t_attn:
|
320 |
+
x = t_attn(x)
|
321 |
+
|
322 |
+
if self.upsample:
|
323 |
+
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
|
324 |
+
x = self.upsample(x)
|
325 |
+
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
# it could be asymmetric!
|
330 |
+
class UNet(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
in_channels: int = 3,
|
334 |
+
out_channels: int = 3,
|
335 |
+
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
|
336 |
+
down_attention: Tuple[bool, ...] = (False, False, False, True, True),
|
337 |
+
mid_attention: bool = True,
|
338 |
+
up_channels: Tuple[int, ...] = (1024, 512, 256),
|
339 |
+
up_attention: Tuple[bool, ...] = (True, True, False),
|
340 |
+
layers_per_block: int = 2,
|
341 |
+
skip_scale: float = np.sqrt(0.5),
|
342 |
+
num_views: int = 4,
|
343 |
+
num_frames: int = 8,
|
344 |
+
use_temp_attn: bool = True
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
|
348 |
+
# first
|
349 |
+
self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
|
350 |
+
|
351 |
+
# down
|
352 |
+
down_blocks = []
|
353 |
+
cout = down_channels[0]
|
354 |
+
for i in range(len(down_channels)):
|
355 |
+
cin = cout
|
356 |
+
cout = down_channels[i]
|
357 |
+
|
358 |
+
down_blocks.append(DownBlock(
|
359 |
+
cin, cout,
|
360 |
+
num_layers=layers_per_block,
|
361 |
+
downsample=(i != len(down_channels) - 1), # not final layer
|
362 |
+
attention=down_attention[i],
|
363 |
+
skip_scale=skip_scale,
|
364 |
+
num_views=num_views,
|
365 |
+
num_frames=num_frames,
|
366 |
+
use_temp_attn=use_temp_attn
|
367 |
+
))
|
368 |
+
self.down_blocks = nn.ModuleList(down_blocks)
|
369 |
+
|
370 |
+
# mid
|
371 |
+
self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames, use_temp_attn=use_temp_attn)
|
372 |
+
|
373 |
+
# up
|
374 |
+
up_blocks = []
|
375 |
+
cout = up_channels[0]
|
376 |
+
for i in range(len(up_channels)):
|
377 |
+
cin = cout
|
378 |
+
cout = up_channels[i]
|
379 |
+
cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric
|
380 |
+
|
381 |
+
up_blocks.append(UpBlock(
|
382 |
+
cin, cskip, cout,
|
383 |
+
num_layers=layers_per_block + 1, # one more layer for up
|
384 |
+
upsample=(i != len(up_channels) - 1), # not final layer
|
385 |
+
attention=up_attention[i],
|
386 |
+
skip_scale=skip_scale,
|
387 |
+
num_views=num_views,
|
388 |
+
num_frames=num_frames,
|
389 |
+
use_temp_attn=use_temp_attn
|
390 |
+
))
|
391 |
+
self.up_blocks = nn.ModuleList(up_blocks)
|
392 |
+
|
393 |
+
# last
|
394 |
+
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
|
395 |
+
self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
|
396 |
+
|
397 |
+
|
398 |
+
def forward(self, x, return_mid_feature=False):
|
399 |
+
# x: [B, Cin, H, W]
|
400 |
+
|
401 |
+
# first
|
402 |
+
x = self.conv_in(x)
|
403 |
+
|
404 |
+
# down
|
405 |
+
xss = [x]
|
406 |
+
for block in self.down_blocks:
|
407 |
+
x, xs = block(x)
|
408 |
+
xss.extend(xs)
|
409 |
+
|
410 |
+
# mid
|
411 |
+
x = self.mid_block(x)
|
412 |
+
mid_feature = (x, xss)
|
413 |
+
|
414 |
+
# up
|
415 |
+
for block in self.up_blocks:
|
416 |
+
xs = xss[-len(block.nets):]
|
417 |
+
xss = xss[:-len(block.nets)]
|
418 |
+
x = block(x, xs)
|
419 |
+
|
420 |
+
# last
|
421 |
+
x = self.norm_out(x)
|
422 |
+
x = F.silu(x)
|
423 |
+
x = self.conv_out(x) # [B, Cout, H', W']
|
424 |
+
|
425 |
+
if return_mid_feature:
|
426 |
+
return x, *mid_feature
|
427 |
+
|
428 |
+
return x
|
core/utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import roma
|
8 |
+
from kiui.op import safe_normalize
|
9 |
+
|
10 |
+
def get_rays(pose, h, w, fovy, opengl=True):
|
11 |
+
|
12 |
+
x, y = torch.meshgrid(
|
13 |
+
torch.arange(w, device=pose.device),
|
14 |
+
torch.arange(h, device=pose.device),
|
15 |
+
indexing="xy",
|
16 |
+
)
|
17 |
+
x = x.flatten()
|
18 |
+
y = y.flatten()
|
19 |
+
|
20 |
+
cx = w * 0.5
|
21 |
+
cy = h * 0.5
|
22 |
+
|
23 |
+
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
|
24 |
+
|
25 |
+
camera_dirs = F.pad(
|
26 |
+
torch.stack(
|
27 |
+
[
|
28 |
+
(x - cx + 0.5) / focal,
|
29 |
+
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
|
30 |
+
],
|
31 |
+
dim=-1,
|
32 |
+
),
|
33 |
+
(0, 1),
|
34 |
+
value=(-1.0 if opengl else 1.0),
|
35 |
+
) # [hw, 3]
|
36 |
+
|
37 |
+
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
|
38 |
+
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
|
39 |
+
|
40 |
+
rays_o = rays_o.view(h, w, 3)
|
41 |
+
rays_d = safe_normalize(rays_d).view(h, w, 3)
|
42 |
+
|
43 |
+
return rays_o, rays_d
|
44 |
+
|
45 |
+
def orbit_camera_jitter(poses, strength=0.1):
|
46 |
+
# poses: [B, 4, 4], assume orbit camera in opengl format
|
47 |
+
# random orbital rotate
|
48 |
+
|
49 |
+
B = poses.shape[0]
|
50 |
+
rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
51 |
+
rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
52 |
+
|
53 |
+
rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
|
54 |
+
R = rot @ poses[:, :3, :3]
|
55 |
+
T = rot @ poses[:, :3, 3:]
|
56 |
+
|
57 |
+
new_poses = poses.clone()
|
58 |
+
new_poses[:, :3, :3] = R
|
59 |
+
new_poses[:, :3, 3:] = T
|
60 |
+
|
61 |
+
return new_poses
|
62 |
+
|
63 |
+
def grid_distortion(images, strength=0.5):
|
64 |
+
# images: [B, C, H, W]
|
65 |
+
# num_steps: int, grid resolution for distortion
|
66 |
+
# strength: float in [0, 1], strength of distortion
|
67 |
+
|
68 |
+
B, C, H, W = images.shape
|
69 |
+
|
70 |
+
num_steps = np.random.randint(8, 17)
|
71 |
+
grid_steps = torch.linspace(-1, 1, num_steps)
|
72 |
+
|
73 |
+
# have to loop batch...
|
74 |
+
grids = []
|
75 |
+
for b in range(B):
|
76 |
+
# construct displacement
|
77 |
+
x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
|
78 |
+
x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
|
79 |
+
x_steps = (x_steps * W).long() # [num_steps]
|
80 |
+
x_steps[0] = 0
|
81 |
+
x_steps[-1] = W
|
82 |
+
xs = []
|
83 |
+
for i in range(num_steps - 1):
|
84 |
+
xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
|
85 |
+
xs = torch.cat(xs, dim=0) # [W]
|
86 |
+
|
87 |
+
y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
|
88 |
+
y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
|
89 |
+
y_steps = (y_steps * H).long() # [num_steps]
|
90 |
+
y_steps[0] = 0
|
91 |
+
y_steps[-1] = H
|
92 |
+
ys = []
|
93 |
+
for i in range(num_steps - 1):
|
94 |
+
ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
|
95 |
+
ys = torch.cat(ys, dim=0) # [H]
|
96 |
+
|
97 |
+
# construct grid
|
98 |
+
grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
|
99 |
+
grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
|
100 |
+
|
101 |
+
grids.append(grid)
|
102 |
+
|
103 |
+
grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
|
104 |
+
|
105 |
+
# grid sample
|
106 |
+
images = F.grid_sample(images, grids, align_corners=False)
|
107 |
+
|
108 |
+
return images
|
109 |
+
|
data_test/000000_fg.mp4
ADDED
Binary file (226 kB). View file
|
|
data_test/000070_fg.mp4
ADDED
Binary file (291 kB). View file
|
|
data_test/000370_fg.mp4
ADDED
Binary file (273 kB). View file
|
|
data_test/blooming_rose_fg.mp4
ADDED
Binary file (83.5 kB). View file
|
|
data_test/cat_king_fg.mp4
ADDED
Binary file (266 kB). View file
|
|
data_test/dancing_robot_fg.mp4
ADDED
Binary file (64.2 kB). View file
|
|
data_test/lifting1_fg.mp4
ADDED
Binary file (311 kB). View file
|
|
data_test/monster-with-melting-candle_fg.mp4
ADDED
Binary file (365 kB). View file
|
|
data_test/otter-on-surfboard_fg.mp4
ADDED
Binary file (305 kB). View file
|
|
data_test/sighing_frog_fg.mp4
ADDED
Binary file (90.1 kB). View file
|
|
environment.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: l4gm
|
2 |
+
channels:
|
3 |
+
- pyg
|
4 |
+
- nvidia/label/cuda-12.1.0
|
5 |
+
- pytorch
|
6 |
+
- conda-forge
|
7 |
+
- xformers
|
8 |
+
dependencies:
|
9 |
+
- python=3.10
|
10 |
+
- pytorch=2.5.1
|
11 |
+
- pytorch-cuda=12.1
|
12 |
+
- torchvision
|
13 |
+
- xformers
|
14 |
+
- cuda
|
15 |
+
- cuda-nvcc
|
16 |
+
- numpy<2.0.0
|
17 |
+
- scipy
|
18 |
+
- rich
|
19 |
+
- pip
|
20 |
+
- setuptools
|
21 |
+
- ninja
|
22 |
+
- tqdm
|
23 |
+
- ray-default
|
24 |
+
- flatten-dict
|
25 |
+
- gcc_linux-64=11
|
26 |
+
- gxx_linux-64=11
|
27 |
+
- opencv
|
28 |
+
- transformers
|
29 |
+
- einops
|
30 |
+
- pip:
|
31 |
+
- -r requirements.txt
|
32 |
+
- git+https://github.com/nerfstudio-project/gsplat.git@v1.4.0
|
infer_3d.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import imageio.v3 as iio
|
17 |
+
import cv2
|
18 |
+
import numpy as np
|
19 |
+
import imageio
|
20 |
+
|
21 |
+
import os
|
22 |
+
import tyro
|
23 |
+
import glob
|
24 |
+
import imageio
|
25 |
+
import numpy as np
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torchvision.transforms.functional as TF
|
30 |
+
from safetensors.torch import load_file
|
31 |
+
import time
|
32 |
+
|
33 |
+
import kiui
|
34 |
+
from kiui.cam import orbit_camera
|
35 |
+
|
36 |
+
from core.options import AllConfigs, Options
|
37 |
+
from core.models import LGM
|
38 |
+
from mvdream.pipeline_mvdream import MVDreamPipeline
|
39 |
+
|
40 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
41 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
42 |
+
|
43 |
+
opt = tyro.cli(AllConfigs)
|
44 |
+
|
45 |
+
# model
|
46 |
+
model = LGM(opt)
|
47 |
+
|
48 |
+
# resume pretrained checkpoint
|
49 |
+
if opt.resume is not None:
|
50 |
+
if opt.resume.endswith('safetensors'):
|
51 |
+
ckpt = load_file(opt.resume, device='cpu')
|
52 |
+
else:
|
53 |
+
ckpt = torch.load(opt.resume, map_location='cpu')
|
54 |
+
model.load_state_dict(ckpt, strict=False)
|
55 |
+
print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
56 |
+
else:
|
57 |
+
print(f'[WARN] model randomly initialized, are you sure?')
|
58 |
+
|
59 |
+
# device
|
60 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
61 |
+
model = model.half().to(device)
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
bg_color = torch.tensor([255, 255, 255], dtype=torch.float32, device="cuda") / 255.
|
65 |
+
|
66 |
+
rays_embeddings = model.prepare_default_rays(device)
|
67 |
+
|
68 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
69 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
|
70 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
71 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
72 |
+
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
73 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
74 |
+
proj_matrix[2, 3] = 1
|
75 |
+
|
76 |
+
# load image dream
|
77 |
+
pipe = MVDreamPipeline.from_pretrained(
|
78 |
+
"ashawkey/imagedream-ipmv-diffusers", # remote weights
|
79 |
+
torch_dtype=torch.float16,
|
80 |
+
trust_remote_code=True,
|
81 |
+
# local_files_only=True,
|
82 |
+
)
|
83 |
+
pipe = pipe.to(device)
|
84 |
+
|
85 |
+
|
86 |
+
def process_eval_video(video_path, T):
|
87 |
+
frames = iio.imread(video_path)
|
88 |
+
frames = [frames[x] for x in range(frames.shape[0])]
|
89 |
+
V = opt.num_input_views
|
90 |
+
img_TV = []
|
91 |
+
for t in range(T):
|
92 |
+
|
93 |
+
img = frames[t]
|
94 |
+
|
95 |
+
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
|
96 |
+
img = img.astype(np.float32) / 255.0
|
97 |
+
|
98 |
+
img_V = []
|
99 |
+
for v in range(V):
|
100 |
+
img_V.append(img)
|
101 |
+
img_TV.append(np.stack(img_V, axis=0))
|
102 |
+
|
103 |
+
return np.stack(img_TV, axis=0)
|
104 |
+
|
105 |
+
|
106 |
+
# process function
|
107 |
+
def process(opt: Options, path):
|
108 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
109 |
+
print(f'[INFO] Processing {path} --> {name}')
|
110 |
+
os.makedirs(opt.workspace, exist_ok=True)
|
111 |
+
|
112 |
+
ref_video = process_eval_video(path, opt.num_frames) # [TV, 512, 512, 3]
|
113 |
+
|
114 |
+
|
115 |
+
end_time = time.time()
|
116 |
+
|
117 |
+
cv2.imwrite(os.path.join(opt.workspace, f'{name}_orig.png'), ref_video[0,0][..., ::-1] * 255)
|
118 |
+
|
119 |
+
mv_image = pipe('', ref_video[0,0], guidance_scale=5, num_inference_steps=30, elevation=0)
|
120 |
+
for v in range(4):
|
121 |
+
cv2.imwrite(os.path.join(opt.workspace, f'{name}_mv_{(v-1)%4:03d}.png'), mv_image[v][..., ::-1] * 255)
|
122 |
+
mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
|
123 |
+
|
124 |
+
|
125 |
+
# generate gaussians
|
126 |
+
input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
127 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
128 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
129 |
+
|
130 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
134 |
+
gaussians_all_frame = model.forward_gaussians(input_image)
|
135 |
+
|
136 |
+
B, T, V = 1, gaussians_all_frame.shape[0]//opt.batch_size, opt.num_views
|
137 |
+
gaussians_all_frame = gaussians_all_frame.reshape(B, T, *gaussians_all_frame.shape[1:])
|
138 |
+
|
139 |
+
# align azimuth
|
140 |
+
best_azi = 0
|
141 |
+
best_diff = 1e8
|
142 |
+
for v, azi in enumerate(np.arange(-180, 180, 1)):
|
143 |
+
gaussians = gaussians_all_frame[:, 0]
|
144 |
+
|
145 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
146 |
+
|
147 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
148 |
+
|
149 |
+
# cameras needed by gaussian rasterizer
|
150 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
151 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
152 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
153 |
+
|
154 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)
|
155 |
+
image = result['image']
|
156 |
+
alpha = result['alpha']
|
157 |
+
|
158 |
+
image = image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
|
159 |
+
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
|
160 |
+
|
161 |
+
diff = np.mean((image- ref_video[0,0]) ** 2)
|
162 |
+
|
163 |
+
if diff < best_diff:
|
164 |
+
best_diff = diff
|
165 |
+
best_azi = azi
|
166 |
+
|
167 |
+
print("Best aligned azimuth: ", best_azi)
|
168 |
+
|
169 |
+
mv_image = []
|
170 |
+
for v, azi in enumerate(np.arange(0, 360, 90)):
|
171 |
+
gaussians = gaussians_all_frame[:, 0]
|
172 |
+
|
173 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
174 |
+
|
175 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
176 |
+
|
177 |
+
# cameras needed by gaussian rasterizer
|
178 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
179 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
180 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
181 |
+
|
182 |
+
scale = 1
|
183 |
+
|
184 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)
|
185 |
+
image = result['image']
|
186 |
+
alpha = result['alpha']
|
187 |
+
|
188 |
+
imageio.imwrite(os.path.join(opt.workspace, f'{name}_{v:03d}.png'), (image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
189 |
+
|
190 |
+
if azi in [0, 90, 180, 270]:
|
191 |
+
rendered_image = image.squeeze(1)
|
192 |
+
rendered_image = F.interpolate(rendered_image, (256, 256))
|
193 |
+
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
|
194 |
+
mv_image.append(rendered_image)
|
195 |
+
mv_image = np.concatenate(mv_image, axis=0)
|
196 |
+
print(f"Generate 3D takes {time.time()-end_time} s")
|
197 |
+
|
198 |
+
images = []
|
199 |
+
azimuth = np.arange(0, 360, 4, dtype=np.int32)
|
200 |
+
elevation = 0
|
201 |
+
for azi in azimuth:
|
202 |
+
gaussians = gaussians_all_frame[:, 0]
|
203 |
+
|
204 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
205 |
+
|
206 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
207 |
+
|
208 |
+
# cameras needed by gaussian rasterizer
|
209 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
210 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
211 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
212 |
+
|
213 |
+
scale = 1
|
214 |
+
|
215 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
|
216 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
217 |
+
|
218 |
+
images = np.concatenate(images, axis=0)
|
219 |
+
imageio.mimwrite(os.path.join(opt.workspace, f'{name}.mp4'), images, fps=30)
|
220 |
+
|
221 |
+
|
222 |
+
torch.cuda.empty_cache()
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
assert opt.test_path is not None
|
227 |
+
if os.path.isdir(opt.test_path):
|
228 |
+
file_paths = glob.glob(os.path.join(opt.test_path, "*"))
|
229 |
+
else:
|
230 |
+
file_paths = [opt.test_path]
|
231 |
+
|
232 |
+
for path in sorted(file_paths):
|
233 |
+
process(opt, path)
|
infer_4d.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import imageio.v3 as iio
|
17 |
+
import cv2
|
18 |
+
import numpy as np
|
19 |
+
import imageio
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
import os
|
23 |
+
import tyro
|
24 |
+
import glob
|
25 |
+
import imageio
|
26 |
+
import numpy as np
|
27 |
+
import tqdm
|
28 |
+
import torch
|
29 |
+
import torch.nn as nn
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import torchvision.transforms.functional as TF
|
32 |
+
from safetensors.torch import load_file
|
33 |
+
|
34 |
+
import kiui
|
35 |
+
from kiui.cam import orbit_camera
|
36 |
+
|
37 |
+
from core.options import AllConfigs, Options
|
38 |
+
from core.models import LGM
|
39 |
+
import time
|
40 |
+
|
41 |
+
from core.utils import get_rays, grid_distortion, orbit_camera_jitter
|
42 |
+
|
43 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
44 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
45 |
+
|
46 |
+
|
47 |
+
USE_INTERPOLATION = True # set to false to disable interpolation
|
48 |
+
MAX_RUNS = 100
|
49 |
+
VIDEO_FPS = 30
|
50 |
+
|
51 |
+
opt = tyro.cli(AllConfigs)
|
52 |
+
|
53 |
+
# model
|
54 |
+
model = LGM(opt)
|
55 |
+
|
56 |
+
# resume pretrained checkpoint
|
57 |
+
if opt.resume is not None:
|
58 |
+
if opt.resume.endswith('safetensors'):
|
59 |
+
ckpt = load_file(opt.resume, device='cpu')
|
60 |
+
else:
|
61 |
+
ckpt = torch.load(opt.resume, map_location='cpu')
|
62 |
+
model.load_state_dict(ckpt, strict=False)
|
63 |
+
print(f'[INFO] Loaded checkpoint from {opt.resume}')
|
64 |
+
else:
|
65 |
+
print(f'[WARN] model randomly initialized, are you sure?')
|
66 |
+
|
67 |
+
# device
|
68 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
69 |
+
model = model.half().to(device)
|
70 |
+
model.eval()
|
71 |
+
|
72 |
+
bg_color = torch.tensor([255, 255, 255], dtype=torch.float32, device="cuda") / 255.
|
73 |
+
|
74 |
+
|
75 |
+
rays_embeddings = model.prepare_default_rays(device)
|
76 |
+
rays_embeddings = torch.cat([rays_embeddings for _ in range(opt.num_frames)])
|
77 |
+
|
78 |
+
|
79 |
+
interp_opt = deepcopy(opt)
|
80 |
+
interp_opt.num_frames = 4
|
81 |
+
model_interp = LGM(interp_opt)
|
82 |
+
# resume pretrained checkpoint
|
83 |
+
if interp_opt.interpresume is not None:
|
84 |
+
if interp_opt.interpresume.endswith('safetensors'):
|
85 |
+
ckpt = load_file(interp_opt.interpresume, device='cpu')
|
86 |
+
else:
|
87 |
+
ckpt = torch.load(interp_opt.interpresume, map_location='cpu')
|
88 |
+
model_interp.load_state_dict(ckpt, strict=False)
|
89 |
+
print(f'[INFO] Loaded Interp checkpoint from {interp_opt.interpresume}')
|
90 |
+
else:
|
91 |
+
print(f'[WARN] model_interp randomly initialized, are you sure?')
|
92 |
+
|
93 |
+
# device
|
94 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
95 |
+
model_interp = model_interp.half().to(device)
|
96 |
+
model_interp.eval()
|
97 |
+
|
98 |
+
|
99 |
+
interp_rays_embeddings = model_interp.prepare_default_rays(device)
|
100 |
+
interp_rays_embeddings = torch.cat([interp_rays_embeddings for _ in range(interp_opt.num_frames)])
|
101 |
+
|
102 |
+
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
|
103 |
+
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
|
104 |
+
proj_matrix[0, 0] = 1 / tan_half_fov
|
105 |
+
proj_matrix[1, 1] = 1 / tan_half_fov
|
106 |
+
proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
107 |
+
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
108 |
+
proj_matrix[2, 3] = 1
|
109 |
+
|
110 |
+
def interpolate_tensors(tensor):
|
111 |
+
# Extract the first and last tensors along the first dimension (B)
|
112 |
+
start_tensor = tensor[0] # shape [4, 3, 256, 256]
|
113 |
+
end_tensor = tensor[-1] # shape [4, 3, 256, 256]
|
114 |
+
tensor_interp = deepcopy(tensor)
|
115 |
+
|
116 |
+
# Iterate over the range from 1 to second-last index
|
117 |
+
|
118 |
+
for i in range(1, tensor.shape[0] - 1):
|
119 |
+
# Calculate the weight for interpolation
|
120 |
+
|
121 |
+
weight = (i - 0) / (tensor.shape[0] - 1)
|
122 |
+
# Interpolate between start_tensor and end_tensor
|
123 |
+
tensor_interp[i] = torch.lerp(start_tensor, end_tensor, weight)
|
124 |
+
|
125 |
+
|
126 |
+
return tensor_interp
|
127 |
+
|
128 |
+
def process_eval_video(frames, video_path, T, start_t=0, downsample_rate=1):
|
129 |
+
L = frames.shape[0]
|
130 |
+
vid_name =video_path.split('/')[-1].split('.')[0]
|
131 |
+
total_frames = L//downsample_rate
|
132 |
+
print(f'{start_t} / {total_frames}')
|
133 |
+
frames = [frames[x] for x in range(frames.shape[0])]
|
134 |
+
V = opt.num_input_views
|
135 |
+
img_TV = []
|
136 |
+
for t in range(T):
|
137 |
+
t += start_t
|
138 |
+
t = min(t, L//downsample_rate-1)
|
139 |
+
t*=downsample_rate
|
140 |
+
|
141 |
+
img = frames[t]
|
142 |
+
|
143 |
+
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
|
144 |
+
img = img.astype(np.float32) / 255.0
|
145 |
+
|
146 |
+
img_V = []
|
147 |
+
for v in range(V):
|
148 |
+
img_V.append(img)
|
149 |
+
img_TV.append(np.stack(img_V, axis=0))
|
150 |
+
|
151 |
+
return np.stack(img_TV, axis=0), L//downsample_rate- start_t
|
152 |
+
|
153 |
+
def load_mv_img(name, img_dir):
|
154 |
+
img_list = []
|
155 |
+
for v in range(4):
|
156 |
+
img = kiui.read_image(os.path.join(img_dir, name + f'_{v:03d}.png'), mode='uint8')
|
157 |
+
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
|
158 |
+
img = img / 255.
|
159 |
+
img_list.append(img)
|
160 |
+
return np.stack(img_list, axis=0)
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
# process function
|
165 |
+
def process(opt: Options, path):
|
166 |
+
name = os.path.splitext(os.path.basename(path))[0]
|
167 |
+
print(f'[INFO] Processing {path} --> {name}')
|
168 |
+
os.makedirs(opt.workspace, exist_ok=True)
|
169 |
+
frames = iio.imread(path)
|
170 |
+
img_dir = opt.workspace
|
171 |
+
mv_image = load_mv_img(name, img_dir)
|
172 |
+
|
173 |
+
print(iio.immeta(path))
|
174 |
+
FPS = int(iio.immeta(path)['fps'])
|
175 |
+
downsample_rate = FPS // 15 if FPS > 15 else 1 # default reconstruction fps 15
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
with torch.inference_mode():
|
180 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
181 |
+
start_t = 0
|
182 |
+
gaussians_all_frame_all_run = []
|
183 |
+
gaussians_all_frame_all_run_w_interp = []
|
184 |
+
for run_idx in range(MAX_RUNS):
|
185 |
+
ref_video, end_t = process_eval_video(frames, path, opt.num_frames, start_t, downsample_rate=downsample_rate)
|
186 |
+
ref_video[:, 1:] = mv_image[None, 1:] # repeat
|
187 |
+
input_image = torch.from_numpy(ref_video).reshape([-1, *ref_video.shape[2:]]).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
188 |
+
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
|
189 |
+
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
190 |
+
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
191 |
+
|
192 |
+
end_time = time.time()
|
193 |
+
|
194 |
+
gaussians_all_frame = model.forward_gaussians(input_image)
|
195 |
+
print(f"Forward pass takes {time.time()-end_time} s")
|
196 |
+
|
197 |
+
B, T, V = 1, gaussians_all_frame.shape[0]//opt.batch_size, opt.num_views
|
198 |
+
gaussians_all_frame = gaussians_all_frame.reshape(B, T, *gaussians_all_frame.shape[1:])
|
199 |
+
|
200 |
+
if run_idx > 0:
|
201 |
+
gaussians_all_frame_wo_inter = gaussians_all_frame[:, 1:max(end_t, 1)]
|
202 |
+
else:
|
203 |
+
gaussians_all_frame_wo_inter = gaussians_all_frame
|
204 |
+
|
205 |
+
if gaussians_all_frame_wo_inter.shape[1] > 0 and USE_INTERPOLATION:
|
206 |
+
# render multiview video
|
207 |
+
render_img_TV = []
|
208 |
+
for t in range(gaussians_all_frame.shape[1]):
|
209 |
+
render_img_V = []
|
210 |
+
for v, azi in enumerate(np.arange(0, 360, 90)):
|
211 |
+
|
212 |
+
gaussians = gaussians_all_frame[:, t]
|
213 |
+
|
214 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
215 |
+
|
216 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
217 |
+
|
218 |
+
# cameras needed by gaussian rasterizer
|
219 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
220 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
221 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
222 |
+
|
223 |
+
rendered_image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
|
224 |
+
rendered_image = rendered_image.squeeze(1)
|
225 |
+
rendered_image = F.interpolate(rendered_image, (256, 256))
|
226 |
+
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy() # B H W C
|
227 |
+
|
228 |
+
render_img_V.append(rendered_image)
|
229 |
+
render_img_V = np.concatenate(render_img_V, axis=0) # V H W C
|
230 |
+
render_img_TV.append(render_img_V)
|
231 |
+
render_img_TV = np.stack(render_img_TV, axis=0) # T V H W C
|
232 |
+
ref_video = np.concatenate([np.stack([ref_video[ttt] for _ in range(opt.interpolate_rate)], 0) for ttt in range(ref_video.shape[0])], 0)
|
233 |
+
|
234 |
+
|
235 |
+
for tt in range(gaussians_all_frame_wo_inter.shape[1] -1 ):
|
236 |
+
|
237 |
+
curr_ref_video = deepcopy( ref_video[ tt * opt.interpolate_rate: tt * opt.interpolate_rate + interp_opt.num_frames ])
|
238 |
+
curr_ref_video[0, 1:] = render_img_TV[tt, 1:]
|
239 |
+
|
240 |
+
curr_ref_video[-1, 1:] = render_img_TV[tt+1, 1:]
|
241 |
+
|
242 |
+
|
243 |
+
curr_ref_video = torch.from_numpy(curr_ref_video).float().to(
|
244 |
+
device) # [4, 3, 256, 256]
|
245 |
+
|
246 |
+
images_input_interp = interpolate_tensors(curr_ref_video)
|
247 |
+
|
248 |
+
curr_ref_video[1:-1, :] = images_input_interp[1:-1, :]
|
249 |
+
|
250 |
+
input_image_interp = curr_ref_video.reshape([-1, *curr_ref_video.shape[2:]]).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
|
251 |
+
input_image_interp = F.interpolate(input_image_interp, size=(interp_opt.input_size, interp_opt.input_size), mode='bilinear',
|
252 |
+
align_corners=False)
|
253 |
+
input_image_interp = TF.normalize(input_image_interp, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
254 |
+
|
255 |
+
input_image_interp = torch.cat([input_image_interp, interp_rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
|
256 |
+
|
257 |
+
end_time = time.time()
|
258 |
+
gaussians_interp_all_frame = model_interp.forward_gaussians(input_image_interp)
|
259 |
+
print(f"Interpolate forward pass takes {time.time()-end_time} s")
|
260 |
+
|
261 |
+
B, T, V = 1, gaussians_interp_all_frame.shape[0] // opt.batch_size, opt.num_views
|
262 |
+
gaussians_interp_all_frame = gaussians_interp_all_frame.reshape(B, T, *gaussians_interp_all_frame.shape[1:])
|
263 |
+
|
264 |
+
if tt > 0:
|
265 |
+
gaussians_interp_all_frame = gaussians_interp_all_frame[:, 1:]
|
266 |
+
|
267 |
+
gaussians_all_frame_all_run_w_interp.append(gaussians_interp_all_frame)
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
gaussians_all_frame_all_run.append(gaussians_all_frame_wo_inter)
|
272 |
+
start_t += opt.num_frames -1
|
273 |
+
|
274 |
+
mv_image = []
|
275 |
+
for v, azi in enumerate(np.arange(0, 360, 90)):
|
276 |
+
gaussians = gaussians_all_frame_wo_inter[:, -1]
|
277 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
278 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
279 |
+
# cameras needed by gaussian rasterizer
|
280 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
281 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
282 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
283 |
+
|
284 |
+
rendered_image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
|
285 |
+
rendered_image = rendered_image.squeeze(1)
|
286 |
+
rendered_image = F.interpolate(rendered_image, (256, 256))
|
287 |
+
rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
|
288 |
+
mv_image.append(rendered_image)
|
289 |
+
mv_image = np.concatenate(mv_image, axis=0)
|
290 |
+
elif gaussians_all_frame_wo_inter.shape[1] > 0:
|
291 |
+
gaussians_all_frame_all_run.append(gaussians_all_frame_wo_inter)
|
292 |
+
start_t += opt.num_frames -1
|
293 |
+
else:
|
294 |
+
break
|
295 |
+
|
296 |
+
gaussians_all_frame_wo_interp = torch.cat(gaussians_all_frame_all_run, dim=1)
|
297 |
+
if USE_INTERPOLATION:
|
298 |
+
gaussians_all_frame_w_interp = torch.cat(gaussians_all_frame_all_run_w_interp, dim=1)
|
299 |
+
|
300 |
+
if USE_INTERPOLATION:
|
301 |
+
zip_dump = zip(["wo_interp", "w_interp"], [gaussians_all_frame_wo_interp, gaussians_all_frame_w_interp])
|
302 |
+
else:
|
303 |
+
zip_dump = zip(["wo_interp"], [gaussians_all_frame_wo_interp])
|
304 |
+
|
305 |
+
for sv_name, gaussians_all_frame in zip_dump:
|
306 |
+
if sv_name == "w_interp":
|
307 |
+
ANIM_FPS = FPS / downsample_rate * gaussians_all_frame_w_interp.shape[1] / gaussians_all_frame_wo_interp.shape[1]
|
308 |
+
else:
|
309 |
+
ANIM_FPS = FPS / downsample_rate
|
310 |
+
print(f"{sv_name} | input video fps: {FPS} | downsample rate: {downsample_rate} | animation fps: {ANIM_FPS} | output video fps: {VIDEO_FPS}")
|
311 |
+
render_img_TV = []
|
312 |
+
for t in range(gaussians_all_frame.shape[1]):
|
313 |
+
render_img_V = []
|
314 |
+
for v, azi in enumerate(np.arange(0, 360, 90)):
|
315 |
+
|
316 |
+
gaussians = gaussians_all_frame[:, t]
|
317 |
+
|
318 |
+
cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
319 |
+
|
320 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
321 |
+
|
322 |
+
# cameras needed by gaussian rasterizer
|
323 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
324 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
325 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
326 |
+
|
327 |
+
result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)
|
328 |
+
image = result['image']
|
329 |
+
alpha = result['alpha']
|
330 |
+
|
331 |
+
render_img_V.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
332 |
+
render_img_V = np.concatenate(render_img_V, axis=2)
|
333 |
+
render_img_TV.append(render_img_V)
|
334 |
+
render_img_TV = np.concatenate(render_img_TV, axis=0)
|
335 |
+
|
336 |
+
|
337 |
+
images = []
|
338 |
+
azimuth = np.arange(0, 360, 1*30/VIDEO_FPS, dtype=np.int32)
|
339 |
+
elevation = 0
|
340 |
+
t = 0
|
341 |
+
delta_t = ANIM_FPS / VIDEO_FPS
|
342 |
+
for azi in azimuth:
|
343 |
+
if azi in [0, 90, 180, 270]:
|
344 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
345 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
346 |
+
|
347 |
+
# cameras needed by gaussian rasterizer
|
348 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
349 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
350 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
351 |
+
|
352 |
+
for _ in range(45):
|
353 |
+
gaussians = gaussians_all_frame[:, int(t) % gaussians_all_frame.shape[1]]
|
354 |
+
t += delta_t
|
355 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
|
356 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
357 |
+
else:
|
358 |
+
gaussians = gaussians_all_frame[:, int(t) % gaussians_all_frame.shape[1]]
|
359 |
+
t += delta_t
|
360 |
+
|
361 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
362 |
+
|
363 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
364 |
+
|
365 |
+
# cameras needed by gaussian rasterizer
|
366 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
367 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
368 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
369 |
+
|
370 |
+
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
|
371 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
372 |
+
|
373 |
+
images = np.concatenate(images, axis=0)
|
374 |
+
|
375 |
+
torch.cuda.empty_cache()
|
376 |
+
|
377 |
+
|
378 |
+
imageio.mimwrite(os.path.join(opt.workspace, f'{sv_name}_{name}_fixed.mp4'), render_img_TV, fps=ANIM_FPS)
|
379 |
+
print("Fixed video saved.")
|
380 |
+
imageio.mimwrite(os.path.join(opt.workspace, f'{sv_name}_{name}.mp4'), images, fps=VIDEO_FPS)
|
381 |
+
print("Stop video saved.")
|
382 |
+
|
383 |
+
|
384 |
+
assert opt.test_path is not None
|
385 |
+
|
386 |
+
if os.path.isdir(opt.test_path):
|
387 |
+
file_paths = glob.glob(os.path.join(opt.test_path, "*"))
|
388 |
+
else:
|
389 |
+
file_paths = [opt.test_path]
|
390 |
+
|
391 |
+
for path in sorted(file_paths):
|
392 |
+
process(opt, path)
|
main.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import tyro
|
17 |
+
import time
|
18 |
+
import random
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from core.options import AllConfigs
|
22 |
+
from core.models import LGM
|
23 |
+
|
24 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
25 |
+
from safetensors.torch import load_file
|
26 |
+
|
27 |
+
import kiui
|
28 |
+
from PIL import Image
|
29 |
+
|
30 |
+
import json
|
31 |
+
import os
|
32 |
+
import numpy as np
|
33 |
+
import imageio
|
34 |
+
|
35 |
+
def main():
|
36 |
+
opt = tyro.cli(AllConfigs)
|
37 |
+
|
38 |
+
accelerator = Accelerator(
|
39 |
+
mixed_precision=opt.mixed_precision,
|
40 |
+
gradient_accumulation_steps=opt.gradient_accumulation_steps,
|
41 |
+
# kwargs_handlers=[ddp_kwargs],
|
42 |
+
)
|
43 |
+
if accelerator.is_main_process:
|
44 |
+
print(opt)
|
45 |
+
|
46 |
+
# model
|
47 |
+
model = LGM(opt)
|
48 |
+
|
49 |
+
epoch_start = 0
|
50 |
+
if os.path.exists(f'{opt.workspace}/model.safetensors') and os.path.exists(f'{opt.workspace}/metadata.json'):
|
51 |
+
opt.resume = f'{opt.workspace}/model.safetensors'
|
52 |
+
with open(f'{opt.workspace}/metadata.json', 'r') as f:
|
53 |
+
dc = json.load(f)
|
54 |
+
epoch_start = dc['epoch'] + 1
|
55 |
+
|
56 |
+
|
57 |
+
# resume
|
58 |
+
if opt.resume is not None and opt.resume != 'None':
|
59 |
+
if opt.resume.endswith('safetensors'):
|
60 |
+
ckpt = load_file(opt.resume, device='cpu')
|
61 |
+
else:
|
62 |
+
ckpt = torch.load(opt.resume, map_location='cpu')
|
63 |
+
|
64 |
+
# tolerant load (only load matching shapes)
|
65 |
+
# model.load_state_dict(ckpt, strict=False)
|
66 |
+
state_dict = model.state_dict()
|
67 |
+
for k, v in ckpt.items():
|
68 |
+
if k in state_dict:
|
69 |
+
if state_dict[k].shape == v.shape:
|
70 |
+
state_dict[k].copy_(v)
|
71 |
+
else:
|
72 |
+
accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
|
73 |
+
else:
|
74 |
+
accelerator.print(f'[WARN] unexpected param {k}: {v.shape}')
|
75 |
+
|
76 |
+
# data
|
77 |
+
if opt.data_mode == '4d':
|
78 |
+
from core.provider_objaverse_4d import ObjaverseDataset as Dataset
|
79 |
+
elif opt.data_mode == '4d_interp':
|
80 |
+
from core.provider_objaverse_4d_interp import ObjaverseDataset as Dataset
|
81 |
+
else:
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
train_dataset = Dataset(opt, training=True)
|
85 |
+
train_dataloader = torch.utils.data.DataLoader(
|
86 |
+
train_dataset,
|
87 |
+
batch_size=opt.batch_size,
|
88 |
+
shuffle=True,
|
89 |
+
num_workers=opt.num_workers,
|
90 |
+
pin_memory=True,
|
91 |
+
drop_last=True,
|
92 |
+
)
|
93 |
+
|
94 |
+
test_dataset = Dataset(opt, training=False)
|
95 |
+
test_dataloader = torch.utils.data.DataLoader(
|
96 |
+
test_dataset,
|
97 |
+
batch_size=opt.batch_size,
|
98 |
+
shuffle=False,
|
99 |
+
num_workers=0,
|
100 |
+
pin_memory=True,
|
101 |
+
drop_last=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
# optimizer
|
105 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))
|
106 |
+
|
107 |
+
# scheduler (per-iteration)
|
108 |
+
total_steps = opt.num_epochs * len(train_dataloader)
|
109 |
+
pct_start = 3000 / total_steps
|
110 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start)
|
111 |
+
|
112 |
+
if epoch_start > 0:
|
113 |
+
optimizer.load_state_dict(torch.load(os.path.join(opt.workspace, 'optimizer.pth'), map_location='cpu'))
|
114 |
+
scheduler.load_state_dict(torch.load(os.path.join(opt.workspace, 'scheduler.pth')))
|
115 |
+
|
116 |
+
# accelerate
|
117 |
+
model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(
|
118 |
+
model, optimizer, train_dataloader, test_dataloader, scheduler
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
# loop
|
124 |
+
os.makedirs(opt.workspace, exist_ok=True)
|
125 |
+
end_time = time.time()
|
126 |
+
for epoch in range(epoch_start, opt.num_epochs):
|
127 |
+
# train
|
128 |
+
model.train()
|
129 |
+
total_loss = 0
|
130 |
+
total_psnr = 0
|
131 |
+
for i, data in enumerate(train_dataloader):
|
132 |
+
with accelerator.accumulate(model):
|
133 |
+
|
134 |
+
optimizer.zero_grad()
|
135 |
+
|
136 |
+
step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs
|
137 |
+
|
138 |
+
out = model(data, step_ratio)
|
139 |
+
loss = out['loss']
|
140 |
+
psnr = out['psnr']
|
141 |
+
accelerator.backward(loss)
|
142 |
+
|
143 |
+
# gradient clipping
|
144 |
+
if accelerator.sync_gradients:
|
145 |
+
accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip)
|
146 |
+
|
147 |
+
optimizer.step()
|
148 |
+
scheduler.step()
|
149 |
+
|
150 |
+
total_loss += loss.detach()
|
151 |
+
total_psnr += psnr.detach()
|
152 |
+
|
153 |
+
if accelerator.is_main_process:
|
154 |
+
# logging
|
155 |
+
if i % 10 == 0:
|
156 |
+
mem_free, mem_total = torch.cuda.mem_get_info()
|
157 |
+
print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f} time: {time.time() - end_time:.6f}")
|
158 |
+
end_time = time.time()
|
159 |
+
|
160 |
+
# save log images
|
161 |
+
if i % 500 == 0:
|
162 |
+
if '4d' in opt.data_mode:
|
163 |
+
B, T, V = opt.batch_size, opt.num_frames, opt.num_views
|
164 |
+
|
165 |
+
gt_images = data['images_output'].reshape(B, T, V, *data['images_output'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
|
166 |
+
pred_images = out['images_pred'].reshape(B, T, V, *out['images_pred'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
|
167 |
+
|
168 |
+
train_gt_images = []
|
169 |
+
train_pred_images = []
|
170 |
+
for t in range(T):
|
171 |
+
train_gt_images_V = []
|
172 |
+
train_pred_images_V = []
|
173 |
+
for v in range(V):
|
174 |
+
train_gt_images_V.append((gt_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
175 |
+
train_pred_images_V.append((pred_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
176 |
+
train_gt_images.append(np.concatenate(train_gt_images_V, axis=2))
|
177 |
+
train_pred_images.append(np.concatenate(train_pred_images_V, axis=2))
|
178 |
+
train_gt_images = np.concatenate(train_gt_images, axis=0)
|
179 |
+
train_pred_images = np.concatenate(train_pred_images, axis=0)
|
180 |
+
imageio.mimwrite(f'{opt.workspace}/train_gt_images_{epoch}_{i}.mp4', train_gt_images, fps=8)
|
181 |
+
imageio.mimwrite(f'{opt.workspace}/train_pred_images_{epoch}_{i}.mp4', train_pred_images, fps=8)
|
182 |
+
|
183 |
+
|
184 |
+
elif '3d' in opt.data_mode:
|
185 |
+
gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
|
186 |
+
gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
|
187 |
+
kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images)
|
188 |
+
|
189 |
+
pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
|
190 |
+
pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
|
191 |
+
kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images)
|
192 |
+
else:
|
193 |
+
raise NotImplementedError
|
194 |
+
|
195 |
+
|
196 |
+
total_loss = accelerator.gather_for_metrics(total_loss).mean()
|
197 |
+
total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
|
198 |
+
if accelerator.is_main_process:
|
199 |
+
total_loss /= len(train_dataloader)
|
200 |
+
total_psnr /= len(train_dataloader)
|
201 |
+
accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}")
|
202 |
+
|
203 |
+
# checkpoint
|
204 |
+
accelerator.wait_for_everyone()
|
205 |
+
accelerator.save_model(model, opt.workspace)
|
206 |
+
accelerator.save_model(model, os.path.join(opt.workspace, 'backup'))
|
207 |
+
if accelerator.is_main_process:
|
208 |
+
torch.save(optimizer.state_dict(), os.path.join(opt.workspace, 'optimizer.pth'))
|
209 |
+
torch.save(scheduler.state_dict(), os.path.join(opt.workspace, 'scheduler.pth'))
|
210 |
+
with open(f'{opt.workspace}/metadata.json', 'w') as f:
|
211 |
+
json.dump({'epoch': epoch}, f)
|
212 |
+
|
213 |
+
torch.save(optimizer.state_dict(), os.path.join(opt.workspace, 'backup', 'optimizer.pth'))
|
214 |
+
torch.save(scheduler.state_dict(), os.path.join(opt.workspace, 'backup', 'scheduler.pth'))
|
215 |
+
with open(f'{opt.workspace}/backup/metadata.json', 'w') as f:
|
216 |
+
json.dump({'epoch': epoch}, f)
|
217 |
+
|
218 |
+
|
219 |
+
# eval
|
220 |
+
with torch.no_grad():
|
221 |
+
model.eval()
|
222 |
+
total_psnr = 0
|
223 |
+
for i, data in enumerate(test_dataloader):
|
224 |
+
|
225 |
+
out = model(data)
|
226 |
+
|
227 |
+
psnr = out['psnr']
|
228 |
+
total_psnr += psnr.detach()
|
229 |
+
|
230 |
+
# save some images
|
231 |
+
if accelerator.is_main_process:
|
232 |
+
if '4d' in opt.data_mode:
|
233 |
+
B, T, V = opt.batch_size, opt.num_frames, opt.num_views
|
234 |
+
|
235 |
+
gt_images = data['images_output'].reshape(-1, T, V, *data['images_output'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
|
236 |
+
pred_images = out['images_pred'].reshape(-1, T, V, *out['images_pred'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
|
237 |
+
|
238 |
+
eval_gt_images = []
|
239 |
+
eval_pred_images = []
|
240 |
+
for t in range(T):
|
241 |
+
eval_gt_images_V = []
|
242 |
+
eval_pred_images_V = []
|
243 |
+
for v in range(V):
|
244 |
+
eval_gt_images_V.append((gt_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
245 |
+
eval_pred_images_V.append((pred_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
246 |
+
eval_gt_images.append(np.concatenate(eval_gt_images_V, axis=2))
|
247 |
+
eval_pred_images.append(np.concatenate(eval_pred_images_V, axis=2))
|
248 |
+
eval_gt_images = np.concatenate(eval_gt_images, axis=0)
|
249 |
+
eval_pred_images = np.concatenate(eval_pred_images, axis=0)
|
250 |
+
imageio.mimwrite(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.mp4', eval_gt_images, fps=8)
|
251 |
+
imageio.mimwrite(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.mp4', eval_pred_images, fps=8)
|
252 |
+
|
253 |
+
elif '3d' in opt.data_mode:
|
254 |
+
gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
|
255 |
+
gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
|
256 |
+
kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images)
|
257 |
+
|
258 |
+
pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
|
259 |
+
pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
|
260 |
+
kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images)
|
261 |
+
else:
|
262 |
+
raise NotImplementedError
|
263 |
+
|
264 |
+
torch.cuda.empty_cache()
|
265 |
+
|
266 |
+
total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
|
267 |
+
if accelerator.is_main_process:
|
268 |
+
total_psnr /= len(test_dataloader)
|
269 |
+
accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}")
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
main()
|
mvdream/mv_unet.py
ADDED
@@ -0,0 +1,1005 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
from inspect import isfunction
|
4 |
+
from typing import Optional, Any, List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import ConfigMixin
|
12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
13 |
+
|
14 |
+
# require xformers!
|
15 |
+
import xformers
|
16 |
+
import xformers.ops
|
17 |
+
|
18 |
+
from kiui.cam import orbit_camera
|
19 |
+
|
20 |
+
def get_camera(
|
21 |
+
num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
|
22 |
+
):
|
23 |
+
angle_gap = azimuth_span / num_frames
|
24 |
+
cameras = []
|
25 |
+
for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
|
26 |
+
|
27 |
+
pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4]
|
28 |
+
|
29 |
+
# opengl to blender
|
30 |
+
if blender_coord:
|
31 |
+
pose[2] *= -1
|
32 |
+
pose[[1, 2]] = pose[[2, 1]]
|
33 |
+
|
34 |
+
cameras.append(pose.flatten())
|
35 |
+
|
36 |
+
if extra_view:
|
37 |
+
cameras.append(np.zeros_like(cameras[0]))
|
38 |
+
|
39 |
+
return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
|
40 |
+
|
41 |
+
|
42 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
43 |
+
"""
|
44 |
+
Create sinusoidal timestep embeddings.
|
45 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
46 |
+
These may be fractional.
|
47 |
+
:param dim: the dimension of the output.
|
48 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
49 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
50 |
+
"""
|
51 |
+
if not repeat_only:
|
52 |
+
half = dim // 2
|
53 |
+
freqs = torch.exp(
|
54 |
+
-math.log(max_period)
|
55 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
56 |
+
/ half
|
57 |
+
).to(device=timesteps.device)
|
58 |
+
args = timesteps[:, None] * freqs[None]
|
59 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
60 |
+
if dim % 2:
|
61 |
+
embedding = torch.cat(
|
62 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
embedding = repeat(timesteps, "b -> b d", d=dim)
|
66 |
+
# import pdb; pdb.set_trace()
|
67 |
+
return embedding
|
68 |
+
|
69 |
+
|
70 |
+
def zero_module(module):
|
71 |
+
"""
|
72 |
+
Zero out the parameters of a module and return it.
|
73 |
+
"""
|
74 |
+
for p in module.parameters():
|
75 |
+
p.detach().zero_()
|
76 |
+
return module
|
77 |
+
|
78 |
+
|
79 |
+
def conv_nd(dims, *args, **kwargs):
|
80 |
+
"""
|
81 |
+
Create a 1D, 2D, or 3D convolution module.
|
82 |
+
"""
|
83 |
+
if dims == 1:
|
84 |
+
return nn.Conv1d(*args, **kwargs)
|
85 |
+
elif dims == 2:
|
86 |
+
return nn.Conv2d(*args, **kwargs)
|
87 |
+
elif dims == 3:
|
88 |
+
return nn.Conv3d(*args, **kwargs)
|
89 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
90 |
+
|
91 |
+
|
92 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
93 |
+
"""
|
94 |
+
Create a 1D, 2D, or 3D average pooling module.
|
95 |
+
"""
|
96 |
+
if dims == 1:
|
97 |
+
return nn.AvgPool1d(*args, **kwargs)
|
98 |
+
elif dims == 2:
|
99 |
+
return nn.AvgPool2d(*args, **kwargs)
|
100 |
+
elif dims == 3:
|
101 |
+
return nn.AvgPool3d(*args, **kwargs)
|
102 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
103 |
+
|
104 |
+
|
105 |
+
def default(val, d):
|
106 |
+
if val is not None:
|
107 |
+
return val
|
108 |
+
return d() if isfunction(d) else d
|
109 |
+
|
110 |
+
|
111 |
+
class GEGLU(nn.Module):
|
112 |
+
def __init__(self, dim_in, dim_out):
|
113 |
+
super().__init__()
|
114 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
118 |
+
return x * F.gelu(gate)
|
119 |
+
|
120 |
+
|
121 |
+
class FeedForward(nn.Module):
|
122 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
123 |
+
super().__init__()
|
124 |
+
inner_dim = int(dim * mult)
|
125 |
+
dim_out = default(dim_out, dim)
|
126 |
+
project_in = (
|
127 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
128 |
+
if not glu
|
129 |
+
else GEGLU(dim, inner_dim)
|
130 |
+
)
|
131 |
+
|
132 |
+
self.net = nn.Sequential(
|
133 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
134 |
+
)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
return self.net(x)
|
138 |
+
|
139 |
+
|
140 |
+
class MemoryEfficientCrossAttention(nn.Module):
|
141 |
+
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
query_dim,
|
145 |
+
context_dim=None,
|
146 |
+
heads=8,
|
147 |
+
dim_head=64,
|
148 |
+
dropout=0.0,
|
149 |
+
ip_dim=0,
|
150 |
+
ip_weight=1,
|
151 |
+
):
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
inner_dim = dim_head * heads
|
155 |
+
context_dim = default(context_dim, query_dim)
|
156 |
+
|
157 |
+
self.heads = heads
|
158 |
+
self.dim_head = dim_head
|
159 |
+
|
160 |
+
self.ip_dim = ip_dim
|
161 |
+
self.ip_weight = ip_weight
|
162 |
+
|
163 |
+
if self.ip_dim > 0:
|
164 |
+
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
165 |
+
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
166 |
+
|
167 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
168 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
169 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
170 |
+
|
171 |
+
self.to_out = nn.Sequential(
|
172 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
173 |
+
)
|
174 |
+
self.attention_op: Optional[Any] = None
|
175 |
+
|
176 |
+
def forward(self, x, context=None):
|
177 |
+
q = self.to_q(x)
|
178 |
+
context = default(context, x)
|
179 |
+
|
180 |
+
if self.ip_dim > 0:
|
181 |
+
# context: [B, 77 + 16(ip), 1024]
|
182 |
+
token_len = context.shape[1]
|
183 |
+
context_ip = context[:, -self.ip_dim :, :]
|
184 |
+
k_ip = self.to_k_ip(context_ip)
|
185 |
+
v_ip = self.to_v_ip(context_ip)
|
186 |
+
context = context[:, : (token_len - self.ip_dim), :]
|
187 |
+
|
188 |
+
k = self.to_k(context)
|
189 |
+
v = self.to_v(context)
|
190 |
+
|
191 |
+
b, _, _ = q.shape
|
192 |
+
q, k, v = map(
|
193 |
+
lambda t: t.unsqueeze(3)
|
194 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
195 |
+
.permute(0, 2, 1, 3)
|
196 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
197 |
+
.contiguous(),
|
198 |
+
(q, k, v),
|
199 |
+
)
|
200 |
+
|
201 |
+
# actually compute the attention, what we cannot get enough of
|
202 |
+
out = xformers.ops.memory_efficient_attention(
|
203 |
+
q, k, v, attn_bias=None, op=self.attention_op
|
204 |
+
)
|
205 |
+
|
206 |
+
if self.ip_dim > 0:
|
207 |
+
k_ip, v_ip = map(
|
208 |
+
lambda t: t.unsqueeze(3)
|
209 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
210 |
+
.permute(0, 2, 1, 3)
|
211 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
212 |
+
.contiguous(),
|
213 |
+
(k_ip, v_ip),
|
214 |
+
)
|
215 |
+
# actually compute the attention, what we cannot get enough of
|
216 |
+
out_ip = xformers.ops.memory_efficient_attention(
|
217 |
+
q, k_ip, v_ip, attn_bias=None, op=self.attention_op
|
218 |
+
)
|
219 |
+
out = out + self.ip_weight * out_ip
|
220 |
+
|
221 |
+
out = (
|
222 |
+
out.unsqueeze(0)
|
223 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
224 |
+
.permute(0, 2, 1, 3)
|
225 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
226 |
+
)
|
227 |
+
return self.to_out(out)
|
228 |
+
|
229 |
+
|
230 |
+
class BasicTransformerBlock3D(nn.Module):
|
231 |
+
|
232 |
+
def __init__(
|
233 |
+
self,
|
234 |
+
dim,
|
235 |
+
n_heads,
|
236 |
+
d_head,
|
237 |
+
context_dim,
|
238 |
+
dropout=0.0,
|
239 |
+
gated_ff=True,
|
240 |
+
ip_dim=0,
|
241 |
+
ip_weight=1,
|
242 |
+
):
|
243 |
+
super().__init__()
|
244 |
+
|
245 |
+
self.attn1 = MemoryEfficientCrossAttention(
|
246 |
+
query_dim=dim,
|
247 |
+
context_dim=None, # self-attention
|
248 |
+
heads=n_heads,
|
249 |
+
dim_head=d_head,
|
250 |
+
dropout=dropout,
|
251 |
+
)
|
252 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
253 |
+
self.attn2 = MemoryEfficientCrossAttention(
|
254 |
+
query_dim=dim,
|
255 |
+
context_dim=context_dim,
|
256 |
+
heads=n_heads,
|
257 |
+
dim_head=d_head,
|
258 |
+
dropout=dropout,
|
259 |
+
# ip only applies to cross-attention
|
260 |
+
ip_dim=ip_dim,
|
261 |
+
ip_weight=ip_weight,
|
262 |
+
)
|
263 |
+
self.norm1 = nn.LayerNorm(dim)
|
264 |
+
self.norm2 = nn.LayerNorm(dim)
|
265 |
+
self.norm3 = nn.LayerNorm(dim)
|
266 |
+
|
267 |
+
def forward(self, x, context=None, num_frames=1):
|
268 |
+
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
|
269 |
+
x = self.attn1(self.norm1(x), context=None) + x
|
270 |
+
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
|
271 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
272 |
+
x = self.ff(self.norm3(x)) + x
|
273 |
+
return x
|
274 |
+
|
275 |
+
|
276 |
+
class SpatialTransformer3D(nn.Module):
|
277 |
+
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
in_channels,
|
281 |
+
n_heads,
|
282 |
+
d_head,
|
283 |
+
context_dim, # cross attention input dim
|
284 |
+
depth=1,
|
285 |
+
dropout=0.0,
|
286 |
+
ip_dim=0,
|
287 |
+
ip_weight=1,
|
288 |
+
):
|
289 |
+
super().__init__()
|
290 |
+
|
291 |
+
if not isinstance(context_dim, list):
|
292 |
+
context_dim = [context_dim]
|
293 |
+
|
294 |
+
self.in_channels = in_channels
|
295 |
+
|
296 |
+
inner_dim = n_heads * d_head
|
297 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
298 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
299 |
+
|
300 |
+
self.transformer_blocks = nn.ModuleList(
|
301 |
+
[
|
302 |
+
BasicTransformerBlock3D(
|
303 |
+
inner_dim,
|
304 |
+
n_heads,
|
305 |
+
d_head,
|
306 |
+
context_dim=context_dim[d],
|
307 |
+
dropout=dropout,
|
308 |
+
ip_dim=ip_dim,
|
309 |
+
ip_weight=ip_weight,
|
310 |
+
)
|
311 |
+
for d in range(depth)
|
312 |
+
]
|
313 |
+
)
|
314 |
+
|
315 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
316 |
+
|
317 |
+
|
318 |
+
def forward(self, x, context=None, num_frames=1):
|
319 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
320 |
+
if not isinstance(context, list):
|
321 |
+
context = [context]
|
322 |
+
b, c, h, w = x.shape
|
323 |
+
x_in = x
|
324 |
+
x = self.norm(x)
|
325 |
+
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
326 |
+
x = self.proj_in(x)
|
327 |
+
for i, block in enumerate(self.transformer_blocks):
|
328 |
+
x = block(x, context=context[i], num_frames=num_frames)
|
329 |
+
x = self.proj_out(x)
|
330 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
331 |
+
|
332 |
+
return x + x_in
|
333 |
+
|
334 |
+
|
335 |
+
class PerceiverAttention(nn.Module):
|
336 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
337 |
+
super().__init__()
|
338 |
+
self.scale = dim_head ** -0.5
|
339 |
+
self.dim_head = dim_head
|
340 |
+
self.heads = heads
|
341 |
+
inner_dim = dim_head * heads
|
342 |
+
|
343 |
+
self.norm1 = nn.LayerNorm(dim)
|
344 |
+
self.norm2 = nn.LayerNorm(dim)
|
345 |
+
|
346 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
347 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
348 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
349 |
+
|
350 |
+
def forward(self, x, latents):
|
351 |
+
"""
|
352 |
+
Args:
|
353 |
+
x (torch.Tensor): image features
|
354 |
+
shape (b, n1, D)
|
355 |
+
latent (torch.Tensor): latent features
|
356 |
+
shape (b, n2, D)
|
357 |
+
"""
|
358 |
+
x = self.norm1(x)
|
359 |
+
latents = self.norm2(latents)
|
360 |
+
|
361 |
+
b, l, _ = latents.shape
|
362 |
+
|
363 |
+
q = self.to_q(latents)
|
364 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
365 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
366 |
+
|
367 |
+
q, k, v = map(
|
368 |
+
lambda t: t.reshape(b, t.shape[1], self.heads, -1)
|
369 |
+
.transpose(1, 2)
|
370 |
+
.reshape(b, self.heads, t.shape[1], -1)
|
371 |
+
.contiguous(),
|
372 |
+
(q, k, v),
|
373 |
+
)
|
374 |
+
|
375 |
+
# attention
|
376 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
377 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
378 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
379 |
+
out = weight @ v
|
380 |
+
|
381 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
382 |
+
|
383 |
+
return self.to_out(out)
|
384 |
+
|
385 |
+
|
386 |
+
class Resampler(nn.Module):
|
387 |
+
def __init__(
|
388 |
+
self,
|
389 |
+
dim=1024,
|
390 |
+
depth=8,
|
391 |
+
dim_head=64,
|
392 |
+
heads=16,
|
393 |
+
num_queries=8,
|
394 |
+
embedding_dim=768,
|
395 |
+
output_dim=1024,
|
396 |
+
ff_mult=4,
|
397 |
+
):
|
398 |
+
super().__init__()
|
399 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
400 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
401 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
402 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
403 |
+
|
404 |
+
self.layers = nn.ModuleList([])
|
405 |
+
for _ in range(depth):
|
406 |
+
self.layers.append(
|
407 |
+
nn.ModuleList(
|
408 |
+
[
|
409 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
410 |
+
nn.Sequential(
|
411 |
+
nn.LayerNorm(dim),
|
412 |
+
nn.Linear(dim, dim * ff_mult, bias=False),
|
413 |
+
nn.GELU(),
|
414 |
+
nn.Linear(dim * ff_mult, dim, bias=False),
|
415 |
+
)
|
416 |
+
]
|
417 |
+
)
|
418 |
+
)
|
419 |
+
|
420 |
+
def forward(self, x):
|
421 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
422 |
+
x = self.proj_in(x)
|
423 |
+
for attn, ff in self.layers:
|
424 |
+
latents = attn(x, latents) + latents
|
425 |
+
latents = ff(latents) + latents
|
426 |
+
|
427 |
+
latents = self.proj_out(latents)
|
428 |
+
return self.norm_out(latents)
|
429 |
+
|
430 |
+
|
431 |
+
class CondSequential(nn.Sequential):
|
432 |
+
"""
|
433 |
+
A sequential module that passes timestep embeddings to the children that
|
434 |
+
support it as an extra input.
|
435 |
+
"""
|
436 |
+
|
437 |
+
def forward(self, x, emb, context=None, num_frames=1):
|
438 |
+
for layer in self:
|
439 |
+
if isinstance(layer, ResBlock):
|
440 |
+
x = layer(x, emb)
|
441 |
+
elif isinstance(layer, SpatialTransformer3D):
|
442 |
+
x = layer(x, context, num_frames=num_frames)
|
443 |
+
else:
|
444 |
+
x = layer(x)
|
445 |
+
return x
|
446 |
+
|
447 |
+
|
448 |
+
class Upsample(nn.Module):
|
449 |
+
"""
|
450 |
+
An upsampling layer with an optional convolution.
|
451 |
+
:param channels: channels in the inputs and outputs.
|
452 |
+
:param use_conv: a bool determining if a convolution is applied.
|
453 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
454 |
+
upsampling occurs in the inner-two dimensions.
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
458 |
+
super().__init__()
|
459 |
+
self.channels = channels
|
460 |
+
self.out_channels = out_channels or channels
|
461 |
+
self.use_conv = use_conv
|
462 |
+
self.dims = dims
|
463 |
+
if use_conv:
|
464 |
+
self.conv = conv_nd(
|
465 |
+
dims, self.channels, self.out_channels, 3, padding=padding
|
466 |
+
)
|
467 |
+
|
468 |
+
def forward(self, x):
|
469 |
+
assert x.shape[1] == self.channels
|
470 |
+
if self.dims == 3:
|
471 |
+
x = F.interpolate(
|
472 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
473 |
+
)
|
474 |
+
else:
|
475 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
476 |
+
if self.use_conv:
|
477 |
+
x = self.conv(x)
|
478 |
+
return x
|
479 |
+
|
480 |
+
|
481 |
+
class Downsample(nn.Module):
|
482 |
+
"""
|
483 |
+
A downsampling layer with an optional convolution.
|
484 |
+
:param channels: channels in the inputs and outputs.
|
485 |
+
:param use_conv: a bool determining if a convolution is applied.
|
486 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
487 |
+
downsampling occurs in the inner-two dimensions.
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
491 |
+
super().__init__()
|
492 |
+
self.channels = channels
|
493 |
+
self.out_channels = out_channels or channels
|
494 |
+
self.use_conv = use_conv
|
495 |
+
self.dims = dims
|
496 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
497 |
+
if use_conv:
|
498 |
+
self.op = conv_nd(
|
499 |
+
dims,
|
500 |
+
self.channels,
|
501 |
+
self.out_channels,
|
502 |
+
3,
|
503 |
+
stride=stride,
|
504 |
+
padding=padding,
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
assert self.channels == self.out_channels
|
508 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
509 |
+
|
510 |
+
def forward(self, x):
|
511 |
+
assert x.shape[1] == self.channels
|
512 |
+
return self.op(x)
|
513 |
+
|
514 |
+
|
515 |
+
class ResBlock(nn.Module):
|
516 |
+
"""
|
517 |
+
A residual block that can optionally change the number of channels.
|
518 |
+
:param channels: the number of input channels.
|
519 |
+
:param emb_channels: the number of timestep embedding channels.
|
520 |
+
:param dropout: the rate of dropout.
|
521 |
+
:param out_channels: if specified, the number of out channels.
|
522 |
+
:param use_conv: if True and out_channels is specified, use a spatial
|
523 |
+
convolution instead of a smaller 1x1 convolution to change the
|
524 |
+
channels in the skip connection.
|
525 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
526 |
+
:param up: if True, use this block for upsampling.
|
527 |
+
:param down: if True, use this block for downsampling.
|
528 |
+
"""
|
529 |
+
|
530 |
+
def __init__(
|
531 |
+
self,
|
532 |
+
channels,
|
533 |
+
emb_channels,
|
534 |
+
dropout,
|
535 |
+
out_channels=None,
|
536 |
+
use_conv=False,
|
537 |
+
use_scale_shift_norm=False,
|
538 |
+
dims=2,
|
539 |
+
up=False,
|
540 |
+
down=False,
|
541 |
+
):
|
542 |
+
super().__init__()
|
543 |
+
self.channels = channels
|
544 |
+
self.emb_channels = emb_channels
|
545 |
+
self.dropout = dropout
|
546 |
+
self.out_channels = out_channels or channels
|
547 |
+
self.use_conv = use_conv
|
548 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
549 |
+
|
550 |
+
self.in_layers = nn.Sequential(
|
551 |
+
nn.GroupNorm(32, channels),
|
552 |
+
nn.SiLU(),
|
553 |
+
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
554 |
+
)
|
555 |
+
|
556 |
+
self.updown = up or down
|
557 |
+
|
558 |
+
if up:
|
559 |
+
self.h_upd = Upsample(channels, False, dims)
|
560 |
+
self.x_upd = Upsample(channels, False, dims)
|
561 |
+
elif down:
|
562 |
+
self.h_upd = Downsample(channels, False, dims)
|
563 |
+
self.x_upd = Downsample(channels, False, dims)
|
564 |
+
else:
|
565 |
+
self.h_upd = self.x_upd = nn.Identity()
|
566 |
+
|
567 |
+
self.emb_layers = nn.Sequential(
|
568 |
+
nn.SiLU(),
|
569 |
+
nn.Linear(
|
570 |
+
emb_channels,
|
571 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
572 |
+
),
|
573 |
+
)
|
574 |
+
self.out_layers = nn.Sequential(
|
575 |
+
nn.GroupNorm(32, self.out_channels),
|
576 |
+
nn.SiLU(),
|
577 |
+
nn.Dropout(p=dropout),
|
578 |
+
zero_module(
|
579 |
+
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
580 |
+
),
|
581 |
+
)
|
582 |
+
|
583 |
+
if self.out_channels == channels:
|
584 |
+
self.skip_connection = nn.Identity()
|
585 |
+
elif use_conv:
|
586 |
+
self.skip_connection = conv_nd(
|
587 |
+
dims, channels, self.out_channels, 3, padding=1
|
588 |
+
)
|
589 |
+
else:
|
590 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
591 |
+
|
592 |
+
def forward(self, x, emb):
|
593 |
+
if self.updown:
|
594 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
595 |
+
h = in_rest(x)
|
596 |
+
h = self.h_upd(h)
|
597 |
+
x = self.x_upd(x)
|
598 |
+
h = in_conv(h)
|
599 |
+
else:
|
600 |
+
h = self.in_layers(x)
|
601 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
602 |
+
while len(emb_out.shape) < len(h.shape):
|
603 |
+
emb_out = emb_out[..., None]
|
604 |
+
if self.use_scale_shift_norm:
|
605 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
606 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
607 |
+
h = out_norm(h) * (1 + scale) + shift
|
608 |
+
h = out_rest(h)
|
609 |
+
else:
|
610 |
+
h = h + emb_out
|
611 |
+
h = self.out_layers(h)
|
612 |
+
return self.skip_connection(x) + h
|
613 |
+
|
614 |
+
|
615 |
+
class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
616 |
+
"""
|
617 |
+
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
618 |
+
:param in_channels: channels in the input Tensor.
|
619 |
+
:param model_channels: base channel count for the model.
|
620 |
+
:param out_channels: channels in the output Tensor.
|
621 |
+
:param num_res_blocks: number of residual blocks per downsample.
|
622 |
+
:param attention_resolutions: a collection of downsample rates at which
|
623 |
+
attention will take place. May be a set, list, or tuple.
|
624 |
+
For example, if this contains 4, then at 4x downsampling, attention
|
625 |
+
will be used.
|
626 |
+
:param dropout: the dropout probability.
|
627 |
+
:param channel_mult: channel multiplier for each level of the UNet.
|
628 |
+
:param conv_resample: if True, use learned convolutions for upsampling and
|
629 |
+
downsampling.
|
630 |
+
:param dims: determines if the signal is 1D, 2D, or 3D.
|
631 |
+
:param num_classes: if specified (as an int), then this model will be
|
632 |
+
class-conditional with `num_classes` classes.
|
633 |
+
:param num_heads: the number of attention heads in each attention layer.
|
634 |
+
:param num_heads_channels: if specified, ignore num_heads and instead use
|
635 |
+
a fixed channel width per attention head.
|
636 |
+
:param num_heads_upsample: works with num_heads to set a different number
|
637 |
+
of heads for upsampling. Deprecated.
|
638 |
+
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
639 |
+
:param resblock_updown: use residual blocks for up/downsampling.
|
640 |
+
:param use_new_attention_order: use a different attention pattern for potentially
|
641 |
+
increased efficiency.
|
642 |
+
:param camera_dim: dimensionality of camera input.
|
643 |
+
"""
|
644 |
+
|
645 |
+
def __init__(
|
646 |
+
self,
|
647 |
+
image_size,
|
648 |
+
in_channels,
|
649 |
+
model_channels,
|
650 |
+
out_channels,
|
651 |
+
num_res_blocks,
|
652 |
+
attention_resolutions,
|
653 |
+
dropout=0,
|
654 |
+
channel_mult=(1, 2, 4, 8),
|
655 |
+
conv_resample=True,
|
656 |
+
dims=2,
|
657 |
+
num_classes=None,
|
658 |
+
num_heads=-1,
|
659 |
+
num_head_channels=-1,
|
660 |
+
num_heads_upsample=-1,
|
661 |
+
use_scale_shift_norm=False,
|
662 |
+
resblock_updown=False,
|
663 |
+
transformer_depth=1,
|
664 |
+
context_dim=None,
|
665 |
+
n_embed=None,
|
666 |
+
num_attention_blocks=None,
|
667 |
+
adm_in_channels=None,
|
668 |
+
camera_dim=None,
|
669 |
+
ip_dim=0, # imagedream uses ip_dim > 0
|
670 |
+
ip_weight=1.0,
|
671 |
+
**kwargs,
|
672 |
+
):
|
673 |
+
super().__init__()
|
674 |
+
assert context_dim is not None
|
675 |
+
|
676 |
+
if num_heads_upsample == -1:
|
677 |
+
num_heads_upsample = num_heads
|
678 |
+
|
679 |
+
if num_heads == -1:
|
680 |
+
assert (
|
681 |
+
num_head_channels != -1
|
682 |
+
), "Either num_heads or num_head_channels has to be set"
|
683 |
+
|
684 |
+
if num_head_channels == -1:
|
685 |
+
assert (
|
686 |
+
num_heads != -1
|
687 |
+
), "Either num_heads or num_head_channels has to be set"
|
688 |
+
|
689 |
+
self.image_size = image_size
|
690 |
+
self.in_channels = in_channels
|
691 |
+
self.model_channels = model_channels
|
692 |
+
self.out_channels = out_channels
|
693 |
+
if isinstance(num_res_blocks, int):
|
694 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
695 |
+
else:
|
696 |
+
if len(num_res_blocks) != len(channel_mult):
|
697 |
+
raise ValueError(
|
698 |
+
"provide num_res_blocks either as an int (globally constant) or "
|
699 |
+
"as a list/tuple (per-level) with the same length as channel_mult"
|
700 |
+
)
|
701 |
+
self.num_res_blocks = num_res_blocks
|
702 |
+
|
703 |
+
if num_attention_blocks is not None:
|
704 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
705 |
+
assert all(
|
706 |
+
map(
|
707 |
+
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
708 |
+
range(len(num_attention_blocks)),
|
709 |
+
)
|
710 |
+
)
|
711 |
+
print(
|
712 |
+
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
713 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
714 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
715 |
+
f"attention will still not be set."
|
716 |
+
)
|
717 |
+
|
718 |
+
self.attention_resolutions = attention_resolutions
|
719 |
+
self.dropout = dropout
|
720 |
+
self.channel_mult = channel_mult
|
721 |
+
self.conv_resample = conv_resample
|
722 |
+
self.num_classes = num_classes
|
723 |
+
self.num_heads = num_heads
|
724 |
+
self.num_head_channels = num_head_channels
|
725 |
+
self.num_heads_upsample = num_heads_upsample
|
726 |
+
self.predict_codebook_ids = n_embed is not None
|
727 |
+
|
728 |
+
self.ip_dim = ip_dim
|
729 |
+
self.ip_weight = ip_weight
|
730 |
+
|
731 |
+
if self.ip_dim > 0:
|
732 |
+
self.image_embed = Resampler(
|
733 |
+
dim=context_dim,
|
734 |
+
depth=4,
|
735 |
+
dim_head=64,
|
736 |
+
heads=12,
|
737 |
+
num_queries=ip_dim, # num token
|
738 |
+
embedding_dim=1280,
|
739 |
+
output_dim=context_dim,
|
740 |
+
ff_mult=4,
|
741 |
+
)
|
742 |
+
|
743 |
+
time_embed_dim = model_channels * 4
|
744 |
+
self.time_embed = nn.Sequential(
|
745 |
+
nn.Linear(model_channels, time_embed_dim),
|
746 |
+
nn.SiLU(),
|
747 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
748 |
+
)
|
749 |
+
|
750 |
+
if camera_dim is not None:
|
751 |
+
time_embed_dim = model_channels * 4
|
752 |
+
self.camera_embed = nn.Sequential(
|
753 |
+
nn.Linear(camera_dim, time_embed_dim),
|
754 |
+
nn.SiLU(),
|
755 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
756 |
+
)
|
757 |
+
|
758 |
+
if self.num_classes is not None:
|
759 |
+
if isinstance(self.num_classes, int):
|
760 |
+
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
|
761 |
+
elif self.num_classes == "continuous":
|
762 |
+
# print("setting up linear c_adm embedding layer")
|
763 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
764 |
+
elif self.num_classes == "sequential":
|
765 |
+
assert adm_in_channels is not None
|
766 |
+
self.label_emb = nn.Sequential(
|
767 |
+
nn.Sequential(
|
768 |
+
nn.Linear(adm_in_channels, time_embed_dim),
|
769 |
+
nn.SiLU(),
|
770 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
771 |
+
)
|
772 |
+
)
|
773 |
+
else:
|
774 |
+
raise ValueError()
|
775 |
+
|
776 |
+
self.input_blocks = nn.ModuleList(
|
777 |
+
[
|
778 |
+
CondSequential(
|
779 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
780 |
+
)
|
781 |
+
]
|
782 |
+
)
|
783 |
+
self._feature_size = model_channels
|
784 |
+
input_block_chans = [model_channels]
|
785 |
+
ch = model_channels
|
786 |
+
ds = 1
|
787 |
+
for level, mult in enumerate(channel_mult):
|
788 |
+
for nr in range(self.num_res_blocks[level]):
|
789 |
+
layers: List[Any] = [
|
790 |
+
ResBlock(
|
791 |
+
ch,
|
792 |
+
time_embed_dim,
|
793 |
+
dropout,
|
794 |
+
out_channels=mult * model_channels,
|
795 |
+
dims=dims,
|
796 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
797 |
+
)
|
798 |
+
]
|
799 |
+
ch = mult * model_channels
|
800 |
+
if ds in attention_resolutions:
|
801 |
+
if num_head_channels == -1:
|
802 |
+
dim_head = ch // num_heads
|
803 |
+
else:
|
804 |
+
num_heads = ch // num_head_channels
|
805 |
+
dim_head = num_head_channels
|
806 |
+
|
807 |
+
if num_attention_blocks is None or nr < num_attention_blocks[level]:
|
808 |
+
layers.append(
|
809 |
+
SpatialTransformer3D(
|
810 |
+
ch,
|
811 |
+
num_heads,
|
812 |
+
dim_head,
|
813 |
+
context_dim=context_dim,
|
814 |
+
depth=transformer_depth,
|
815 |
+
ip_dim=self.ip_dim,
|
816 |
+
ip_weight=self.ip_weight,
|
817 |
+
)
|
818 |
+
)
|
819 |
+
self.input_blocks.append(CondSequential(*layers))
|
820 |
+
self._feature_size += ch
|
821 |
+
input_block_chans.append(ch)
|
822 |
+
if level != len(channel_mult) - 1:
|
823 |
+
out_ch = ch
|
824 |
+
self.input_blocks.append(
|
825 |
+
CondSequential(
|
826 |
+
ResBlock(
|
827 |
+
ch,
|
828 |
+
time_embed_dim,
|
829 |
+
dropout,
|
830 |
+
out_channels=out_ch,
|
831 |
+
dims=dims,
|
832 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
833 |
+
down=True,
|
834 |
+
)
|
835 |
+
if resblock_updown
|
836 |
+
else Downsample(
|
837 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
838 |
+
)
|
839 |
+
)
|
840 |
+
)
|
841 |
+
ch = out_ch
|
842 |
+
input_block_chans.append(ch)
|
843 |
+
ds *= 2
|
844 |
+
self._feature_size += ch
|
845 |
+
|
846 |
+
if num_head_channels == -1:
|
847 |
+
dim_head = ch // num_heads
|
848 |
+
else:
|
849 |
+
num_heads = ch // num_head_channels
|
850 |
+
dim_head = num_head_channels
|
851 |
+
|
852 |
+
self.middle_block = CondSequential(
|
853 |
+
ResBlock(
|
854 |
+
ch,
|
855 |
+
time_embed_dim,
|
856 |
+
dropout,
|
857 |
+
dims=dims,
|
858 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
859 |
+
),
|
860 |
+
SpatialTransformer3D(
|
861 |
+
ch,
|
862 |
+
num_heads,
|
863 |
+
dim_head,
|
864 |
+
context_dim=context_dim,
|
865 |
+
depth=transformer_depth,
|
866 |
+
ip_dim=self.ip_dim,
|
867 |
+
ip_weight=self.ip_weight,
|
868 |
+
),
|
869 |
+
ResBlock(
|
870 |
+
ch,
|
871 |
+
time_embed_dim,
|
872 |
+
dropout,
|
873 |
+
dims=dims,
|
874 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
875 |
+
),
|
876 |
+
)
|
877 |
+
self._feature_size += ch
|
878 |
+
|
879 |
+
self.output_blocks = nn.ModuleList([])
|
880 |
+
for level, mult in list(enumerate(channel_mult))[::-1]:
|
881 |
+
for i in range(self.num_res_blocks[level] + 1):
|
882 |
+
ich = input_block_chans.pop()
|
883 |
+
layers = [
|
884 |
+
ResBlock(
|
885 |
+
ch + ich,
|
886 |
+
time_embed_dim,
|
887 |
+
dropout,
|
888 |
+
out_channels=model_channels * mult,
|
889 |
+
dims=dims,
|
890 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
891 |
+
)
|
892 |
+
]
|
893 |
+
ch = model_channels * mult
|
894 |
+
if ds in attention_resolutions:
|
895 |
+
if num_head_channels == -1:
|
896 |
+
dim_head = ch // num_heads
|
897 |
+
else:
|
898 |
+
num_heads = ch // num_head_channels
|
899 |
+
dim_head = num_head_channels
|
900 |
+
|
901 |
+
if num_attention_blocks is None or i < num_attention_blocks[level]:
|
902 |
+
layers.append(
|
903 |
+
SpatialTransformer3D(
|
904 |
+
ch,
|
905 |
+
num_heads,
|
906 |
+
dim_head,
|
907 |
+
context_dim=context_dim,
|
908 |
+
depth=transformer_depth,
|
909 |
+
ip_dim=self.ip_dim,
|
910 |
+
ip_weight=self.ip_weight,
|
911 |
+
)
|
912 |
+
)
|
913 |
+
if level and i == self.num_res_blocks[level]:
|
914 |
+
out_ch = ch
|
915 |
+
layers.append(
|
916 |
+
ResBlock(
|
917 |
+
ch,
|
918 |
+
time_embed_dim,
|
919 |
+
dropout,
|
920 |
+
out_channels=out_ch,
|
921 |
+
dims=dims,
|
922 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
923 |
+
up=True,
|
924 |
+
)
|
925 |
+
if resblock_updown
|
926 |
+
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
927 |
+
)
|
928 |
+
ds //= 2
|
929 |
+
self.output_blocks.append(CondSequential(*layers))
|
930 |
+
self._feature_size += ch
|
931 |
+
|
932 |
+
self.out = nn.Sequential(
|
933 |
+
nn.GroupNorm(32, ch),
|
934 |
+
nn.SiLU(),
|
935 |
+
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
936 |
+
)
|
937 |
+
if self.predict_codebook_ids:
|
938 |
+
self.id_predictor = nn.Sequential(
|
939 |
+
nn.GroupNorm(32, ch),
|
940 |
+
conv_nd(dims, model_channels, n_embed, 1),
|
941 |
+
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
942 |
+
)
|
943 |
+
|
944 |
+
def forward(
|
945 |
+
self,
|
946 |
+
x,
|
947 |
+
timesteps=None,
|
948 |
+
context=None,
|
949 |
+
y=None,
|
950 |
+
camera=None,
|
951 |
+
num_frames=1,
|
952 |
+
ip=None,
|
953 |
+
ip_img=None,
|
954 |
+
**kwargs,
|
955 |
+
):
|
956 |
+
"""
|
957 |
+
Apply the model to an input batch.
|
958 |
+
:param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
|
959 |
+
:param timesteps: a 1-D batch of timesteps.
|
960 |
+
:param context: conditioning plugged in via crossattn
|
961 |
+
:param y: an [N] Tensor of labels, if class-conditional.
|
962 |
+
:param num_frames: a integer indicating number of frames for tensor reshaping.
|
963 |
+
:return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
|
964 |
+
"""
|
965 |
+
assert (
|
966 |
+
x.shape[0] % num_frames == 0
|
967 |
+
), "input batch size must be dividable by num_frames!"
|
968 |
+
assert (y is not None) == (
|
969 |
+
self.num_classes is not None
|
970 |
+
), "must specify y if and only if the model is class-conditional"
|
971 |
+
|
972 |
+
hs = []
|
973 |
+
|
974 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
975 |
+
|
976 |
+
emb = self.time_embed(t_emb)
|
977 |
+
|
978 |
+
if self.num_classes is not None:
|
979 |
+
assert y is not None
|
980 |
+
assert y.shape[0] == x.shape[0]
|
981 |
+
emb = emb + self.label_emb(y)
|
982 |
+
|
983 |
+
# Add camera embeddings
|
984 |
+
if camera is not None:
|
985 |
+
emb = emb + self.camera_embed(camera)
|
986 |
+
|
987 |
+
# imagedream variant
|
988 |
+
if self.ip_dim > 0:
|
989 |
+
x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
|
990 |
+
ip_emb = self.image_embed(ip)
|
991 |
+
context = torch.cat((context, ip_emb), 1)
|
992 |
+
|
993 |
+
h = x
|
994 |
+
for module in self.input_blocks:
|
995 |
+
h = module(h, emb, context, num_frames=num_frames)
|
996 |
+
hs.append(h)
|
997 |
+
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
998 |
+
for module in self.output_blocks:
|
999 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
1000 |
+
h = module(h, emb, context, num_frames=num_frames)
|
1001 |
+
h = h.type(x.dtype)
|
1002 |
+
if self.predict_codebook_ids:
|
1003 |
+
return self.id_predictor(h)
|
1004 |
+
else:
|
1005 |
+
return self.out(h)
|
mvdream/pipeline_mvdream.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import inspect
|
4 |
+
import numpy as np
|
5 |
+
from typing import Callable, List, Optional, Union
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
|
7 |
+
from diffusers import AutoencoderKL, DiffusionPipeline
|
8 |
+
from diffusers.utils import (
|
9 |
+
deprecate,
|
10 |
+
is_accelerate_available,
|
11 |
+
is_accelerate_version,
|
12 |
+
logging,
|
13 |
+
)
|
14 |
+
from diffusers.configuration_utils import FrozenDict
|
15 |
+
from diffusers.schedulers import DDIMScheduler
|
16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
17 |
+
|
18 |
+
from mvdream.mv_unet import MultiViewUNetModel, get_camera
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
21 |
+
|
22 |
+
|
23 |
+
class MVDreamPipeline(DiffusionPipeline):
|
24 |
+
|
25 |
+
_optional_components = ["feature_extractor", "image_encoder"]
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
vae: AutoencoderKL,
|
30 |
+
unet: MultiViewUNetModel,
|
31 |
+
tokenizer: CLIPTokenizer,
|
32 |
+
text_encoder: CLIPTextModel,
|
33 |
+
scheduler: DDIMScheduler,
|
34 |
+
# imagedream variant
|
35 |
+
feature_extractor: CLIPImageProcessor,
|
36 |
+
image_encoder: CLIPVisionModel,
|
37 |
+
requires_safety_checker: bool = False,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
|
42 |
+
deprecation_message = (
|
43 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
44 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
|
45 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
46 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
47 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
48 |
+
" file"
|
49 |
+
)
|
50 |
+
deprecate(
|
51 |
+
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
|
52 |
+
)
|
53 |
+
new_config = dict(scheduler.config)
|
54 |
+
new_config["steps_offset"] = 1
|
55 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
56 |
+
|
57 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
|
58 |
+
deprecation_message = (
|
59 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
60 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
61 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
62 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
63 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
64 |
+
)
|
65 |
+
deprecate(
|
66 |
+
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
|
67 |
+
)
|
68 |
+
new_config = dict(scheduler.config)
|
69 |
+
new_config["clip_sample"] = False
|
70 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
71 |
+
|
72 |
+
self.register_modules(
|
73 |
+
vae=vae,
|
74 |
+
unet=unet,
|
75 |
+
scheduler=scheduler,
|
76 |
+
tokenizer=tokenizer,
|
77 |
+
text_encoder=text_encoder,
|
78 |
+
feature_extractor=feature_extractor,
|
79 |
+
image_encoder=image_encoder,
|
80 |
+
)
|
81 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
82 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
83 |
+
|
84 |
+
def enable_vae_slicing(self):
|
85 |
+
r"""
|
86 |
+
Enable sliced VAE decoding.
|
87 |
+
|
88 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
89 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
90 |
+
"""
|
91 |
+
self.vae.enable_slicing()
|
92 |
+
|
93 |
+
def disable_vae_slicing(self):
|
94 |
+
r"""
|
95 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
96 |
+
computing decoding in one step.
|
97 |
+
"""
|
98 |
+
self.vae.disable_slicing()
|
99 |
+
|
100 |
+
def enable_vae_tiling(self):
|
101 |
+
r"""
|
102 |
+
Enable tiled VAE decoding.
|
103 |
+
|
104 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
105 |
+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
106 |
+
"""
|
107 |
+
self.vae.enable_tiling()
|
108 |
+
|
109 |
+
def disable_vae_tiling(self):
|
110 |
+
r"""
|
111 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
112 |
+
computing decoding in one step.
|
113 |
+
"""
|
114 |
+
self.vae.disable_tiling()
|
115 |
+
|
116 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
117 |
+
r"""
|
118 |
+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
119 |
+
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
120 |
+
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
121 |
+
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
122 |
+
`enable_model_cpu_offload`, but performance is lower.
|
123 |
+
"""
|
124 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
125 |
+
from accelerate import cpu_offload
|
126 |
+
else:
|
127 |
+
raise ImportError(
|
128 |
+
"`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
|
129 |
+
)
|
130 |
+
|
131 |
+
device = torch.device(f"cuda:{gpu_id}")
|
132 |
+
|
133 |
+
if self.device.type != "cpu":
|
134 |
+
self.to("cpu", silence_dtype_warnings=True)
|
135 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
136 |
+
|
137 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
138 |
+
cpu_offload(cpu_offloaded_model, device)
|
139 |
+
|
140 |
+
def enable_model_cpu_offload(self, gpu_id=0):
|
141 |
+
r"""
|
142 |
+
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
143 |
+
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
144 |
+
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
145 |
+
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
146 |
+
"""
|
147 |
+
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
148 |
+
from accelerate import cpu_offload_with_hook
|
149 |
+
else:
|
150 |
+
raise ImportError(
|
151 |
+
"`enable_model_offload` requires `accelerate v0.17.0` or higher."
|
152 |
+
)
|
153 |
+
|
154 |
+
device = torch.device(f"cuda:{gpu_id}")
|
155 |
+
|
156 |
+
if self.device.type != "cpu":
|
157 |
+
self.to("cpu", silence_dtype_warnings=True)
|
158 |
+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
159 |
+
|
160 |
+
hook = None
|
161 |
+
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
162 |
+
_, hook = cpu_offload_with_hook(
|
163 |
+
cpu_offloaded_model, device, prev_module_hook=hook
|
164 |
+
)
|
165 |
+
|
166 |
+
# We'll offload the last model manually.
|
167 |
+
self.final_offload_hook = hook
|
168 |
+
|
169 |
+
@property
|
170 |
+
def _execution_device(self):
|
171 |
+
r"""
|
172 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
173 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
174 |
+
hooks.
|
175 |
+
"""
|
176 |
+
if not hasattr(self.unet, "_hf_hook"):
|
177 |
+
return self.device
|
178 |
+
for module in self.unet.modules():
|
179 |
+
if (
|
180 |
+
hasattr(module, "_hf_hook")
|
181 |
+
and hasattr(module._hf_hook, "execution_device")
|
182 |
+
and module._hf_hook.execution_device is not None
|
183 |
+
):
|
184 |
+
return torch.device(module._hf_hook.execution_device)
|
185 |
+
return self.device
|
186 |
+
|
187 |
+
def _encode_prompt(
|
188 |
+
self,
|
189 |
+
prompt,
|
190 |
+
device,
|
191 |
+
num_images_per_prompt,
|
192 |
+
do_classifier_free_guidance: bool,
|
193 |
+
negative_prompt=None,
|
194 |
+
):
|
195 |
+
r"""
|
196 |
+
Encodes the prompt into text encoder hidden states.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
prompt (`str` or `List[str]`, *optional*):
|
200 |
+
prompt to be encoded
|
201 |
+
device: (`torch.device`):
|
202 |
+
torch device
|
203 |
+
num_images_per_prompt (`int`):
|
204 |
+
number of images that should be generated per prompt
|
205 |
+
do_classifier_free_guidance (`bool`):
|
206 |
+
whether to use classifier free guidance or not
|
207 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
208 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
209 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
210 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
211 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
212 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
213 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
214 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
215 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
216 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
217 |
+
argument.
|
218 |
+
"""
|
219 |
+
if prompt is not None and isinstance(prompt, str):
|
220 |
+
batch_size = 1
|
221 |
+
elif prompt is not None and isinstance(prompt, list):
|
222 |
+
batch_size = len(prompt)
|
223 |
+
else:
|
224 |
+
raise ValueError(
|
225 |
+
f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
|
226 |
+
)
|
227 |
+
|
228 |
+
text_inputs = self.tokenizer(
|
229 |
+
prompt,
|
230 |
+
padding="max_length",
|
231 |
+
max_length=self.tokenizer.model_max_length,
|
232 |
+
truncation=True,
|
233 |
+
return_tensors="pt",
|
234 |
+
)
|
235 |
+
text_input_ids = text_inputs.input_ids
|
236 |
+
untruncated_ids = self.tokenizer(
|
237 |
+
prompt, padding="longest", return_tensors="pt"
|
238 |
+
).input_ids
|
239 |
+
|
240 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
241 |
+
text_input_ids, untruncated_ids
|
242 |
+
):
|
243 |
+
removed_text = self.tokenizer.batch_decode(
|
244 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
245 |
+
)
|
246 |
+
logger.warning(
|
247 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
248 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
249 |
+
)
|
250 |
+
|
251 |
+
if (
|
252 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
253 |
+
and self.text_encoder.config.use_attention_mask
|
254 |
+
):
|
255 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
256 |
+
else:
|
257 |
+
attention_mask = None
|
258 |
+
|
259 |
+
prompt_embeds = self.text_encoder(
|
260 |
+
text_input_ids.to(device),
|
261 |
+
attention_mask=attention_mask,
|
262 |
+
)
|
263 |
+
prompt_embeds = prompt_embeds[0]
|
264 |
+
|
265 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
266 |
+
|
267 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
268 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
269 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
270 |
+
prompt_embeds = prompt_embeds.view(
|
271 |
+
bs_embed * num_images_per_prompt, seq_len, -1
|
272 |
+
)
|
273 |
+
|
274 |
+
# get unconditional embeddings for classifier free guidance
|
275 |
+
if do_classifier_free_guidance:
|
276 |
+
uncond_tokens: List[str]
|
277 |
+
if negative_prompt is None:
|
278 |
+
uncond_tokens = [""] * batch_size
|
279 |
+
elif type(prompt) is not type(negative_prompt):
|
280 |
+
raise TypeError(
|
281 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
282 |
+
f" {type(prompt)}."
|
283 |
+
)
|
284 |
+
elif isinstance(negative_prompt, str):
|
285 |
+
uncond_tokens = [negative_prompt]
|
286 |
+
elif batch_size != len(negative_prompt):
|
287 |
+
raise ValueError(
|
288 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
289 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
290 |
+
" the batch size of `prompt`."
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
uncond_tokens = negative_prompt
|
294 |
+
|
295 |
+
max_length = prompt_embeds.shape[1]
|
296 |
+
uncond_input = self.tokenizer(
|
297 |
+
uncond_tokens,
|
298 |
+
padding="max_length",
|
299 |
+
max_length=max_length,
|
300 |
+
truncation=True,
|
301 |
+
return_tensors="pt",
|
302 |
+
)
|
303 |
+
|
304 |
+
if (
|
305 |
+
hasattr(self.text_encoder.config, "use_attention_mask")
|
306 |
+
and self.text_encoder.config.use_attention_mask
|
307 |
+
):
|
308 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
309 |
+
else:
|
310 |
+
attention_mask = None
|
311 |
+
|
312 |
+
negative_prompt_embeds = self.text_encoder(
|
313 |
+
uncond_input.input_ids.to(device),
|
314 |
+
attention_mask=attention_mask,
|
315 |
+
)
|
316 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
317 |
+
|
318 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
319 |
+
seq_len = negative_prompt_embeds.shape[1]
|
320 |
+
|
321 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
322 |
+
dtype=self.text_encoder.dtype, device=device
|
323 |
+
)
|
324 |
+
|
325 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
326 |
+
1, num_images_per_prompt, 1
|
327 |
+
)
|
328 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
329 |
+
batch_size * num_images_per_prompt, seq_len, -1
|
330 |
+
)
|
331 |
+
|
332 |
+
# For classifier free guidance, we need to do two forward passes.
|
333 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
334 |
+
# to avoid doing two forward passes
|
335 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
336 |
+
|
337 |
+
return prompt_embeds
|
338 |
+
|
339 |
+
def decode_latents(self, latents):
|
340 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
341 |
+
image = self.vae.decode(latents).sample
|
342 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
343 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
344 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
345 |
+
return image
|
346 |
+
|
347 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
348 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
349 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
350 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
351 |
+
# and should be between [0, 1]
|
352 |
+
|
353 |
+
accepts_eta = "eta" in set(
|
354 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
355 |
+
)
|
356 |
+
extra_step_kwargs = {}
|
357 |
+
if accepts_eta:
|
358 |
+
extra_step_kwargs["eta"] = eta
|
359 |
+
|
360 |
+
# check if the scheduler accepts generator
|
361 |
+
accepts_generator = "generator" in set(
|
362 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
363 |
+
)
|
364 |
+
if accepts_generator:
|
365 |
+
extra_step_kwargs["generator"] = generator
|
366 |
+
return extra_step_kwargs
|
367 |
+
|
368 |
+
def prepare_latents(
|
369 |
+
self,
|
370 |
+
batch_size,
|
371 |
+
num_channels_latents,
|
372 |
+
height,
|
373 |
+
width,
|
374 |
+
dtype,
|
375 |
+
device,
|
376 |
+
generator,
|
377 |
+
latents=None,
|
378 |
+
):
|
379 |
+
shape = (
|
380 |
+
batch_size,
|
381 |
+
num_channels_latents,
|
382 |
+
height // self.vae_scale_factor,
|
383 |
+
width // self.vae_scale_factor,
|
384 |
+
)
|
385 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
386 |
+
raise ValueError(
|
387 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
388 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
389 |
+
)
|
390 |
+
|
391 |
+
if latents is None:
|
392 |
+
latents = randn_tensor(
|
393 |
+
shape, generator=generator, device=device, dtype=dtype
|
394 |
+
)
|
395 |
+
else:
|
396 |
+
latents = latents.to(device)
|
397 |
+
|
398 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
399 |
+
latents = latents * self.scheduler.init_noise_sigma
|
400 |
+
return latents
|
401 |
+
|
402 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
403 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
404 |
+
|
405 |
+
if image.dtype == np.float32:
|
406 |
+
image = (image * 255).astype(np.uint8)
|
407 |
+
|
408 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
409 |
+
image = image.to(device=device, dtype=dtype)
|
410 |
+
|
411 |
+
image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
412 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
413 |
+
|
414 |
+
return torch.zeros_like(image_embeds), image_embeds
|
415 |
+
|
416 |
+
def encode_image_latents(self, image, device, num_images_per_prompt):
|
417 |
+
|
418 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
419 |
+
|
420 |
+
image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
|
421 |
+
image = 2 * image - 1
|
422 |
+
image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
|
423 |
+
image = image.to(dtype=dtype)
|
424 |
+
|
425 |
+
posterior = self.vae.encode(image).latent_dist
|
426 |
+
latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
|
427 |
+
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
|
428 |
+
|
429 |
+
return torch.zeros_like(latents), latents
|
430 |
+
|
431 |
+
@torch.no_grad()
|
432 |
+
def __call__(
|
433 |
+
self,
|
434 |
+
prompt: str = "",
|
435 |
+
image: Optional[np.ndarray] = None,
|
436 |
+
height: int = 256,
|
437 |
+
width: int = 256,
|
438 |
+
elevation: float = 0,
|
439 |
+
num_inference_steps: int = 50,
|
440 |
+
guidance_scale: float = 7.0,
|
441 |
+
negative_prompt: str = "",
|
442 |
+
num_images_per_prompt: int = 1,
|
443 |
+
eta: float = 0.0,
|
444 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
445 |
+
output_type: Optional[str] = "numpy", # pil, numpy, latents
|
446 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
447 |
+
callback_steps: int = 1,
|
448 |
+
num_frames: int = 4,
|
449 |
+
device=torch.device("cuda:0"),
|
450 |
+
):
|
451 |
+
self.unet = self.unet.to(device=device)
|
452 |
+
self.vae = self.vae.to(device=device)
|
453 |
+
self.text_encoder = self.text_encoder.to(device=device)
|
454 |
+
|
455 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
456 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
457 |
+
# corresponds to doing no classifier free guidance.
|
458 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
459 |
+
|
460 |
+
# Prepare timesteps
|
461 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
462 |
+
timesteps = self.scheduler.timesteps
|
463 |
+
|
464 |
+
# imagedream variant
|
465 |
+
if image is not None:
|
466 |
+
assert isinstance(image, np.ndarray) and image.dtype == np.float32
|
467 |
+
self.image_encoder = self.image_encoder.to(device=device)
|
468 |
+
image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
|
469 |
+
image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
|
470 |
+
|
471 |
+
_prompt_embeds = self._encode_prompt(
|
472 |
+
prompt=prompt,
|
473 |
+
device=device,
|
474 |
+
num_images_per_prompt=num_images_per_prompt,
|
475 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
476 |
+
negative_prompt=negative_prompt,
|
477 |
+
) # type: ignore
|
478 |
+
prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
|
479 |
+
|
480 |
+
# Prepare latent variables
|
481 |
+
actual_num_frames = num_frames if image is None else num_frames + 1
|
482 |
+
latents: torch.Tensor = self.prepare_latents(
|
483 |
+
actual_num_frames * num_images_per_prompt,
|
484 |
+
4,
|
485 |
+
height,
|
486 |
+
width,
|
487 |
+
prompt_embeds_pos.dtype,
|
488 |
+
device,
|
489 |
+
generator,
|
490 |
+
None,
|
491 |
+
)
|
492 |
+
|
493 |
+
if image is not None:
|
494 |
+
camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device)
|
495 |
+
else:
|
496 |
+
camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device)
|
497 |
+
camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
|
498 |
+
|
499 |
+
# Prepare extra step kwargs.
|
500 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
501 |
+
|
502 |
+
# Denoising loop
|
503 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
504 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
505 |
+
for i, t in enumerate(timesteps):
|
506 |
+
# expand the latents if we are doing classifier free guidance
|
507 |
+
multiplier = 2 if do_classifier_free_guidance else 1
|
508 |
+
latent_model_input = torch.cat([latents] * multiplier)
|
509 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
510 |
+
|
511 |
+
unet_inputs = {
|
512 |
+
'x': latent_model_input,
|
513 |
+
'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
|
514 |
+
'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
|
515 |
+
'num_frames': actual_num_frames,
|
516 |
+
'camera': torch.cat([camera] * multiplier),
|
517 |
+
}
|
518 |
+
|
519 |
+
if image is not None:
|
520 |
+
unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
|
521 |
+
unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
|
522 |
+
|
523 |
+
# predict the noise residual
|
524 |
+
noise_pred = self.unet.forward(**unet_inputs)
|
525 |
+
|
526 |
+
# perform guidance
|
527 |
+
if do_classifier_free_guidance:
|
528 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
529 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
530 |
+
noise_pred_text - noise_pred_uncond
|
531 |
+
)
|
532 |
+
|
533 |
+
# compute the previous noisy sample x_t -> x_t-1
|
534 |
+
latents: torch.Tensor = self.scheduler.step(
|
535 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
536 |
+
)[0]
|
537 |
+
|
538 |
+
# call the callback, if provided
|
539 |
+
if i == len(timesteps) - 1 or (
|
540 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
541 |
+
):
|
542 |
+
progress_bar.update()
|
543 |
+
if callback is not None and i % callback_steps == 0:
|
544 |
+
callback(i, t, latents) # type: ignore
|
545 |
+
|
546 |
+
# Post-processing
|
547 |
+
if output_type == "latent":
|
548 |
+
image = latents
|
549 |
+
elif output_type == "pil":
|
550 |
+
image = self.decode_latents(latents)
|
551 |
+
image = self.numpy_to_pil(image)
|
552 |
+
else: # numpy
|
553 |
+
image = self.decode_latents(latents)
|
554 |
+
|
555 |
+
# Offload last model to CPU
|
556 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
557 |
+
self.final_offload_hook.offload()
|
558 |
+
|
559 |
+
return image
|
readme.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## L4GM: Large 4D Gaussian Reconstruction Model
|
3 |
+
<p align="center">
|
4 |
+
<img src="assets/teaser.jpg">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
[**Paper**](https://arxiv.org/abs/2406.10324) | [**Project Page**](https://research.nvidia.com/labs/toronto-ai/l4gm/) | [**Model Weights**](https://huggingface.co/jiawei011/L4GM)
|
8 |
+
|
9 |
+
We present L4GM, the first 4D Large Reconstruction Model that produces animated objects from a single-view video input -- in a single feed-forward pass that takes only a second.
|
10 |
+
|
11 |
+
---
|
12 |
+
|
13 |
+
### Install
|
14 |
+
```bash
|
15 |
+
conda env create -f environment.yml
|
16 |
+
conda activate l4gm
|
17 |
+
```
|
18 |
+
|
19 |
+
### Inference
|
20 |
+
Download pretrained [L4GM model](https://huggingface.co/jiawei011/L4GM/blob/main/recon.safetensors) and [4D interpolation model](https://huggingface.co/jiawei011/L4GM/blob/main/interp.safetensors) to `pretrained/recon.safetensors` and `pretrained/interp.safetensors` respectively.
|
21 |
+
|
22 |
+
Select an input video. Remove its background and crop it to 256x256 with third-party tools. We provide some processed examples in the `data_test` folder.
|
23 |
+
|
24 |
+
1. Generate 3D by:
|
25 |
+
```sh
|
26 |
+
python infer_3d.py big --workspace results --resume pretrained/recon.safetensors --num_frames 1 --test_path data_test/otter-on-surfboard_fg.mp4
|
27 |
+
```
|
28 |
+
|
29 |
+
2. Generate 4D by:
|
30 |
+
```sh
|
31 |
+
python infer_4d.py big --workspace results --resume pretrained/recon.safetensors --interpresume pretrained/interp.safetensors --num_frames 16 --test_path data_test/otter-on-surfboard_fg.mp4
|
32 |
+
```
|
33 |
+
|
34 |
+
### Training
|
35 |
+
Render Objaverse with Blender scripts in the `blender_scripts` folder first.
|
36 |
+
|
37 |
+
Download pretrained [LGM](https://huggingface.co/ashawkey/LGM/blob/main/model_fixrot.safetensors) to `pretrained/model_fixrot.safetensors`.
|
38 |
+
|
39 |
+
L4GM model training:
|
40 |
+
```sh
|
41 |
+
accelerate launch \
|
42 |
+
--config_file acc_configs/gpu8.yaml \
|
43 |
+
main.py big \
|
44 |
+
--workspace workspace_recon \
|
45 |
+
--resume pretrained/model_fixrot.safetensors \
|
46 |
+
--data_mode 4d \
|
47 |
+
--num_epochs 200 \
|
48 |
+
--prob_cam_jitter 0 \
|
49 |
+
--datalist data_train/datalist_8fps.txt \
|
50 |
+
```
|
51 |
+
Our released checkpoint uses `--num_epochs 500`.
|
52 |
+
|
53 |
+
4D Interpolation model training:
|
54 |
+
```sh
|
55 |
+
accelerate launch \
|
56 |
+
--config_file acc_configs/gpu8.yaml \
|
57 |
+
main.py big \
|
58 |
+
--workspace workspace_interp \
|
59 |
+
--resume workspace_recon/model.safetensors \
|
60 |
+
--data_mode 4d_interp \
|
61 |
+
--num_frames 4 \
|
62 |
+
--num_epochs 200 \
|
63 |
+
--prob_cam_jitter 0 \
|
64 |
+
--prob_grid_distortion 0 \
|
65 |
+
--datalist data_train/datalist_24fps.txt \
|
66 |
+
```
|
67 |
+
|
68 |
+
### Citation
|
69 |
+
```bib
|
70 |
+
@inproceedings{ren2024l4gm,
|
71 |
+
title={L4GM: Large 4D Gaussian Reconstruction Model},
|
72 |
+
author={Jiawei Ren and Kevin Xie and Ashkan Mirzaei and Hanxue Liang and Xiaohui Zeng and Karsten Kreis and Ziwei Liu and Antonio Torralba and Sanja Fidler and Seung Wook Kim and Huan Ling},
|
73 |
+
booktitle={Proceedings of Neural Information Processing Systems(NeurIPS)},
|
74 |
+
month = {Dec},
|
75 |
+
year={2024}
|
76 |
+
}
|
77 |
+
```
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tyro
|
2 |
+
accelerate==0.28.0
|
3 |
+
imageio
|
4 |
+
imageio-ffmpeg
|
5 |
+
lpips
|
6 |
+
Pillow
|
7 |
+
safetensors
|
8 |
+
scikit-image
|
9 |
+
scikit-learn
|
10 |
+
scipy
|
11 |
+
tqdm
|
12 |
+
kiui >= 0.2.3
|
13 |
+
roma
|
14 |
+
plyfile
|
15 |
+
|
16 |
+
# mvdream
|
17 |
+
diffusers==0.27.2
|
18 |
+
huggingface_hub==0.23.5
|
19 |
+
transformers
|
20 |
+
|