fffiloni commited on
Commit
2cdb96e
1 Parent(s): 1131d88

Migrated from GitHub

Browse files
CONTRIBUTING.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Developer Certificate of Origin
2
+ Version 1.1
3
+
4
+ Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
5
+
6
+ Everyone is permitted to copy and distribute verbatim copies of this
7
+ license document, but changing it is not allowed.
8
+
9
+
10
+ Developer's Certificate of Origin 1.1
11
+
12
+ By making a contribution to this project, I certify that:
13
+
14
+ (a) The contribution was created in whole or in part by me and I
15
+ have the right to submit it under the open source license
16
+ indicated in the file; or
17
+
18
+ (b) The contribution is based upon previous work that, to the best
19
+ of my knowledge, is covered under an appropriate open source
20
+ license and I have the right under that license to submit that
21
+ work with modifications, whether created in whole or in part
22
+ by me, under the same open source license (unless I am
23
+ permitted to submit under a different license), as indicated
24
+ in the file; or
25
+
26
+ (c) The contribution was provided directly to me by some other
27
+ person who certified (a), (b) or (c) and I have not modified
28
+ it.
29
+
30
+ (d) I understand and agree that this project and the contribution
31
+ are public and that a record of the contribution (including all
32
+ personal information I submit with it, including my sign-off) is
33
+ maintained indefinitely and may be redistributed consistent with
34
+ this project or the open source license(s) involved.
LICENSE ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ Copyright 2024 NVIDIA Corporation
180
+
181
+ Licensed under the Apache License, Version 2.0 (the "License");
182
+ you may not use this file except in compliance with the License.
183
+ You may obtain a copy of the License at
184
+
185
+ http://www.apache.org/licenses/LICENSE-2.0
186
+
187
+ Unless required by applicable law or agreed to in writing, software
188
+ distributed under the License is distributed on an "AS IS" BASIS,
189
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
190
+ See the License for the specific language governing permissions and
191
+ limitations under the License.
192
+
193
+
194
+ PORTIONS LICENSED AS FOLLOWS
195
+
196
+ > core/utils.py
197
+ > mvdream/mv_unet.py
198
+ > mvdream/pipeline_mvdream.py
199
+
200
+ MIT License
201
+
202
+ Copyright (c) 2024 3D Topia
203
+
204
+ Permission is hereby granted, free of charge, to any person obtaining a copy
205
+ of this software and associated documentation files (the "Software"), to deal
206
+ in the Software without restriction, including without limitation the rights
207
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
208
+ copies of the Software, and to permit persons to whom the Software is
209
+ furnished to do so, subject to the following conditions:
210
+
211
+ The above copyright notice and this permission notice shall be included in all
212
+ copies or substantial portions of the Software.
213
+
214
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
215
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
216
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
217
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
218
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
219
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
220
+ SOFTWARE.
221
+
222
+
223
+ > core/attention.py
224
+
225
+ Copyright (c) Meta Platforms, Inc. and affiliates.
226
+
227
+ This source code is licensed under the Apache License, Version 2.0
228
+ found in the LICENSE file in the root directory of this source tree.
acc_configs/gpu1.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ machine_rank: 0
6
+ main_training_function: main
7
+ mixed_precision: bf16
8
+ num_machines: 1
9
+ num_processes: 1
10
+ rdzv_backend: static
11
+ same_network: true
12
+ tpu_env: []
13
+ tpu_use_cluster: false
14
+ tpu_use_sudo: false
15
+ use_cpu: false
acc_configs/gpu4.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ machine_rank: 0
6
+ main_training_function: main
7
+ mixed_precision: fp16
8
+ num_machines: 1
9
+ num_processes: 4
10
+ rdzv_backend: static
11
+ same_network: true
12
+ tpu_env: []
13
+ tpu_use_cluster: false
14
+ tpu_use_sudo: false
15
+ use_cpu: false
acc_configs/gpu6.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ machine_rank: 0
6
+ main_training_function: main
7
+ mixed_precision: fp16
8
+ num_machines: 1
9
+ num_processes: 6
10
+ rdzv_backend: static
11
+ same_network: true
12
+ tpu_env: []
13
+ tpu_use_cluster: false
14
+ tpu_use_sudo: false
15
+ use_cpu: false
acc_configs/gpu8.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ machine_rank: 0
6
+ main_training_function: main
7
+ mixed_precision: bf16
8
+ num_machines: 1
9
+ num_processes: 8
10
+ rdzv_backend: static
11
+ same_network: true
12
+ tpu_env: []
13
+ tpu_use_cluster: false
14
+ tpu_use_sudo: false
15
+ use_cpu: false
assets/teaser.jpg ADDED
blender_scripts/render_objaverse.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse, sys, os, math, re
17
+ import bpy
18
+ from mathutils import Vector, Matrix
19
+ import numpy as np
20
+ import cv2
21
+ import signal
22
+ from contextlib import contextmanager
23
+ from loguru import logger
24
+ from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Set, Tuple
25
+ import random
26
+ class TimeoutException(Exception): pass
27
+
28
+ logger.info('Rendering started.')
29
+
30
+ @contextmanager
31
+ def time_limit(seconds):
32
+ def signal_handler(signum, frame):
33
+ raise TimeoutException("Timed out!")
34
+ signal.signal(signal.SIGALRM, signal_handler)
35
+ signal.alarm(seconds)
36
+ try:
37
+ yield
38
+ finally:
39
+ signal.alarm(0)
40
+
41
+ parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.')
42
+ parser.add_argument(
43
+ '--seed', type=int, default=0,
44
+ help='number of views to be rendered')
45
+ parser.add_argument(
46
+ '--views', type=int, default=4,
47
+ help='number of views to be rendered')
48
+ parser.add_argument(
49
+ 'obj', type=str,
50
+ help='Path to the obj file to be rendered.')
51
+ parser.add_argument(
52
+ '--output_folder', type=str, default='/tmp',
53
+ help='The path the output will be dumped to.')
54
+ parser.add_argument(
55
+ '--scale', type=float, default=1,
56
+ help='Scaling factor applied to model. Depends on size of mesh.')
57
+ parser.add_argument(
58
+ '--format', type=str, default='PNG',
59
+ help='Format of files generated. Either PNG or OPEN_EXR')
60
+
61
+ parser.add_argument(
62
+ '--resolution', type=int, default=512,
63
+ help='Resolution of the images.')
64
+ parser.add_argument(
65
+ '--engine', type=str, default='CYCLES',
66
+ help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...')
67
+ parser.add_argument(
68
+ '--gpu', type=int, default=0,
69
+ help='gpu.')
70
+ parser.add_argument(
71
+ '--animation_idx', type=int, default=0,
72
+ help='The index of animation')
73
+
74
+ parser.add_argument(
75
+ '--camera_option', type=str, default='fixed',
76
+ help='Camera Options')
77
+ parser.add_argument(
78
+ '--fixed_animation_length', type=int, default=-1,
79
+ help='Set animation length to fixed number of framnes')
80
+ parser.add_argument(
81
+ '--step_angle', type=int, default=3,
82
+ help='Angle in degree for each step camera rotation')
83
+ parser.add_argument(
84
+ '--downsample', type=int, default=1,
85
+ help='Downsample ratio. No downsample by default')
86
+
87
+ argv = sys.argv[sys.argv.index("--") + 1:]
88
+ args = parser.parse_args(argv)
89
+
90
+
91
+ model_identifier = os.path.split(args.obj)[1].split('.')[0]
92
+ synset_idx = args.obj.split('/')[-2]
93
+
94
+ save_root = os.path.join(os.path.abspath(args.output_folder), synset_idx, model_identifier, f'{args.animation_idx:03d}')
95
+
96
+ # Set up rendering
97
+ context = bpy.context
98
+ scene = bpy.context.scene
99
+ render = bpy.context.scene.render
100
+
101
+ render.engine = args.engine# 'BLENDER_EEVEE'
102
+ render.image_settings.color_mode = 'RGBA' # ('RGB', 'RGBA', ...)
103
+ render.image_settings.file_format = args.format # ('PNG', 'OPEN_EXR', 'JPEG, ...)
104
+ render.resolution_x = args.resolution
105
+ render.resolution_y = args.resolution
106
+ render.resolution_percentage = 100
107
+ bpy.context.scene.cycles.filter_width = 0.01
108
+ bpy.context.scene.render.film_transparent = True
109
+ render_depth_normal = False
110
+ bpy.context.scene.cycles.device = 'GPU'
111
+ bpy.context.scene.cycles.diffuse_bounces = 1
112
+ bpy.context.scene.cycles.glossy_bounces = 1
113
+ bpy.context.scene.cycles.transparent_max_bounces = 1
114
+ bpy.context.scene.cycles.transmission_bounces = 1
115
+ bpy.context.scene.cycles.samples = 16
116
+ bpy.context.scene.cycles.use_denoising = True
117
+ bpy.context.scene.cycles.denoiser = 'OPTIX'
118
+ bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
119
+ bpy.context.scene.cycles.device = 'GPU'
120
+
121
+
122
+ def enable_cuda_devices():
123
+ prefs = bpy.context.preferences
124
+ cprefs = prefs.addons['cycles'].preferences
125
+ cprefs.get_devices()
126
+ # Attempt to set GPU device types if available
127
+ for compute_device_type in ('CUDA', 'OPENCL', 'NONE'):
128
+ try:
129
+ cprefs.compute_device_type = compute_device_type
130
+ print("Compute device selected: {0}".format(compute_device_type))
131
+ break
132
+ except TypeError:
133
+ pass
134
+
135
+ # Any CUDA/OPENCL devices?
136
+ acceleratedTypes = ['CUDA', 'OPENCL', 'OPTIX']
137
+ acceleratedTypes = ['CUDA', 'OPENCL']
138
+ accelerated = any(device.type in acceleratedTypes for device in cprefs.devices)
139
+ print('Accelerated render = {0}'.format(accelerated))
140
+
141
+ # If we have CUDA/OPENCL devices, enable only them, otherwise enable
142
+ # all devices (assumed to be CPU)
143
+ print(cprefs.devices)
144
+ for idx, device in enumerate(cprefs.devices):
145
+ device.use = (not accelerated or device.type in acceleratedTypes)# and idx == args.gpu
146
+ print('Device enabled ({type}) = {enabled}'.format(type=device.type, enabled=device.use))
147
+ return accelerated
148
+
149
+ enable_cuda_devices()
150
+ context.active_object.select_set(True)
151
+ bpy.ops.object.delete()
152
+
153
+ # Import textured mesh
154
+ bpy.ops.object.select_all(action='DESELECT')
155
+
156
+ try:
157
+ with time_limit(1000):
158
+ imported_object = bpy.ops.import_scene.gltf(filepath=args.obj, merge_vertices=True, guess_original_bind_pose=False, bone_heuristic="TEMPERANCE")
159
+ except TimeoutException as e:
160
+ print("Timed out finished!")
161
+ exit()
162
+
163
+
164
+ # count animated frames
165
+ animation_names = []
166
+ ending_frame_list = {}
167
+ for k in bpy.data.actions.keys():
168
+ matched_obj_name = ''
169
+ for obj in bpy.context.selected_objects:
170
+ if '_'+obj.name in k and len(obj.name) > len(matched_obj_name):
171
+ matched_obj_name = obj.name
172
+ a_name = k.replace('_'+matched_obj_name, '')
173
+ a = bpy.data.actions[k]
174
+ frame_start, frame_end = map(int, a.frame_range)
175
+ logger.info(f'{k} | frame start: {frame_start}, frame end: {frame_end} | fps: {bpy.context.scene.render.fps}')
176
+ if a_name not in animation_names:
177
+ animation_names.append(a_name)
178
+ ending_frame_list[a_name] = frame_end
179
+ else:
180
+ ending_frame_list[a_name] = max(frame_end, ending_frame_list[a_name])
181
+
182
+
183
+
184
+ selected_a_name = animation_names[args.animation_idx]
185
+ max_frame = ending_frame_list[selected_a_name]
186
+ for obj in bpy.context.selected_objects:
187
+ if obj.animation_data is not None:
188
+ obj_a_name = selected_a_name+'_'+obj.name
189
+ if obj_a_name in bpy.data.actions:
190
+ print('Found ', obj_a_name)
191
+ obj.animation_data.action = bpy.data.actions[obj_a_name]
192
+ else:
193
+ print('Miss ', obj_a_name)
194
+
195
+ num_frames = args.fixed_animation_length if args.fixed_animation_length != -1 else max_frame
196
+ num_frames = num_frames // args.downsample
197
+
198
+ if num_frames == 0:
199
+ print("No animation!")
200
+ exit()
201
+
202
+ # from https://github.com/allenai/objaverse-xl/blob/main/scripts/rendering/blender_script.py
203
+ def get_3x4_RT_matrix_from_blender(cam: bpy.types.Object):
204
+ """Returns the 3x4 RT matrix from the given camera.
205
+
206
+ Taken from Zero123, which in turn was taken from
207
+ https://github.com/panmari/stanford-shapenet-renderer/blob/master/render_blender.py
208
+
209
+ Args:
210
+ cam (bpy.types.Object): The camera object.
211
+
212
+ Returns:
213
+ Matrix: The 3x4 RT matrix from the given camera.
214
+ """
215
+ # Use matrix_world instead to account for all constraints
216
+ location, rotation = cam.matrix_world.decompose()[0:2]
217
+ R_world2bcam = rotation.to_matrix().transposed()
218
+
219
+ # Use location from matrix_world to account for constraints:
220
+ T_world2bcam = -1 * R_world2bcam @ location
221
+
222
+ # put into 3x4 matrix
223
+ RT = Matrix(
224
+ (
225
+ R_world2bcam[0][:] + (T_world2bcam[0],),
226
+ R_world2bcam[1][:] + (T_world2bcam[1],),
227
+ R_world2bcam[2][:] + (T_world2bcam[2],),
228
+ )
229
+ )
230
+ return RT
231
+ def _create_light(
232
+ name: str,
233
+ light_type: Literal["POINT", "SUN", "SPOT", "AREA"],
234
+ location: Tuple[float, float, float],
235
+ rotation: Tuple[float, float, float],
236
+ energy: float,
237
+ use_shadow: bool = False,
238
+ specular_factor: float = 1.0,
239
+ ):
240
+ """Creates a light object.
241
+
242
+ Args:
243
+ name (str): Name of the light object.
244
+ light_type (Literal["POINT", "SUN", "SPOT", "AREA"]): Type of the light.
245
+ location (Tuple[float, float, float]): Location of the light.
246
+ rotation (Tuple[float, float, float]): Rotation of the light.
247
+ energy (float): Energy of the light.
248
+ use_shadow (bool, optional): Whether to use shadows. Defaults to False.
249
+ specular_factor (float, optional): Specular factor of the light. Defaults to 1.0.
250
+
251
+ Returns:
252
+ bpy.types.Object: The light object.
253
+ """
254
+
255
+ light_data = bpy.data.lights.new(name=name, type=light_type)
256
+ light_object = bpy.data.objects.new(name, light_data)
257
+ bpy.context.collection.objects.link(light_object)
258
+ light_object.location = location
259
+ light_object.rotation_euler = rotation
260
+ light_data.use_shadow = use_shadow
261
+ light_data.specular_factor = specular_factor
262
+ light_data.energy = energy
263
+ return light_object
264
+
265
+
266
+ def randomize_lighting() -> Dict[str, bpy.types.Object]:
267
+ """Randomizes the lighting in the scene.
268
+
269
+ Returns:
270
+ Dict[str, bpy.types.Object]: Dictionary of the lights in the scene. The keys are
271
+ "key_light", "fill_light", "rim_light", and "bottom_light".
272
+ """
273
+
274
+ # Clear existing lights
275
+ bpy.ops.object.select_all(action="DESELECT")
276
+ bpy.ops.object.select_by_type(type="LIGHT")
277
+ bpy.ops.object.delete()
278
+
279
+ # Create key light
280
+ key_light = _create_light(
281
+ name="Key_Light",
282
+ light_type="SUN",
283
+ location=(0, 0, 0),
284
+ rotation=(0.785398, 0, -0.785398),
285
+ # energy=random.choice([3, 4, 5]),
286
+ energy=4,
287
+ )
288
+
289
+ # Create fill light
290
+ fill_light = _create_light(
291
+ name="Fill_Light",
292
+ light_type="SUN",
293
+ location=(0, 0, 0),
294
+ rotation=(0.785398, 0, 2.35619),
295
+ # energy=random.choice([2, 3, 4]),
296
+ energy=3,
297
+ )
298
+
299
+ # Create rim light
300
+ rim_light = _create_light(
301
+ name="Rim_Light",
302
+ light_type="SUN",
303
+ location=(0, 0, 0),
304
+ rotation=(-0.785398, 0, -3.92699),
305
+ # energy=random.choice([3, 4, 5]),
306
+ energy=4,
307
+ )
308
+
309
+ # Create bottom light
310
+ bottom_light = _create_light(
311
+ name="Bottom_Light",
312
+ light_type="SUN",
313
+ location=(0, 0, 0),
314
+ rotation=(3.14159, 0, 0),
315
+ # energy=random.choice([1, 2, 3]),
316
+ energy=2,
317
+ )
318
+
319
+ return dict(
320
+ key_light=key_light,
321
+ fill_light=fill_light,
322
+ rim_light=rim_light,
323
+ bottom_light=bottom_light,
324
+ )
325
+
326
+ def scene_bbox(
327
+ single_obj = None, ignore_matrix = False
328
+ ):
329
+ """Returns the bounding box of the scene.
330
+
331
+ Taken from Shap-E rendering script
332
+ (https://github.com/openai/shap-e/blob/main/shap_e/rendering/blender/blender_script.py#L68-L82)
333
+
334
+ Args:
335
+ single_obj (Optional[bpy.types.Object], optional): If not None, only computes
336
+ the bounding box for the given object. Defaults to None.
337
+ ignore_matrix (bool, optional): Whether to ignore the object's matrix. Defaults
338
+ to False.
339
+
340
+ Raises:
341
+ RuntimeError: If there are no objects in the scene.
342
+
343
+ Returns:
344
+ Tuple[Vector, Vector]: The minimum and maximum coordinates of the bounding box.
345
+ """
346
+ bbox_min = (math.inf,) * 3
347
+ bbox_max = (-math.inf,) * 3
348
+ found = False
349
+ for i in range(num_frames):
350
+ bpy.context.scene.frame_set(i * args.downsample)
351
+ for obj in get_scene_meshes() if single_obj is None else [single_obj]:
352
+ found = True
353
+ for coord in obj.bound_box:
354
+ coord = Vector(coord)
355
+ if not ignore_matrix:
356
+ coord = obj.matrix_world @ coord
357
+ bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord))
358
+ bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord))
359
+
360
+ if not found:
361
+ raise RuntimeError("no objects in scene to compute bounding box for")
362
+
363
+ return Vector(bbox_min), Vector(bbox_max)
364
+
365
+ def get_scene_meshes():
366
+ """Returns all meshes in the scene.
367
+
368
+ Yields:
369
+ Generator[bpy.types.Object, None, None]: Generator of all meshes in the scene.
370
+ """
371
+ for obj in bpy.context.scene.objects.values():
372
+ if isinstance(obj.data, (bpy.types.Mesh)):
373
+ yield obj
374
+
375
+ def get_scene_root_objects():
376
+ """Returns all root objects in the scene.
377
+
378
+ Yields:
379
+ Generator[bpy.types.Object, None, None]: Generator of all root objects in the
380
+ scene.
381
+ """
382
+ for obj in bpy.context.scene.objects.values():
383
+ if not obj.parent:
384
+ yield obj
385
+
386
+ def normalize_scene():
387
+ """Normalizes the scene by scaling and translating it to fit in a unit cube centered
388
+ at the origin.
389
+
390
+ Mostly taken from the Point-E / Shap-E rendering script
391
+ (https://github.com/openai/point-e/blob/main/point_e/evals/scripts/blender_script.py#L97-L112),
392
+ but fix for multiple root objects: (see bug report here:
393
+ https://github.com/openai/shap-e/pull/60).
394
+
395
+ Returns:
396
+ None
397
+ """
398
+ if len(list(get_scene_root_objects())) > 1:
399
+ # create an empty object to be used as a parent for all root objects
400
+ parent_empty = bpy.data.objects.new("ParentEmpty", None)
401
+ bpy.context.scene.collection.objects.link(parent_empty)
402
+
403
+ # parent all root objects to the empty object
404
+ for obj in get_scene_root_objects():
405
+ if obj != parent_empty:
406
+ obj.parent = parent_empty
407
+
408
+ bbox_min, bbox_max = scene_bbox()
409
+ scale = 1 / max(bbox_max - bbox_min)
410
+ logger.info(f"Scale: {scale}")
411
+ for obj in get_scene_root_objects():
412
+ obj.scale = obj.scale * scale
413
+
414
+ # Apply scale to matrix_world.
415
+ bpy.context.view_layer.update()
416
+ bbox_min, bbox_max = scene_bbox()
417
+ offset = -(bbox_min + bbox_max) / 2
418
+ for obj in get_scene_root_objects():
419
+ obj.matrix_world.translation += offset
420
+ bpy.ops.object.select_all(action="DESELECT")
421
+
422
+ # unparent the camera
423
+ bpy.data.objects["Camera"].parent = None
424
+
425
+ normalize_scene()
426
+
427
+ randomize_lighting()
428
+
429
+ # Place camera
430
+ cam = scene.objects['Camera']
431
+ cam.location = (0, 1.5, 0) # radius equals to 1
432
+ cam.data.lens = 35
433
+ cam.data.sensor_width = 32
434
+
435
+ cam_constraint = cam.constraints.new(type='TRACK_TO')
436
+ cam_constraint.track_axis = 'TRACK_NEGATIVE_Z'
437
+ cam_constraint.up_axis = 'UP_Y'
438
+
439
+ cam_empty = bpy.data.objects.new("Empty", None)
440
+ cam_empty.location = (0, 0, 0)
441
+ cam.parent = cam_empty
442
+
443
+ scene.collection.objects.link(cam_empty)
444
+ context.view_layer.objects.active = cam_empty
445
+ cam_constraint.target = cam_empty
446
+
447
+ stepsize = 360.0 / args.views
448
+ rotation_mode = 'XYZ'
449
+
450
+
451
+ np.random.seed(args.seed)
452
+
453
+ if args.camera_option == "fixed":
454
+ for scene in bpy.data.scenes:
455
+ scene.cycles.device = 'GPU'
456
+
457
+ elevation_angle = 0.
458
+ rotation_angle = 0.
459
+
460
+ for view_idx in range(args.views):
461
+ img_folder = os.path.join(save_root, f'{view_idx:03d}', 'img')
462
+ mask_folder = os.path.join(save_root, f'{view_idx:03d}', 'mask')
463
+ camera_folder = os.path.join(save_root, f'{view_idx:03d}', 'camera')
464
+
465
+ os.makedirs(img_folder, exist_ok=True)
466
+ os.makedirs(mask_folder, exist_ok=True)
467
+ os.makedirs(camera_folder, exist_ok=True)
468
+
469
+ np.save(os.path.join(camera_folder, 'rotation'), np.array([rotation_angle + view_idx * stepsize for _ in range(num_frames)]))
470
+ np.save(os.path.join(camera_folder, 'elevation'), np.array([elevation_angle for _ in range(num_frames)]))
471
+
472
+ cam_empty.rotation_euler[2] = math.radians(rotation_angle + view_idx * stepsize)
473
+ cam_empty.rotation_euler[0] = math.radians(elevation_angle)
474
+
475
+ # save camera RT matrix
476
+ rt_matrix = get_3x4_RT_matrix_from_blender(cam)
477
+ rt_matrix_path = os.path.join(camera_folder, "rt_matrix.npy")
478
+ np.save(rt_matrix_path, rt_matrix)
479
+ for i in range(0, num_frames):
480
+ bpy.context.scene.frame_set(i * args.downsample)
481
+ render_file_path = os.path.join(img_folder,'%03d.png' % (i))
482
+ scene.render.filepath = render_file_path
483
+ bpy.ops.render.render(write_still=True)
484
+
485
+ for i in range(0, num_frames):
486
+ img = cv2.imread(os.path.join(img_folder, '%03d.png' % (i)), cv2.IMREAD_UNCHANGED)
487
+ mask = img[:, :, 3:4] / 255.0
488
+ white_img = img[:, :, :3] * mask + np.ones_like(img[:, :, :3]) * (1 - mask) * 255
489
+ white_img = np.clip(white_img, 0, 255)
490
+ cv2.imwrite(os.path.join(img_folder, '%03d.jpg' % (i)), white_img)
491
+ cv2.imwrite(os.path.join(mask_folder, '%03d.png'%(i)), img[:, :, 3])
492
+ os.system('rm %s'%(os.path.join(img_folder, '%03d.png' % (i))))
493
+
494
+ elif args.camera_option == "random":
495
+ for scene in bpy.data.scenes:
496
+ scene.cycles.device = 'GPU'
497
+
498
+ for view_idx in range(args.views):
499
+ elevation_angle = np.random.rand(1) * 35 - 5 # [-5, 30]
500
+ rotation_angle = np.random.rand(1) * 360
501
+
502
+ img_folder = os.path.join(save_root, f'{view_idx:03d}', 'img')
503
+ mask_folder = os.path.join(save_root, f'{view_idx:03d}', 'mask')
504
+ camera_folder = os.path.join(save_root, f'{view_idx:03d}', 'camera')
505
+
506
+ os.makedirs(img_folder, exist_ok=True)
507
+ os.makedirs(mask_folder, exist_ok=True)
508
+ os.makedirs(camera_folder, exist_ok=True)
509
+
510
+ np.save(os.path.join(camera_folder, 'rotation'), np.array([rotation_angle for _ in range(num_frames)]))
511
+ np.save(os.path.join(camera_folder, 'elevation'), np.array([elevation_angle for _ in range(num_frames)]))
512
+
513
+ cam_empty.rotation_euler[2] = math.radians(rotation_angle)
514
+ cam_empty.rotation_euler[0] = math.radians(elevation_angle)
515
+
516
+ # save camera RT matrix
517
+ rt_matrix = get_3x4_RT_matrix_from_blender(cam)
518
+ rt_matrix_path = os.path.join(camera_folder, "rt_matrix.npy")
519
+ np.save(rt_matrix_path, rt_matrix)
520
+
521
+ for i in range(0, num_frames):
522
+ bpy.context.scene.frame_set(i * args.downsample)
523
+ render_file_path = os.path.join(img_folder,'%03d.png' % (i))
524
+ scene.render.filepath = render_file_path
525
+ bpy.ops.render.render(write_still=True)
526
+
527
+ for i in range(0, num_frames):
528
+ img = cv2.imread(os.path.join(img_folder, '%03d.png' % (i)), cv2.IMREAD_UNCHANGED)
529
+ mask = img[:, :, 3:4] / 255.0
530
+ white_img = img[:, :, :3] * mask + np.ones_like(img[:, :, :3]) * (1 - mask) * 255
531
+ white_img = np.clip(white_img, 0, 255)
532
+ cv2.imwrite(os.path.join(img_folder, '%03d.jpg' % (i)), white_img)
533
+ cv2.imwrite(os.path.join(mask_folder, '%03d.png'%(i)), img[:, :, 3])
534
+ os.system('rm %s'%(os.path.join(img_folder, '%03d.png' % (i))))
535
+
536
+ else:
537
+ raise NotImplemented
core/__init__.py ADDED
File without changes
core/attention.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import os
11
+ import warnings
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
17
+ try:
18
+ if XFORMERS_ENABLED:
19
+ from xformers.ops import memory_efficient_attention, unbind
20
+
21
+ XFORMERS_AVAILABLE = True
22
+ # warnings.warn("xFormers is available (Attention)")
23
+ else:
24
+ warnings.warn("xFormers is disabled (Attention)")
25
+ raise ImportError
26
+ except ImportError:
27
+ XFORMERS_AVAILABLE = False
28
+ warnings.warn("xFormers is not available (Attention)")
29
+
30
+
31
+ class Attention(nn.Module):
32
+ def __init__(
33
+ self,
34
+ dim: int,
35
+ num_heads: int = 8,
36
+ qkv_bias: bool = False,
37
+ proj_bias: bool = True,
38
+ attn_drop: float = 0.0,
39
+ proj_drop: float = 0.0,
40
+ ) -> None:
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ head_dim = dim // num_heads
44
+ self.scale = head_dim**-0.5
45
+
46
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
47
+ self.attn_drop = nn.Dropout(attn_drop)
48
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
49
+ self.proj_drop = nn.Dropout(proj_drop)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ B, N, C = x.shape
53
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
54
+
55
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
56
+ attn = q @ k.transpose(-2, -1)
57
+
58
+ attn = attn.softmax(dim=-1)
59
+ attn = self.attn_drop(attn)
60
+
61
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
62
+ x = self.proj(x)
63
+ x = self.proj_drop(x)
64
+ return x
65
+
66
+
67
+ class MemEffAttention(Attention):
68
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
69
+ if not XFORMERS_AVAILABLE:
70
+ if attn_bias is not None:
71
+ raise AssertionError("xFormers is required for using nested tensors")
72
+ return super().forward(x)
73
+
74
+ B, N, C = x.shape
75
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
76
+
77
+ q, k, v = unbind(qkv, 2)
78
+
79
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
80
+
81
+ x = x.reshape([B, N, C])
82
+
83
+ x = self.proj(x)
84
+ x = self.proj_drop(x)
85
+ return x
core/gs.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from core.options import Options
23
+
24
+ import kiui
25
+
26
+ from gsplat.rendering import rasterization
27
+
28
+ class GaussianRenderer:
29
+ def __init__(self, opt: Options):
30
+
31
+ self.opt = opt
32
+ self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
33
+
34
+ # intrinsics
35
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
36
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
37
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
38
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
39
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
40
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
41
+ self.proj_matrix[2, 3] = 1
42
+
43
+ f = self.opt.output_size / (2 * self.tan_half_fov)
44
+ self.K = torch.tensor([[f, 0., self.opt.output_size/2.], [0., f, self.opt.output_size/2.], [0., 0., 1.]], dtype=torch.float32, device="cuda")
45
+
46
+ def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None):
47
+ # gaussians: [B, N, 14]
48
+ # cam_view, cam_view_proj: [B, V, 4, 4]
49
+ # cam_pos: [B, V, 3]
50
+
51
+ device = gaussians.device
52
+ B, V = cam_view.shape[:2]
53
+
54
+ # loop of loop...
55
+ images = []
56
+ alphas = []
57
+ for b in range(B):
58
+
59
+ # pos, opacity, scale, rotation, shs
60
+ means3D = gaussians[b, :, 0:3].contiguous().float()
61
+ opacity = gaussians[b, :, 3:4].contiguous().float()
62
+ scales = gaussians[b, :, 4:7].contiguous().float()
63
+ rotations = gaussians[b, :, 7:11].contiguous().float()
64
+ rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
65
+
66
+ # render novel views
67
+ view_matrix = cam_view[b].float()
68
+ view_proj_matrix = cam_view_proj[b].float()
69
+ campos = cam_pos[b].float()
70
+
71
+ viewmat = view_matrix.transpose(2, 1) # [V, 4, 4]
72
+
73
+
74
+ rendered_image_all, rendered_alpha_all, info = rasterization(
75
+ means=means3D,
76
+ quats=rotations,
77
+ scales=scales,
78
+ opacities=opacity.squeeze(-1),
79
+ colors=rgbs,
80
+ viewmats=viewmat,
81
+ Ks=torch.stack([self.K for _ in range(V)]),
82
+ width=self.opt.output_size,
83
+ height=self.opt.output_size,
84
+ near_plane=self.opt.znear,
85
+ far_plane=self.opt.zfar,
86
+ packed=False,
87
+ backgrounds=torch.stack([self.bg_color for _ in range(V)]) if self.bg_color is not None else None,
88
+ render_mode="RGB",
89
+ )
90
+ for rendered_image, rendered_alpha in zip(rendered_image_all, rendered_alpha_all):
91
+
92
+ rendered_image = rendered_image.permute(2, 0, 1)
93
+ rendered_image = rendered_image.clamp(0, 1)
94
+
95
+ rendered_alpha = rendered_alpha.permute(2, 0, 1)
96
+
97
+ images.append(rendered_image)
98
+ alphas.append(rendered_alpha)
99
+
100
+ images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
101
+ alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
102
+
103
+ return {
104
+ "image": images, # [B, V, 3, H, W]
105
+ "alpha": alphas, # [B, V, 1, H, W]
106
+ }
107
+
108
+
109
+ def save_ply(self, gaussians, path, compatible=True):
110
+ # gaussians: [B, N, 14]
111
+ # compatible: save pre-activated gaussians as in the original paper
112
+
113
+ assert gaussians.shape[0] == 1, 'only support batch size 1'
114
+
115
+ from plyfile import PlyData, PlyElement
116
+
117
+ means3D = gaussians[0, :, 0:3].contiguous().float()
118
+ opacity = gaussians[0, :, 3:4].contiguous().float()
119
+ scales = gaussians[0, :, 4:7].contiguous().float()
120
+ rotations = gaussians[0, :, 7:11].contiguous().float()
121
+ shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
122
+
123
+ # prune by opacity
124
+ mask = opacity.squeeze(-1) >= 0.005
125
+ means3D = means3D[mask]
126
+ opacity = opacity[mask]
127
+ scales = scales[mask]
128
+ rotations = rotations[mask]
129
+ shs = shs[mask]
130
+
131
+ # invert activation to make it compatible with the original ply format
132
+ if compatible:
133
+ opacity = kiui.op.inverse_sigmoid(opacity)
134
+ scales = torch.log(scales + 1e-8)
135
+ shs = (shs - 0.5) / 0.28209479177387814
136
+
137
+ xyzs = means3D.detach().cpu().numpy()
138
+ f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
139
+ opacities = opacity.detach().cpu().numpy()
140
+ scales = scales.detach().cpu().numpy()
141
+ rotations = rotations.detach().cpu().numpy()
142
+
143
+ l = ['x', 'y', 'z']
144
+ # All channels except the 3 DC
145
+ for i in range(f_dc.shape[1]):
146
+ l.append('f_dc_{}'.format(i))
147
+ l.append('opacity')
148
+ for i in range(scales.shape[1]):
149
+ l.append('scale_{}'.format(i))
150
+ for i in range(rotations.shape[1]):
151
+ l.append('rot_{}'.format(i))
152
+
153
+ dtype_full = [(attribute, 'f4') for attribute in l]
154
+
155
+ elements = np.empty(xyzs.shape[0], dtype=dtype_full)
156
+ attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
157
+ elements[:] = list(map(tuple, attributes))
158
+ el = PlyElement.describe(elements, 'vertex')
159
+
160
+ PlyData([el]).write(path)
161
+
162
+ def load_ply(self, path, compatible=True):
163
+
164
+ from plyfile import PlyData, PlyElement
165
+
166
+ plydata = PlyData.read(path)
167
+
168
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
169
+ np.asarray(plydata.elements[0]["y"]),
170
+ np.asarray(plydata.elements[0]["z"])), axis=1)
171
+ print("Number of points at loading : ", xyz.shape[0])
172
+
173
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
174
+
175
+ shs = np.zeros((xyz.shape[0], 3))
176
+ shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
177
+ shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
178
+ shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
179
+
180
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
181
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
182
+ for idx, attr_name in enumerate(scale_names):
183
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
184
+
185
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
186
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
187
+ for idx, attr_name in enumerate(rot_names):
188
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
189
+
190
+ gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
191
+ gaussians = torch.from_numpy(gaussians).float() # cpu
192
+
193
+ if compatible:
194
+ gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
195
+ gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
196
+ gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
197
+
198
+ return gaussians
core/models.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+
21
+ import kiui
22
+ from kiui.lpips import LPIPS
23
+
24
+ from core.unet import UNet
25
+ from core.options import Options
26
+ from core.gs import GaussianRenderer
27
+
28
+
29
+
30
+ class LGM(nn.Module):
31
+ def __init__(
32
+ self,
33
+ opt: Options,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.opt = opt
38
+
39
+ # unet
40
+ self.unet = UNet(
41
+ 9, 14 * self.opt.gaussian_perpixel,
42
+ down_channels=self.opt.down_channels,
43
+ down_attention=self.opt.down_attention,
44
+ mid_attention=self.opt.mid_attention,
45
+ up_channels=self.opt.up_channels,
46
+ up_attention=self.opt.up_attention,
47
+ num_views=self.opt.num_input_views,
48
+ num_frames=self.opt.num_frames,
49
+ use_temp_attn=self.opt.use_temp_attn
50
+ )
51
+
52
+ # last conv
53
+ self.conv = nn.Conv2d(14 * self.opt.gaussian_perpixel, 14 * self.opt.gaussian_perpixel, kernel_size=1) # NOTE: maybe remove it if train again
54
+
55
+ # Gaussian Renderer
56
+ self.gs = GaussianRenderer(opt)
57
+
58
+ # activations...
59
+ self.pos_act = lambda x: x.clamp(-1, 1)
60
+ self.scale_act = lambda x: 0.1 * F.softplus(x)
61
+ self.opacity_act = lambda x: torch.sigmoid(x)
62
+ self.rot_act = lambda x: F.normalize(x, dim=-1)
63
+ self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
64
+
65
+ # LPIPS loss
66
+ if self.opt.lambda_lpips > 0:
67
+ self.lpips_loss = LPIPS(net='vgg')
68
+ self.lpips_loss.requires_grad_(False)
69
+
70
+
71
+ def state_dict(self, **kwargs):
72
+ # remove lpips_loss
73
+ state_dict = super().state_dict(**kwargs)
74
+ for k in list(state_dict.keys()):
75
+ if 'lpips_loss' in k:
76
+ del state_dict[k]
77
+ return state_dict
78
+
79
+
80
+ def prepare_default_rays(self, device, elevation=0):
81
+
82
+ from kiui.cam import orbit_camera
83
+ from core.utils import get_rays
84
+
85
+ cam_poses = np.stack([
86
+ orbit_camera(elevation, 0, radius=self.opt.cam_radius),
87
+ orbit_camera(elevation, 90, radius=self.opt.cam_radius),
88
+ orbit_camera(elevation, 180, radius=self.opt.cam_radius),
89
+ orbit_camera(elevation, 270, radius=self.opt.cam_radius),
90
+ ], axis=0) # [4, 4, 4]
91
+ cam_poses = torch.from_numpy(cam_poses)
92
+
93
+ rays_embeddings = []
94
+ for i in range(cam_poses.shape[0]):
95
+ rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
96
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
97
+ rays_embeddings.append(rays_plucker)
98
+
99
+ ## visualize rays for plotting figure
100
+ # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
101
+
102
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
103
+
104
+ return rays_embeddings
105
+
106
+
107
+
108
+ def forward_gaussians(self, images):
109
+ # images: [B, T, 4, 9, H, W]
110
+ # return: Gaussians: [B, dim_t]
111
+
112
+ B, TV, C, H, W = images.shape
113
+ T = self.opt.num_frames
114
+ V = TV // T
115
+ images = images.view(B*T*V, C, H, W)
116
+
117
+ x = self.unet(images) # [B*4, 14, h, w]
118
+ x = self.conv(x) # [B*4, 14, h, w]
119
+
120
+ x = x.reshape(B*T, V, 14 * self.opt.gaussian_perpixel, self.opt.splat_size, self.opt.splat_size)
121
+
122
+ x = x.permute(0, 1, 3, 4, 2).reshape(B*T, -1, 14).contiguous()
123
+
124
+ pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
125
+ opacity = self.opacity_act(x[..., 3:4])
126
+ scale = self.scale_act(x[..., 4:7])
127
+ rotation = self.rot_act(x[..., 7:11])
128
+ rgbs = self.rgb_act(x[..., 11:])
129
+
130
+ gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, T, N, 14]
131
+
132
+ return gaussians
133
+
134
+
135
+ def forward(self, data, step_ratio=1):
136
+ # data: output of the dataloader
137
+ # return: loss
138
+
139
+ results = {}
140
+ loss = 0
141
+
142
+ images = data['input'] # [B, Tx4, 9, h, W], input features
143
+
144
+ B, TV, C, H, W = images.shape
145
+ T = self.opt.num_frames
146
+
147
+ # use the first view to predict gaussians
148
+ gaussians = self.forward_gaussians(images) # [B * T, N, 14]
149
+
150
+ results['gaussians'] = gaussians
151
+
152
+ # always use white bg
153
+ bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
154
+
155
+ # use the other views for rendering and supervision
156
+ data['cam_view'] = data['cam_view'].reshape(B*T, -1, *data['cam_view'].shape[2:])
157
+ data['cam_view_proj'] = data['cam_view_proj'].reshape(B*T, -1, *data['cam_view_proj'].shape[2:])
158
+ data['cam_pos'] = data['cam_pos'].reshape(B*T, -1, *data['cam_pos'].shape[2:])
159
+
160
+ results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
161
+ pred_images = results['image'] # [B*T, V, C, output_size, output_size]
162
+ pred_alphas = results['alpha'] # [B*T, V, 1, output_size, output_size]
163
+
164
+ results['images_pred'] = pred_images
165
+ results['alphas_pred'] = pred_alphas
166
+
167
+
168
+ data['images_output'] = data['images_output'].reshape(B*T, -1, *data['images_output'].shape[2:])
169
+ data['masks_output'] = data['masks_output'].reshape(B*T, -1, *data['masks_output'].shape[2:])
170
+
171
+ gt_images = data['images_output'] # [B*T, V, 3, output_size, output_size], ground-truth novel views
172
+ gt_masks = data['masks_output'] # [B*T, V, 1, output_size, output_size], ground-truth masks
173
+
174
+ gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
175
+
176
+ loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
177
+ loss = loss + loss_mse
178
+
179
+ if self.opt.lambda_lpips > 0:
180
+ loss_lpips = self.lpips_loss(
181
+ # downsampled to at most 256 to reduce memory cost
182
+ F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
183
+ F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
184
+ ).mean()
185
+ results['loss_lpips'] = loss_lpips
186
+ loss = loss + self.opt.lambda_lpips * loss_lpips
187
+
188
+ results['loss'] = loss
189
+
190
+ # metric
191
+ with torch.no_grad():
192
+ psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
193
+ results['psnr'] = psnr
194
+
195
+ return results
core/options.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import tyro
17
+ from dataclasses import dataclass
18
+ from typing import Tuple, Literal, Dict, Optional
19
+
20
+
21
+ @dataclass
22
+ class Options:
23
+ ### model
24
+ # Unet image input size
25
+ input_size: int = 256
26
+ # Unet definition
27
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
28
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
29
+ mid_attention: bool = True
30
+ up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
31
+ up_attention: Tuple[bool, ...] = (True, True, True, False)
32
+ # Unet output size, dependent on the input_size and U-Net structure!
33
+ splat_size: int = 64
34
+ # gaussian render size
35
+ output_size: int = 256
36
+
37
+ ### dataset
38
+ # data mode (only support s3 now)
39
+ data_mode: str = '4d'
40
+ # fovy of the dataset
41
+ fovy: float = 49.1
42
+ # camera near plane
43
+ znear: float = 0.5
44
+ # camera far plane
45
+ zfar: float = 2.5
46
+ # number of all views (input + output)
47
+ num_views: int = 12
48
+ # number of views
49
+ num_input_views: int = 4
50
+ # camera radius
51
+ cam_radius: float = 1.5 # to better use [-1, 1]^3 space
52
+ # num workers
53
+ num_workers: int = 16
54
+ datalist: str=''
55
+
56
+ ### training
57
+ # workspace
58
+ workspace: str = './workspace'
59
+ # resume
60
+ resume: Optional[str] = None
61
+ # batch size (per-GPU)
62
+ batch_size: int = 8
63
+ # gradient accumulation
64
+ gradient_accumulation_steps: int = 1
65
+ # training epochs
66
+ num_epochs: int = 30
67
+ # lpips loss weight
68
+ lambda_lpips: float = 1.0
69
+ # gradient clip
70
+ gradient_clip: float = 1.0
71
+ # mixed precision
72
+ mixed_precision: str = 'bf16'
73
+ # learning rate
74
+ lr: float = 4e-4
75
+ # augmentation prob for grid distortion
76
+ prob_grid_distortion: float = 0.5
77
+ # augmentation prob for camera jitter
78
+ prob_cam_jitter: float = 0.5
79
+ # number of gaussians per pixel
80
+ gaussian_perpixel: int = 1
81
+
82
+ ### testing
83
+ # test image path
84
+ test_path: Optional[str] = None
85
+
86
+ ### misc
87
+ # nvdiffrast backend setting
88
+ force_cuda_rast: bool = False
89
+ # render fancy video with gaussian scaling effect
90
+ fancy_video: bool = False
91
+
92
+ # 4D
93
+ num_frames: int = 8
94
+ use_temp_attn: bool = True
95
+ shuffle_input: bool = True
96
+
97
+ # s3
98
+ sample_by_anim: bool = True
99
+
100
+ # interp
101
+ interpresume: Optional[str] = None
102
+ interpolate_rate: int = 3
103
+
104
+
105
+ # all the default settings
106
+ config_defaults: Dict[str, Options] = {}
107
+ config_doc: Dict[str, str] = {}
108
+
109
+ config_doc['lrm'] = 'the default settings for LGM'
110
+ config_defaults['lrm'] = Options()
111
+
112
+
113
+ config_doc['big'] = 'big model with higher resolution Gaussians'
114
+ config_defaults['big'] = Options(
115
+ input_size=256,
116
+ up_channels=(1024, 1024, 512, 256, 128), # one more decoder
117
+ up_attention=(True, True, True, False, False),
118
+ splat_size=128,
119
+ output_size=512, # render & supervise Gaussians at a higher resolution.
120
+ batch_size=1,
121
+ num_views=8,
122
+ gradient_accumulation_steps=1,
123
+ mixed_precision='bf16',
124
+ resume='pretrained/model_fp16_fixrot.safetensors',
125
+ )
126
+
127
+
128
+ AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
core/provider_objaverse_4d.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import cv2
18
+ import random
19
+ import numpy as np
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torchvision.transforms.functional as TF
25
+ from torch.utils.data import Dataset
26
+
27
+ import kiui
28
+ from core.options import Options
29
+ from core.utils import get_rays, grid_distortion, orbit_camera_jitter
30
+
31
+ from kiui.cam import orbit_camera
32
+
33
+ import tarfile
34
+ from io import BytesIO
35
+
36
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
37
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
38
+
39
+
40
+ def load_np_array_from_tar(tar, path):
41
+ array_file = BytesIO()
42
+ array_file.write(tar.extractfile(path).read())
43
+ array_file.seek(0)
44
+ return np.load(array_file)
45
+
46
+
47
+ class ObjaverseDataset(Dataset):
48
+
49
+ def _warn(self):
50
+ raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
51
+
52
+ def __init__(self, opt: Options, training=True, evaluating=False):
53
+
54
+ self.opt = opt
55
+ self.training = training
56
+ self.evaluating = evaluating
57
+
58
+ self.items = []
59
+ with open(self.opt.datalist, 'r') as f:
60
+ for line in f.readlines():
61
+ self.items.append(line.strip())
62
+
63
+
64
+ anim_map = {}
65
+ for x in self.items:
66
+ k = x.split('-')[1]
67
+ if k in anim_map:
68
+ anim_map[k] += '|'+x
69
+ else:
70
+ anim_map[k] = x
71
+ self.items = list(anim_map.values())
72
+
73
+
74
+ # naive split
75
+ if self.training:
76
+ self.items = self.items[:-self.opt.batch_size]
77
+ elif self.evaluating:
78
+ self.items = self.items[::1000]
79
+ else:
80
+ self.items = self.items[-self.opt.batch_size:]
81
+
82
+
83
+ # default camera intrinsics
84
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
85
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
86
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
87
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
88
+ self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
89
+ self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
90
+ self.proj_matrix[2, 3] = 1
91
+
92
+ def __len__(self):
93
+ return len(self.items)
94
+
95
+ def _get_batch(self, idx):
96
+ if self.training:
97
+ uid = random.choice(self.items[idx].split('|'))
98
+ else:
99
+ uid = self.items[idx].split('|')[0]
100
+
101
+ results = {}
102
+
103
+ # load num_views images
104
+ images = []
105
+ masks = []
106
+ cam_poses = []
107
+
108
+ if self.training and self.opt.shuffle_input:
109
+ vids = np.random.permutation(np.arange(32, 48))[:self.opt.num_input_views].tolist() + np.random.permutation(32).tolist()
110
+ else:
111
+ vids = np.arange(32, 48, 4).tolist() + np.arange(32).tolist()
112
+
113
+
114
+ random_tar_name = 'random_clip/' + uid
115
+ fixed_16_tar_name = 'fixed_16_clip/' + uid
116
+
117
+ local_random_tar_name = os.environ["DATA_HOME"] + random_tar_name.replace('/', '-')
118
+ local_fixed_16_tar_name = os.environ["DATA_HOME"] + fixed_16_tar_name.replace('/', '-')
119
+
120
+ tar_random = tarfile.open(local_random_tar_name)
121
+ tar_fixed = tarfile.open(local_fixed_16_tar_name)
122
+
123
+
124
+ T = self.opt.num_frames
125
+ for t_idx in range(T):
126
+ t = t_idx
127
+ vid_cnt = 0
128
+ for vid in vids:
129
+ if vid >= 32:
130
+ vid = vid % 32
131
+ tar = tar_fixed
132
+ else:
133
+ tar = tar_random
134
+
135
+ image_path = os.path.join('.', f'{vid:03d}/img', f'{t:03d}.jpg')
136
+ mask_path = os.path.join('.', f'{vid:03d}/mask', f'{t:03d}.png')
137
+
138
+ elevation_path = os.path.join('.', f'{vid:03d}/camera', f'elevation.npy')
139
+ rotation_path = os.path.join('.', f'{vid:03d}/camera', f'rotation.npy')
140
+
141
+ image = np.frombuffer(tar.extractfile(image_path).read(), np.uint8)
142
+ image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
143
+
144
+ azi = load_np_array_from_tar(tar, rotation_path)[t, None]
145
+ elevation = load_np_array_from_tar(tar, elevation_path)[t, None] * -1 # to align with pretrained LGM
146
+ azi = float(azi)
147
+ elevation = float(elevation)
148
+ c2w = torch.from_numpy(orbit_camera(elevation, azi, radius=1.5, opengl=True))
149
+
150
+ image = image.permute(2, 0, 1) # [4, 512, 512]
151
+
152
+ mask = np.frombuffer(tar.extractfile(mask_path).read(), np.uint8)
153
+ mask = torch.from_numpy(cv2.imdecode(mask, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255).unsqueeze(0) # [512, 512, 4] in [0, 1]
154
+
155
+ image = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
156
+ mask = F.interpolate(mask.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
157
+
158
+ image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
159
+ image = image[[2,1,0]].contiguous() # bgr to rgb
160
+
161
+ images.append(image)
162
+ masks.append(mask.squeeze(0))
163
+ cam_poses.append(c2w)
164
+
165
+ vid_cnt += 1
166
+ if vid_cnt == self.opt.num_views:
167
+ break
168
+
169
+ if vid_cnt < self.opt.num_views:
170
+ print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
171
+ n = self.opt.num_views - vid_cnt
172
+ images = images + [images[-1]] * n
173
+ masks = masks + [masks[-1]] * n
174
+ cam_poses = cam_poses + [cam_poses[-1]] * n
175
+
176
+ images = torch.stack(images, dim=0) # [V, C, H, W]
177
+ masks = torch.stack(masks, dim=0) # [V, H, W]
178
+ cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
179
+
180
+ # normalized camera feats as in paper (transform the first pose to a fixed position)
181
+ transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
182
+ cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
183
+
184
+ images_input = F.interpolate(images.reshape(T, self.opt.num_views, *images.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *images.shape[1:]).clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
185
+ cam_poses_input = cam_poses.reshape(T, self.opt.num_views, *cam_poses.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *cam_poses.shape[1:]).clone()
186
+
187
+ # data augmentation
188
+ if self.training:
189
+ images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
190
+ cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
191
+
192
+ # apply random grid distortion to simulate 3D inconsistency
193
+ if random.random() < self.opt.prob_grid_distortion:
194
+ for t in range(T):
195
+ images_input[t, 1:] = grid_distortion(images_input[t, 1:])
196
+ # apply camera jittering (only to input!)
197
+ if random.random() < self.opt.prob_cam_jitter:
198
+ for t in range(T):
199
+ cam_poses_input[t, 1:] = orbit_camera_jitter(cam_poses_input[t, 1:])
200
+
201
+ images_input = images_input.reshape(-1, *images_input.shape[2:])
202
+ cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
203
+
204
+ # masking other views
205
+ images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
206
+ images_input[1:, 1:] = images_input[0:1, 1:]
207
+ images_input = images_input.reshape(-1, *images_input.shape[2:])
208
+
209
+ cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
210
+ cam_poses_input[1:, 1:] = cam_poses_input[0:1, 1:]
211
+ cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
212
+
213
+ images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
214
+
215
+ # resize render ground-truth images, range still in [0, 1]
216
+ results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
217
+ results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
218
+
219
+ # build rays for input views
220
+ rays_embeddings = []
221
+ for i in range(self.opt.num_input_views * T):
222
+ rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
223
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
224
+ rays_embeddings.append(rays_plucker)
225
+
226
+
227
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
228
+
229
+ final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
230
+ results['input'] = final_input
231
+
232
+ # opengl to colmap camera for gaussian renderer
233
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
234
+
235
+ # cameras needed by gaussian rasterizer
236
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
237
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
238
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
239
+
240
+ results['cam_view'] = cam_view
241
+ results['cam_view_proj'] = cam_view_proj
242
+ results['cam_pos'] = cam_pos
243
+
244
+ return results
245
+
246
+ def __getitem__(self, idx):
247
+ while True:
248
+ try:
249
+ results = self._get_batch(idx)
250
+ break
251
+ except Exception as e:
252
+ print(f"{e}")
253
+ idx = random.randint(0, len(self.items) - 1)
254
+ return results
core/provider_objaverse_4d_interp.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import cv2
18
+ import random
19
+ import numpy as np
20
+ import copy
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torchvision.transforms.functional as TF
26
+ from torch.utils.data import Dataset
27
+
28
+ import kiui
29
+ from core.options import Options
30
+ from core.utils import get_rays, grid_distortion, orbit_camera_jitter
31
+
32
+ from kiui.cam import orbit_camera
33
+
34
+ import tarfile
35
+ from io import BytesIO
36
+
37
+
38
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
39
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
40
+
41
+
42
+ def load_np_array_from_tar(tar, path):
43
+ array_file = BytesIO()
44
+ array_file.write(tar.extractfile(path).read())
45
+ array_file.seek(0)
46
+ return np.load(array_file)
47
+
48
+ def interpolate_tensors(tensor):
49
+ # Extract the first and last tensors along the first dimension (B)
50
+ start_tensor = tensor[0] # shape [4, 3, 256, 256]
51
+ end_tensor = tensor[-1] # shape [4, 3, 256, 256]
52
+ tensor_interp = copy.deepcopy(tensor)
53
+
54
+ # Iterate over the range from 1 to second-last index
55
+ for i in range(1, tensor.size(0) - 1):
56
+ # Calculate the weight for interpolation
57
+
58
+ weight = (i - 0) / (tensor.size(0) - 1)
59
+ # Interpolate between start_tensor and end_tensor
60
+ tensor_interp[i] = torch.lerp(start_tensor, end_tensor, weight)
61
+
62
+
63
+ return tensor_interp
64
+
65
+ class ObjaverseDataset(Dataset):
66
+
67
+ def _warn(self):
68
+ raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
69
+
70
+ def __init__(self, opt: Options, training=True, evaluating=False):
71
+
72
+ self.opt = opt
73
+ self.training = training
74
+ self.evaluating = evaluating
75
+
76
+ self.items = []
77
+ with open(self.opt.datalist, 'r') as f:
78
+ for line in f.readlines():
79
+ self.items.append(line.strip())
80
+
81
+ anim_map = {}
82
+ for x in self.items:
83
+ k = x.split('-')[1]
84
+ if k in anim_map:
85
+ anim_map[k] += '|'+x
86
+ else:
87
+ anim_map[k] = x
88
+ self.items = list(anim_map.values())
89
+
90
+
91
+
92
+ # naive split
93
+ if self.training:
94
+ self.items = self.items[:-self.opt.batch_size]
95
+ elif self.evaluating:
96
+ self.items = self.items[::1000]
97
+ else:
98
+ self.items = self.items[-self.opt.batch_size:]
99
+
100
+
101
+ # default camera intrinsics
102
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
103
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
104
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
105
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
106
+ self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
107
+ self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
108
+ self.proj_matrix[2, 3] = 1
109
+
110
+
111
+ def __len__(self):
112
+ return len(self.items)
113
+
114
+ def _get_batch(self, idx):
115
+ # uid = self.items[idx]
116
+ if self.training:
117
+ uid = random.choice(self.items[idx].split('|'))
118
+ else:
119
+ uid = self.items[idx].split('|')[0]
120
+
121
+ results = {}
122
+
123
+ # load num_views images
124
+ images = []
125
+ masks = []
126
+ cam_poses = []
127
+
128
+ if self.training and self.opt.shuffle_input:
129
+ vids = np.random.permutation(np.arange(32, 48))[:self.opt.num_input_views].tolist() + np.random.permutation(32).tolist()
130
+ else:
131
+ vids = np.arange(32, 48, 4).tolist() + np.arange(32).tolist()
132
+
133
+ random_tar_name = 'random_24fps/' + uid
134
+ fixed_16_tar_name = 'fixed_16_24fps/' + uid
135
+
136
+ local_random_tar_name = os.environ["DATA_HOME"] + random_tar_name.replace('/', '-')
137
+ local_fixed_16_tar_name = os.environ["DATA_HOME"] + fixed_16_tar_name.replace('/', '-')
138
+
139
+ tar_random = tarfile.open(local_random_tar_name)
140
+ tar_fixed = tarfile.open(local_fixed_16_tar_name)
141
+
142
+ max_T = 24
143
+
144
+ T = self.opt.num_frames
145
+
146
+ start_frame = np.random.randint(max_T - T)
147
+
148
+ for t_idx in range(T):
149
+ t = start_frame + t_idx
150
+ vid_cnt = 0
151
+ for vid in vids:
152
+ if vid >= 32:
153
+ vid = vid % 32
154
+ tar = tar_fixed
155
+ else:
156
+ tar = tar_random
157
+
158
+ image_path = os.path.join('.', f'{vid:03d}/img', f'{t:03d}.jpg')
159
+ mask_path = os.path.join('.', f'{vid:03d}/mask', f'{t:03d}.png')
160
+
161
+ elevation_path = os.path.join('.', f'{vid:03d}/camera', f'elevation.npy')
162
+ rotation_path = os.path.join('.', f'{vid:03d}/camera', f'rotation.npy')
163
+
164
+ try :
165
+ image = np.frombuffer(tar.extractfile(image_path).read(), np.uint8)
166
+ image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
167
+
168
+ azi = load_np_array_from_tar(tar, rotation_path)[t, None]
169
+ elevation = load_np_array_from_tar(tar, elevation_path)[t, None] * -1 # to align with pretrained LGM
170
+ azi = float(azi)
171
+ elevation = float(elevation)
172
+ c2w = torch.from_numpy(orbit_camera(elevation, azi, radius=1.5, opengl=True))
173
+
174
+ image = image.permute(2, 0, 1) # [4, 512, 512]
175
+
176
+ mask = np.frombuffer(tar.extractfile(mask_path).read(), np.uint8)
177
+ mask = torch.from_numpy(cv2.imdecode(mask, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255).unsqueeze(0) # [512, 512, 4] in [0, 1]
178
+ except:
179
+
180
+ return self.__getitem__(idx - 1)
181
+ image = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
182
+ mask = F.interpolate(mask.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0)
183
+
184
+ image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
185
+ image = image[[2,1,0]].contiguous() # bgr to rgb
186
+
187
+ images.append(image)
188
+ masks.append(mask.squeeze(0))
189
+ cam_poses.append(c2w)
190
+
191
+ vid_cnt += 1
192
+ if vid_cnt == self.opt.num_views:
193
+ break
194
+
195
+ if vid_cnt < self.opt.num_views:
196
+ print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
197
+ n = self.opt.num_views - vid_cnt
198
+ images = images + [images[-1]] * n
199
+ masks = masks + [masks[-1]] * n
200
+ cam_poses = cam_poses + [cam_poses[-1]] * n
201
+
202
+ images = torch.stack(images, dim=0) # [V, C, H, W]
203
+ masks = torch.stack(masks, dim=0) # [V, H, W]
204
+ cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
205
+
206
+ # normalized camera feats as in paper (transform the first pose to a fixed position)
207
+ transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
208
+ cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
209
+
210
+ images_input = F.interpolate(images.reshape(T, self.opt.num_views, *images.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *images.shape[1:]).clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
211
+ cam_poses_input = cam_poses.reshape(T, self.opt.num_views, *cam_poses.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *cam_poses.shape[1:]).clone()
212
+
213
+ # data augmentation
214
+ if self.training:
215
+ images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
216
+ cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
217
+
218
+ # apply random grid distortion to simulate 3D inconsistency
219
+ if random.random() < self.opt.prob_grid_distortion:
220
+ for t in range(T):
221
+ images_input[t, 1:] = grid_distortion(images_input[t, 1:])
222
+ # apply camera jittering (only to input!)
223
+ if random.random() < self.opt.prob_cam_jitter:
224
+ for t in range(T):
225
+ cam_poses_input[t, 1:] = orbit_camera_jitter(cam_poses_input[t, 1:])
226
+
227
+ images_input = images_input.reshape(-1, *images_input.shape[2:])
228
+ cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
229
+
230
+ # masking other views
231
+ images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:])
232
+
233
+ images_input_interp = interpolate_tensors(images_input)
234
+
235
+ images_input[1:-1, :] = images_input_interp[1:-1, :]
236
+ images_input = images_input.reshape(-1, *images_input.shape[2:])
237
+
238
+ cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:])
239
+ cam_poses_input[1:, 1:] = cam_poses_input[0:1, 1:]
240
+ cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:])
241
+
242
+ images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
243
+
244
+ # resize render ground-truth images, range still in [0, 1]
245
+ results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
246
+ results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
247
+
248
+ # build rays for input views
249
+ rays_embeddings = []
250
+ for i in range(self.opt.num_input_views * T):
251
+ rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
252
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
253
+ rays_embeddings.append(rays_plucker)
254
+
255
+
256
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
257
+
258
+ final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
259
+ results['input'] = final_input
260
+
261
+ # opengl to colmap camera for gaussian renderer
262
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
263
+
264
+ # cameras needed by gaussian rasterizer
265
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
266
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
267
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
268
+
269
+ results['cam_view'] = cam_view
270
+ results['cam_view_proj'] = cam_view_proj
271
+ results['cam_pos'] = cam_pos
272
+
273
+ return results
274
+
275
+ def __getitem__(self, idx):
276
+ while True:
277
+ try:
278
+ results = self._get_batch(idx)
279
+ break
280
+ except Exception as e:
281
+ # print(f"{e}")
282
+ idx = random.randint(0, len(self.items) - 1)
283
+ return results
core/unet.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ import numpy as np
21
+ from typing import Tuple, Literal
22
+ from functools import partial
23
+
24
+ from core.attention import MemEffAttention
25
+
26
+
27
+ class MVAttention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int = 8,
32
+ qkv_bias: bool = False,
33
+ proj_bias: bool = True,
34
+ attn_drop: float = 0.0,
35
+ proj_drop: float = 0.0,
36
+ groups: int = 32,
37
+ eps: float = 1e-5,
38
+ residual: bool = True,
39
+ skip_scale: float = 1,
40
+ num_views: int = 4,
41
+ num_frames: int = 8
42
+ ):
43
+ super().__init__()
44
+
45
+ self.residual = residual
46
+ self.skip_scale = skip_scale
47
+ self.num_views = num_views
48
+ self.num_frames = num_frames
49
+
50
+ self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
51
+ self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
52
+
53
+ def forward(self, x):
54
+ # x: [B*T*V, C, H, W]
55
+ BTV, C, H, W = x.shape
56
+ BT = BTV // self.num_views # assert BV % self.num_views == 0
57
+
58
+ res = x
59
+ x = self.norm(x)
60
+
61
+ x = x.reshape(BT, self.num_views, C, H, W).permute(0, 1, 3, 4, 2).contiguous().reshape(BT, -1, C).contiguous()
62
+ x = self.attn(x)
63
+ x = x.reshape(BT, self.num_views, H, W, C).permute(0, 1, 4, 2, 3).contiguous().reshape(BTV, C, H, W).contiguous()
64
+
65
+ if self.residual:
66
+ x = (x + res) * self.skip_scale
67
+ return x
68
+
69
+
70
+ class TempAttention(nn.Module):
71
+ def __init__(
72
+ self,
73
+ dim: int,
74
+ num_heads: int = 8,
75
+ qkv_bias: bool = False,
76
+ proj_bias: bool = True,
77
+ attn_drop: float = 0.0,
78
+ proj_drop: float = 0.0,
79
+ groups: int = 32,
80
+ eps: float = 1e-5,
81
+ residual: bool = True,
82
+ skip_scale: float = 1,
83
+ num_views: int = 4,
84
+ num_frames: int = 8
85
+ ):
86
+ super().__init__()
87
+
88
+ self.residual = residual
89
+ self.skip_scale = skip_scale
90
+ self.num_views = num_views
91
+ self.num_frames = num_frames
92
+
93
+ self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
94
+ self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
95
+
96
+ def forward(self, x):
97
+ # x: [B*T*V, C, H, W]
98
+ BTV, C, H, W = x.shape
99
+ BV = BTV // self.num_frames # assert BV % self.num_views == 0
100
+ B = BV // self.num_views
101
+
102
+ res = x
103
+ x = self.norm(x)
104
+
105
+ # BTV -> BVT
106
+ x = x.reshape(B, self.num_frames, self.num_views, C, H, W).permute(0, 2, 1, 3, 4, 5).contiguous()
107
+
108
+ x = x.reshape(BV, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).contiguous().reshape(BV, -1, C).contiguous().contiguous()
109
+ x = self.attn(x)
110
+ x = x.reshape(BV, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).contiguous().reshape(BTV, C, H, W).contiguous().contiguous()
111
+
112
+ # BVT -> BTV
113
+ x = x.reshape(B, self.num_views, self.num_frames, C, H, W).permute(0, 2, 1, 3, 4, 5).contiguous().reshape(BTV, C, H, W).contiguous()
114
+
115
+ if self.residual:
116
+ x = (x + res) * self.skip_scale
117
+ return x
118
+
119
+
120
+ class ResnetBlock(nn.Module):
121
+ def __init__(
122
+ self,
123
+ in_channels: int,
124
+ out_channels: int,
125
+ resample: Literal['default', 'up', 'down'] = 'default',
126
+ groups: int = 32,
127
+ eps: float = 1e-5,
128
+ skip_scale: float = 1, # multiplied to output
129
+ ):
130
+ super().__init__()
131
+
132
+ self.in_channels = in_channels
133
+ self.out_channels = out_channels
134
+ self.skip_scale = skip_scale
135
+
136
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
137
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
138
+
139
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
140
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
141
+
142
+ self.act = F.silu
143
+
144
+ self.resample = None
145
+ if resample == 'up':
146
+ self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
147
+ elif resample == 'down':
148
+ self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
149
+
150
+ self.shortcut = nn.Identity()
151
+ if self.in_channels != self.out_channels:
152
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
153
+
154
+
155
+ def forward(self, x):
156
+ res = x
157
+
158
+ x = self.norm1(x)
159
+ x = self.act(x)
160
+
161
+ if self.resample:
162
+ res = self.resample(res)
163
+ x = self.resample(x)
164
+
165
+ x = self.conv1(x)
166
+ x = self.norm2(x)
167
+ x = self.act(x)
168
+ x = self.conv2(x)
169
+
170
+ x = (x + self.shortcut(res)) * self.skip_scale
171
+
172
+ return x
173
+
174
+ class DownBlock(nn.Module):
175
+ def __init__(
176
+ self,
177
+ in_channels: int,
178
+ out_channels: int,
179
+ num_layers: int = 1,
180
+ downsample: bool = True,
181
+ attention: bool = True,
182
+ attention_heads: int = 16,
183
+ skip_scale: float = 1,
184
+ num_views: int = 4,
185
+ num_frames: int = 8,
186
+ use_temp_attn=True
187
+ ):
188
+ super().__init__()
189
+
190
+ nets = []
191
+ attns = []
192
+ t_attns = []
193
+ for i in range(num_layers):
194
+ in_channels = in_channels if i == 0 else out_channels
195
+ nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
196
+ if attention:
197
+ attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames))
198
+ t_attns.append(TempAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames) if use_temp_attn else None)
199
+ else:
200
+ attns.append(None)
201
+ t_attns.append(None)
202
+ self.nets = nn.ModuleList(nets)
203
+ self.attns = nn.ModuleList(attns)
204
+ self.t_attns = nn.ModuleList(t_attns)
205
+
206
+ self.downsample = None
207
+ if downsample:
208
+ self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
209
+
210
+ def forward(self, x):
211
+ xs = []
212
+
213
+ for attn, t_attn, net in zip(self.attns, self.t_attns, self.nets):
214
+ x = net(x)
215
+ if attn:
216
+ x = attn(x)
217
+ if t_attn:
218
+ x = t_attn(x)
219
+ xs.append(x)
220
+
221
+ if self.downsample:
222
+ x = self.downsample(x)
223
+ xs.append(x)
224
+
225
+ return x, xs
226
+
227
+
228
+ class MidBlock(nn.Module):
229
+ def __init__(
230
+ self,
231
+ in_channels: int,
232
+ num_layers: int = 1,
233
+ attention: bool = True,
234
+ attention_heads: int = 16,
235
+ skip_scale: float = 1,
236
+ num_views: int = 4,
237
+ num_frames: int = 8,
238
+ use_temp_attn=True
239
+ ):
240
+ super().__init__()
241
+
242
+ nets = []
243
+ attns = []
244
+ t_attns = []
245
+ # first layer
246
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
247
+ # more layers
248
+ for i in range(num_layers):
249
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
250
+ if attention:
251
+ attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames))
252
+ t_attns.append(TempAttention(in_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames) if use_temp_attn else None)
253
+ else:
254
+ attns.append(None)
255
+ t_attns.append(None)
256
+ self.nets = nn.ModuleList(nets)
257
+ self.attns = nn.ModuleList(attns)
258
+ self.t_attns = nn.ModuleList(t_attns)
259
+
260
+ def forward(self, x):
261
+ x = self.nets[0](x)
262
+ for attn, t_attn,net in zip(self.attns, self.t_attns, self.nets[1:]):
263
+ if attn:
264
+ x = attn(x)
265
+ if t_attn:
266
+ x = t_attn(x)
267
+ x = net(x)
268
+ return x
269
+
270
+
271
+ class UpBlock(nn.Module):
272
+ def __init__(
273
+ self,
274
+ in_channels: int,
275
+ prev_out_channels: int,
276
+ out_channels: int,
277
+ num_layers: int = 1,
278
+ upsample: bool = True,
279
+ attention: bool = True,
280
+ attention_heads: int = 16,
281
+ skip_scale: float = 1,
282
+ num_views: int = 4,
283
+ num_frames: int = 8,
284
+ use_temp_attn=True
285
+ ):
286
+ super().__init__()
287
+
288
+ nets = []
289
+ attns = []
290
+ t_attns = []
291
+ for i in range(num_layers):
292
+ cin = in_channels if i == 0 else out_channels
293
+ cskip = prev_out_channels if (i == num_layers - 1) else out_channels
294
+
295
+ nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
296
+ if attention:
297
+ attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames))
298
+ t_attns.append(TempAttention(out_channels, attention_heads, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames) if use_temp_attn else None)
299
+ else:
300
+ attns.append(None)
301
+ t_attns.append(None)
302
+ self.nets = nn.ModuleList(nets)
303
+ self.attns = nn.ModuleList(attns)
304
+ self.t_attns = nn.ModuleList(t_attns)
305
+
306
+ self.upsample = None
307
+ if upsample:
308
+ self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
309
+
310
+ def forward(self, x, xs):
311
+
312
+ for attn, t_attn, net in zip(self.attns, self.t_attns, self.nets):
313
+ res_x = xs[-1]
314
+ xs = xs[:-1]
315
+ x = torch.cat([x, res_x], dim=1)
316
+ x = net(x)
317
+ if attn:
318
+ x = attn(x)
319
+ if t_attn:
320
+ x = t_attn(x)
321
+
322
+ if self.upsample:
323
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest')
324
+ x = self.upsample(x)
325
+
326
+ return x
327
+
328
+
329
+ # it could be asymmetric!
330
+ class UNet(nn.Module):
331
+ def __init__(
332
+ self,
333
+ in_channels: int = 3,
334
+ out_channels: int = 3,
335
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
336
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True),
337
+ mid_attention: bool = True,
338
+ up_channels: Tuple[int, ...] = (1024, 512, 256),
339
+ up_attention: Tuple[bool, ...] = (True, True, False),
340
+ layers_per_block: int = 2,
341
+ skip_scale: float = np.sqrt(0.5),
342
+ num_views: int = 4,
343
+ num_frames: int = 8,
344
+ use_temp_attn: bool = True
345
+ ):
346
+ super().__init__()
347
+
348
+ # first
349
+ self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
350
+
351
+ # down
352
+ down_blocks = []
353
+ cout = down_channels[0]
354
+ for i in range(len(down_channels)):
355
+ cin = cout
356
+ cout = down_channels[i]
357
+
358
+ down_blocks.append(DownBlock(
359
+ cin, cout,
360
+ num_layers=layers_per_block,
361
+ downsample=(i != len(down_channels) - 1), # not final layer
362
+ attention=down_attention[i],
363
+ skip_scale=skip_scale,
364
+ num_views=num_views,
365
+ num_frames=num_frames,
366
+ use_temp_attn=use_temp_attn
367
+ ))
368
+ self.down_blocks = nn.ModuleList(down_blocks)
369
+
370
+ # mid
371
+ self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale, num_views=num_views, num_frames=num_frames, use_temp_attn=use_temp_attn)
372
+
373
+ # up
374
+ up_blocks = []
375
+ cout = up_channels[0]
376
+ for i in range(len(up_channels)):
377
+ cin = cout
378
+ cout = up_channels[i]
379
+ cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric
380
+
381
+ up_blocks.append(UpBlock(
382
+ cin, cskip, cout,
383
+ num_layers=layers_per_block + 1, # one more layer for up
384
+ upsample=(i != len(up_channels) - 1), # not final layer
385
+ attention=up_attention[i],
386
+ skip_scale=skip_scale,
387
+ num_views=num_views,
388
+ num_frames=num_frames,
389
+ use_temp_attn=use_temp_attn
390
+ ))
391
+ self.up_blocks = nn.ModuleList(up_blocks)
392
+
393
+ # last
394
+ self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
395
+ self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
396
+
397
+
398
+ def forward(self, x, return_mid_feature=False):
399
+ # x: [B, Cin, H, W]
400
+
401
+ # first
402
+ x = self.conv_in(x)
403
+
404
+ # down
405
+ xss = [x]
406
+ for block in self.down_blocks:
407
+ x, xs = block(x)
408
+ xss.extend(xs)
409
+
410
+ # mid
411
+ x = self.mid_block(x)
412
+ mid_feature = (x, xss)
413
+
414
+ # up
415
+ for block in self.up_blocks:
416
+ xs = xss[-len(block.nets):]
417
+ xss = xss[:-len(block.nets)]
418
+ x = block(x, xs)
419
+
420
+ # last
421
+ x = self.norm_out(x)
422
+ x = F.silu(x)
423
+ x = self.conv_out(x) # [B, Cout, H', W']
424
+
425
+ if return_mid_feature:
426
+ return x, *mid_feature
427
+
428
+ return x
core/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import roma
8
+ from kiui.op import safe_normalize
9
+
10
+ def get_rays(pose, h, w, fovy, opengl=True):
11
+
12
+ x, y = torch.meshgrid(
13
+ torch.arange(w, device=pose.device),
14
+ torch.arange(h, device=pose.device),
15
+ indexing="xy",
16
+ )
17
+ x = x.flatten()
18
+ y = y.flatten()
19
+
20
+ cx = w * 0.5
21
+ cy = h * 0.5
22
+
23
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
24
+
25
+ camera_dirs = F.pad(
26
+ torch.stack(
27
+ [
28
+ (x - cx + 0.5) / focal,
29
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
30
+ ],
31
+ dim=-1,
32
+ ),
33
+ (0, 1),
34
+ value=(-1.0 if opengl else 1.0),
35
+ ) # [hw, 3]
36
+
37
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
38
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
39
+
40
+ rays_o = rays_o.view(h, w, 3)
41
+ rays_d = safe_normalize(rays_d).view(h, w, 3)
42
+
43
+ return rays_o, rays_d
44
+
45
+ def orbit_camera_jitter(poses, strength=0.1):
46
+ # poses: [B, 4, 4], assume orbit camera in opengl format
47
+ # random orbital rotate
48
+
49
+ B = poses.shape[0]
50
+ rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
51
+ rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
52
+
53
+ rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
54
+ R = rot @ poses[:, :3, :3]
55
+ T = rot @ poses[:, :3, 3:]
56
+
57
+ new_poses = poses.clone()
58
+ new_poses[:, :3, :3] = R
59
+ new_poses[:, :3, 3:] = T
60
+
61
+ return new_poses
62
+
63
+ def grid_distortion(images, strength=0.5):
64
+ # images: [B, C, H, W]
65
+ # num_steps: int, grid resolution for distortion
66
+ # strength: float in [0, 1], strength of distortion
67
+
68
+ B, C, H, W = images.shape
69
+
70
+ num_steps = np.random.randint(8, 17)
71
+ grid_steps = torch.linspace(-1, 1, num_steps)
72
+
73
+ # have to loop batch...
74
+ grids = []
75
+ for b in range(B):
76
+ # construct displacement
77
+ x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
78
+ x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
79
+ x_steps = (x_steps * W).long() # [num_steps]
80
+ x_steps[0] = 0
81
+ x_steps[-1] = W
82
+ xs = []
83
+ for i in range(num_steps - 1):
84
+ xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
85
+ xs = torch.cat(xs, dim=0) # [W]
86
+
87
+ y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
88
+ y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
89
+ y_steps = (y_steps * H).long() # [num_steps]
90
+ y_steps[0] = 0
91
+ y_steps[-1] = H
92
+ ys = []
93
+ for i in range(num_steps - 1):
94
+ ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
95
+ ys = torch.cat(ys, dim=0) # [H]
96
+
97
+ # construct grid
98
+ grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
99
+ grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
100
+
101
+ grids.append(grid)
102
+
103
+ grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
104
+
105
+ # grid sample
106
+ images = F.grid_sample(images, grids, align_corners=False)
107
+
108
+ return images
109
+
data_test/000000_fg.mp4 ADDED
Binary file (226 kB). View file
 
data_test/000070_fg.mp4 ADDED
Binary file (291 kB). View file
 
data_test/000370_fg.mp4 ADDED
Binary file (273 kB). View file
 
data_test/blooming_rose_fg.mp4 ADDED
Binary file (83.5 kB). View file
 
data_test/cat_king_fg.mp4 ADDED
Binary file (266 kB). View file
 
data_test/dancing_robot_fg.mp4 ADDED
Binary file (64.2 kB). View file
 
data_test/lifting1_fg.mp4 ADDED
Binary file (311 kB). View file
 
data_test/monster-with-melting-candle_fg.mp4 ADDED
Binary file (365 kB). View file
 
data_test/otter-on-surfboard_fg.mp4 ADDED
Binary file (305 kB). View file
 
data_test/sighing_frog_fg.mp4 ADDED
Binary file (90.1 kB). View file
 
environment.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: l4gm
2
+ channels:
3
+ - pyg
4
+ - nvidia/label/cuda-12.1.0
5
+ - pytorch
6
+ - conda-forge
7
+ - xformers
8
+ dependencies:
9
+ - python=3.10
10
+ - pytorch=2.5.1
11
+ - pytorch-cuda=12.1
12
+ - torchvision
13
+ - xformers
14
+ - cuda
15
+ - cuda-nvcc
16
+ - numpy<2.0.0
17
+ - scipy
18
+ - rich
19
+ - pip
20
+ - setuptools
21
+ - ninja
22
+ - tqdm
23
+ - ray-default
24
+ - flatten-dict
25
+ - gcc_linux-64=11
26
+ - gxx_linux-64=11
27
+ - opencv
28
+ - transformers
29
+ - einops
30
+ - pip:
31
+ - -r requirements.txt
32
+ - git+https://github.com/nerfstudio-project/gsplat.git@v1.4.0
infer_3d.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import imageio.v3 as iio
17
+ import cv2
18
+ import numpy as np
19
+ import imageio
20
+
21
+ import os
22
+ import tyro
23
+ import glob
24
+ import imageio
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torchvision.transforms.functional as TF
30
+ from safetensors.torch import load_file
31
+ import time
32
+
33
+ import kiui
34
+ from kiui.cam import orbit_camera
35
+
36
+ from core.options import AllConfigs, Options
37
+ from core.models import LGM
38
+ from mvdream.pipeline_mvdream import MVDreamPipeline
39
+
40
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
41
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
42
+
43
+ opt = tyro.cli(AllConfigs)
44
+
45
+ # model
46
+ model = LGM(opt)
47
+
48
+ # resume pretrained checkpoint
49
+ if opt.resume is not None:
50
+ if opt.resume.endswith('safetensors'):
51
+ ckpt = load_file(opt.resume, device='cpu')
52
+ else:
53
+ ckpt = torch.load(opt.resume, map_location='cpu')
54
+ model.load_state_dict(ckpt, strict=False)
55
+ print(f'[INFO] Loaded checkpoint from {opt.resume}')
56
+ else:
57
+ print(f'[WARN] model randomly initialized, are you sure?')
58
+
59
+ # device
60
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
61
+ model = model.half().to(device)
62
+ model.eval()
63
+
64
+ bg_color = torch.tensor([255, 255, 255], dtype=torch.float32, device="cuda") / 255.
65
+
66
+ rays_embeddings = model.prepare_default_rays(device)
67
+
68
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
69
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
70
+ proj_matrix[0, 0] = 1 / tan_half_fov
71
+ proj_matrix[1, 1] = 1 / tan_half_fov
72
+ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
73
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
74
+ proj_matrix[2, 3] = 1
75
+
76
+ # load image dream
77
+ pipe = MVDreamPipeline.from_pretrained(
78
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
79
+ torch_dtype=torch.float16,
80
+ trust_remote_code=True,
81
+ # local_files_only=True,
82
+ )
83
+ pipe = pipe.to(device)
84
+
85
+
86
+ def process_eval_video(video_path, T):
87
+ frames = iio.imread(video_path)
88
+ frames = [frames[x] for x in range(frames.shape[0])]
89
+ V = opt.num_input_views
90
+ img_TV = []
91
+ for t in range(T):
92
+
93
+ img = frames[t]
94
+
95
+ img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
96
+ img = img.astype(np.float32) / 255.0
97
+
98
+ img_V = []
99
+ for v in range(V):
100
+ img_V.append(img)
101
+ img_TV.append(np.stack(img_V, axis=0))
102
+
103
+ return np.stack(img_TV, axis=0)
104
+
105
+
106
+ # process function
107
+ def process(opt: Options, path):
108
+ name = os.path.splitext(os.path.basename(path))[0]
109
+ print(f'[INFO] Processing {path} --> {name}')
110
+ os.makedirs(opt.workspace, exist_ok=True)
111
+
112
+ ref_video = process_eval_video(path, opt.num_frames) # [TV, 512, 512, 3]
113
+
114
+
115
+ end_time = time.time()
116
+
117
+ cv2.imwrite(os.path.join(opt.workspace, f'{name}_orig.png'), ref_video[0,0][..., ::-1] * 255)
118
+
119
+ mv_image = pipe('', ref_video[0,0], guidance_scale=5, num_inference_steps=30, elevation=0)
120
+ for v in range(4):
121
+ cv2.imwrite(os.path.join(opt.workspace, f'{name}_mv_{(v-1)%4:03d}.png'), mv_image[v][..., ::-1] * 255)
122
+ mv_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
123
+
124
+
125
+ # generate gaussians
126
+ input_image = torch.from_numpy(mv_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
127
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
128
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
129
+
130
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
131
+
132
+ with torch.no_grad():
133
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
134
+ gaussians_all_frame = model.forward_gaussians(input_image)
135
+
136
+ B, T, V = 1, gaussians_all_frame.shape[0]//opt.batch_size, opt.num_views
137
+ gaussians_all_frame = gaussians_all_frame.reshape(B, T, *gaussians_all_frame.shape[1:])
138
+
139
+ # align azimuth
140
+ best_azi = 0
141
+ best_diff = 1e8
142
+ for v, azi in enumerate(np.arange(-180, 180, 1)):
143
+ gaussians = gaussians_all_frame[:, 0]
144
+
145
+ cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
146
+
147
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
148
+
149
+ # cameras needed by gaussian rasterizer
150
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
151
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
152
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
153
+
154
+ result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)
155
+ image = result['image']
156
+ alpha = result['alpha']
157
+
158
+ image = image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy()
159
+ image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
160
+
161
+ diff = np.mean((image- ref_video[0,0]) ** 2)
162
+
163
+ if diff < best_diff:
164
+ best_diff = diff
165
+ best_azi = azi
166
+
167
+ print("Best aligned azimuth: ", best_azi)
168
+
169
+ mv_image = []
170
+ for v, azi in enumerate(np.arange(0, 360, 90)):
171
+ gaussians = gaussians_all_frame[:, 0]
172
+
173
+ cam_poses = torch.from_numpy(orbit_camera(0, azi + best_azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
174
+
175
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
176
+
177
+ # cameras needed by gaussian rasterizer
178
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
179
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
180
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
181
+
182
+ scale = 1
183
+
184
+ result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)
185
+ image = result['image']
186
+ alpha = result['alpha']
187
+
188
+ imageio.imwrite(os.path.join(opt.workspace, f'{name}_{v:03d}.png'), (image.squeeze(1).permute(0,2,3,1).squeeze(0).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
189
+
190
+ if azi in [0, 90, 180, 270]:
191
+ rendered_image = image.squeeze(1)
192
+ rendered_image = F.interpolate(rendered_image, (256, 256))
193
+ rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
194
+ mv_image.append(rendered_image)
195
+ mv_image = np.concatenate(mv_image, axis=0)
196
+ print(f"Generate 3D takes {time.time()-end_time} s")
197
+
198
+ images = []
199
+ azimuth = np.arange(0, 360, 4, dtype=np.int32)
200
+ elevation = 0
201
+ for azi in azimuth:
202
+ gaussians = gaussians_all_frame[:, 0]
203
+
204
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
205
+
206
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
207
+
208
+ # cameras needed by gaussian rasterizer
209
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
210
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
211
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
212
+
213
+ scale = 1
214
+
215
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
216
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
217
+
218
+ images = np.concatenate(images, axis=0)
219
+ imageio.mimwrite(os.path.join(opt.workspace, f'{name}.mp4'), images, fps=30)
220
+
221
+
222
+ torch.cuda.empty_cache()
223
+
224
+
225
+
226
+ assert opt.test_path is not None
227
+ if os.path.isdir(opt.test_path):
228
+ file_paths = glob.glob(os.path.join(opt.test_path, "*"))
229
+ else:
230
+ file_paths = [opt.test_path]
231
+
232
+ for path in sorted(file_paths):
233
+ process(opt, path)
infer_4d.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import imageio.v3 as iio
17
+ import cv2
18
+ import numpy as np
19
+ import imageio
20
+
21
+ from copy import deepcopy
22
+ import os
23
+ import tyro
24
+ import glob
25
+ import imageio
26
+ import numpy as np
27
+ import tqdm
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import torchvision.transforms.functional as TF
32
+ from safetensors.torch import load_file
33
+
34
+ import kiui
35
+ from kiui.cam import orbit_camera
36
+
37
+ from core.options import AllConfigs, Options
38
+ from core.models import LGM
39
+ import time
40
+
41
+ from core.utils import get_rays, grid_distortion, orbit_camera_jitter
42
+
43
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
44
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
45
+
46
+
47
+ USE_INTERPOLATION = True # set to false to disable interpolation
48
+ MAX_RUNS = 100
49
+ VIDEO_FPS = 30
50
+
51
+ opt = tyro.cli(AllConfigs)
52
+
53
+ # model
54
+ model = LGM(opt)
55
+
56
+ # resume pretrained checkpoint
57
+ if opt.resume is not None:
58
+ if opt.resume.endswith('safetensors'):
59
+ ckpt = load_file(opt.resume, device='cpu')
60
+ else:
61
+ ckpt = torch.load(opt.resume, map_location='cpu')
62
+ model.load_state_dict(ckpt, strict=False)
63
+ print(f'[INFO] Loaded checkpoint from {opt.resume}')
64
+ else:
65
+ print(f'[WARN] model randomly initialized, are you sure?')
66
+
67
+ # device
68
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+ model = model.half().to(device)
70
+ model.eval()
71
+
72
+ bg_color = torch.tensor([255, 255, 255], dtype=torch.float32, device="cuda") / 255.
73
+
74
+
75
+ rays_embeddings = model.prepare_default_rays(device)
76
+ rays_embeddings = torch.cat([rays_embeddings for _ in range(opt.num_frames)])
77
+
78
+
79
+ interp_opt = deepcopy(opt)
80
+ interp_opt.num_frames = 4
81
+ model_interp = LGM(interp_opt)
82
+ # resume pretrained checkpoint
83
+ if interp_opt.interpresume is not None:
84
+ if interp_opt.interpresume.endswith('safetensors'):
85
+ ckpt = load_file(interp_opt.interpresume, device='cpu')
86
+ else:
87
+ ckpt = torch.load(interp_opt.interpresume, map_location='cpu')
88
+ model_interp.load_state_dict(ckpt, strict=False)
89
+ print(f'[INFO] Loaded Interp checkpoint from {interp_opt.interpresume}')
90
+ else:
91
+ print(f'[WARN] model_interp randomly initialized, are you sure?')
92
+
93
+ # device
94
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
95
+ model_interp = model_interp.half().to(device)
96
+ model_interp.eval()
97
+
98
+
99
+ interp_rays_embeddings = model_interp.prepare_default_rays(device)
100
+ interp_rays_embeddings = torch.cat([interp_rays_embeddings for _ in range(interp_opt.num_frames)])
101
+
102
+ tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
103
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
104
+ proj_matrix[0, 0] = 1 / tan_half_fov
105
+ proj_matrix[1, 1] = 1 / tan_half_fov
106
+ proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
107
+ proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
108
+ proj_matrix[2, 3] = 1
109
+
110
+ def interpolate_tensors(tensor):
111
+ # Extract the first and last tensors along the first dimension (B)
112
+ start_tensor = tensor[0] # shape [4, 3, 256, 256]
113
+ end_tensor = tensor[-1] # shape [4, 3, 256, 256]
114
+ tensor_interp = deepcopy(tensor)
115
+
116
+ # Iterate over the range from 1 to second-last index
117
+
118
+ for i in range(1, tensor.shape[0] - 1):
119
+ # Calculate the weight for interpolation
120
+
121
+ weight = (i - 0) / (tensor.shape[0] - 1)
122
+ # Interpolate between start_tensor and end_tensor
123
+ tensor_interp[i] = torch.lerp(start_tensor, end_tensor, weight)
124
+
125
+
126
+ return tensor_interp
127
+
128
+ def process_eval_video(frames, video_path, T, start_t=0, downsample_rate=1):
129
+ L = frames.shape[0]
130
+ vid_name =video_path.split('/')[-1].split('.')[0]
131
+ total_frames = L//downsample_rate
132
+ print(f'{start_t} / {total_frames}')
133
+ frames = [frames[x] for x in range(frames.shape[0])]
134
+ V = opt.num_input_views
135
+ img_TV = []
136
+ for t in range(T):
137
+ t += start_t
138
+ t = min(t, L//downsample_rate-1)
139
+ t*=downsample_rate
140
+
141
+ img = frames[t]
142
+
143
+ img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
144
+ img = img.astype(np.float32) / 255.0
145
+
146
+ img_V = []
147
+ for v in range(V):
148
+ img_V.append(img)
149
+ img_TV.append(np.stack(img_V, axis=0))
150
+
151
+ return np.stack(img_TV, axis=0), L//downsample_rate- start_t
152
+
153
+ def load_mv_img(name, img_dir):
154
+ img_list = []
155
+ for v in range(4):
156
+ img = kiui.read_image(os.path.join(img_dir, name + f'_{v:03d}.png'), mode='uint8')
157
+ img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
158
+ img = img / 255.
159
+ img_list.append(img)
160
+ return np.stack(img_list, axis=0)
161
+
162
+
163
+
164
+ # process function
165
+ def process(opt: Options, path):
166
+ name = os.path.splitext(os.path.basename(path))[0]
167
+ print(f'[INFO] Processing {path} --> {name}')
168
+ os.makedirs(opt.workspace, exist_ok=True)
169
+ frames = iio.imread(path)
170
+ img_dir = opt.workspace
171
+ mv_image = load_mv_img(name, img_dir)
172
+
173
+ print(iio.immeta(path))
174
+ FPS = int(iio.immeta(path)['fps'])
175
+ downsample_rate = FPS // 15 if FPS > 15 else 1 # default reconstruction fps 15
176
+
177
+
178
+
179
+ with torch.inference_mode():
180
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
181
+ start_t = 0
182
+ gaussians_all_frame_all_run = []
183
+ gaussians_all_frame_all_run_w_interp = []
184
+ for run_idx in range(MAX_RUNS):
185
+ ref_video, end_t = process_eval_video(frames, path, opt.num_frames, start_t, downsample_rate=downsample_rate)
186
+ ref_video[:, 1:] = mv_image[None, 1:] # repeat
187
+ input_image = torch.from_numpy(ref_video).reshape([-1, *ref_video.shape[2:]]).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
188
+ input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False)
189
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
190
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
191
+
192
+ end_time = time.time()
193
+
194
+ gaussians_all_frame = model.forward_gaussians(input_image)
195
+ print(f"Forward pass takes {time.time()-end_time} s")
196
+
197
+ B, T, V = 1, gaussians_all_frame.shape[0]//opt.batch_size, opt.num_views
198
+ gaussians_all_frame = gaussians_all_frame.reshape(B, T, *gaussians_all_frame.shape[1:])
199
+
200
+ if run_idx > 0:
201
+ gaussians_all_frame_wo_inter = gaussians_all_frame[:, 1:max(end_t, 1)]
202
+ else:
203
+ gaussians_all_frame_wo_inter = gaussians_all_frame
204
+
205
+ if gaussians_all_frame_wo_inter.shape[1] > 0 and USE_INTERPOLATION:
206
+ # render multiview video
207
+ render_img_TV = []
208
+ for t in range(gaussians_all_frame.shape[1]):
209
+ render_img_V = []
210
+ for v, azi in enumerate(np.arange(0, 360, 90)):
211
+
212
+ gaussians = gaussians_all_frame[:, t]
213
+
214
+ cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
215
+
216
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
217
+
218
+ # cameras needed by gaussian rasterizer
219
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
220
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
221
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
222
+
223
+ rendered_image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
224
+ rendered_image = rendered_image.squeeze(1)
225
+ rendered_image = F.interpolate(rendered_image, (256, 256))
226
+ rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy() # B H W C
227
+
228
+ render_img_V.append(rendered_image)
229
+ render_img_V = np.concatenate(render_img_V, axis=0) # V H W C
230
+ render_img_TV.append(render_img_V)
231
+ render_img_TV = np.stack(render_img_TV, axis=0) # T V H W C
232
+ ref_video = np.concatenate([np.stack([ref_video[ttt] for _ in range(opt.interpolate_rate)], 0) for ttt in range(ref_video.shape[0])], 0)
233
+
234
+
235
+ for tt in range(gaussians_all_frame_wo_inter.shape[1] -1 ):
236
+
237
+ curr_ref_video = deepcopy( ref_video[ tt * opt.interpolate_rate: tt * opt.interpolate_rate + interp_opt.num_frames ])
238
+ curr_ref_video[0, 1:] = render_img_TV[tt, 1:]
239
+
240
+ curr_ref_video[-1, 1:] = render_img_TV[tt+1, 1:]
241
+
242
+
243
+ curr_ref_video = torch.from_numpy(curr_ref_video).float().to(
244
+ device) # [4, 3, 256, 256]
245
+
246
+ images_input_interp = interpolate_tensors(curr_ref_video)
247
+
248
+ curr_ref_video[1:-1, :] = images_input_interp[1:-1, :]
249
+
250
+ input_image_interp = curr_ref_video.reshape([-1, *curr_ref_video.shape[2:]]).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
251
+ input_image_interp = F.interpolate(input_image_interp, size=(interp_opt.input_size, interp_opt.input_size), mode='bilinear',
252
+ align_corners=False)
253
+ input_image_interp = TF.normalize(input_image_interp, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
254
+
255
+ input_image_interp = torch.cat([input_image_interp, interp_rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
256
+
257
+ end_time = time.time()
258
+ gaussians_interp_all_frame = model_interp.forward_gaussians(input_image_interp)
259
+ print(f"Interpolate forward pass takes {time.time()-end_time} s")
260
+
261
+ B, T, V = 1, gaussians_interp_all_frame.shape[0] // opt.batch_size, opt.num_views
262
+ gaussians_interp_all_frame = gaussians_interp_all_frame.reshape(B, T, *gaussians_interp_all_frame.shape[1:])
263
+
264
+ if tt > 0:
265
+ gaussians_interp_all_frame = gaussians_interp_all_frame[:, 1:]
266
+
267
+ gaussians_all_frame_all_run_w_interp.append(gaussians_interp_all_frame)
268
+
269
+
270
+
271
+ gaussians_all_frame_all_run.append(gaussians_all_frame_wo_inter)
272
+ start_t += opt.num_frames -1
273
+
274
+ mv_image = []
275
+ for v, azi in enumerate(np.arange(0, 360, 90)):
276
+ gaussians = gaussians_all_frame_wo_inter[:, -1]
277
+ cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
278
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
279
+ # cameras needed by gaussian rasterizer
280
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
281
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
282
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
283
+
284
+ rendered_image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
285
+ rendered_image = rendered_image.squeeze(1)
286
+ rendered_image = F.interpolate(rendered_image, (256, 256))
287
+ rendered_image = rendered_image.permute(0,2,3,1).contiguous().float().cpu().numpy()
288
+ mv_image.append(rendered_image)
289
+ mv_image = np.concatenate(mv_image, axis=0)
290
+ elif gaussians_all_frame_wo_inter.shape[1] > 0:
291
+ gaussians_all_frame_all_run.append(gaussians_all_frame_wo_inter)
292
+ start_t += opt.num_frames -1
293
+ else:
294
+ break
295
+
296
+ gaussians_all_frame_wo_interp = torch.cat(gaussians_all_frame_all_run, dim=1)
297
+ if USE_INTERPOLATION:
298
+ gaussians_all_frame_w_interp = torch.cat(gaussians_all_frame_all_run_w_interp, dim=1)
299
+
300
+ if USE_INTERPOLATION:
301
+ zip_dump = zip(["wo_interp", "w_interp"], [gaussians_all_frame_wo_interp, gaussians_all_frame_w_interp])
302
+ else:
303
+ zip_dump = zip(["wo_interp"], [gaussians_all_frame_wo_interp])
304
+
305
+ for sv_name, gaussians_all_frame in zip_dump:
306
+ if sv_name == "w_interp":
307
+ ANIM_FPS = FPS / downsample_rate * gaussians_all_frame_w_interp.shape[1] / gaussians_all_frame_wo_interp.shape[1]
308
+ else:
309
+ ANIM_FPS = FPS / downsample_rate
310
+ print(f"{sv_name} | input video fps: {FPS} | downsample rate: {downsample_rate} | animation fps: {ANIM_FPS} | output video fps: {VIDEO_FPS}")
311
+ render_img_TV = []
312
+ for t in range(gaussians_all_frame.shape[1]):
313
+ render_img_V = []
314
+ for v, azi in enumerate(np.arange(0, 360, 90)):
315
+
316
+ gaussians = gaussians_all_frame[:, t]
317
+
318
+ cam_poses = torch.from_numpy(orbit_camera(0, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
319
+
320
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
321
+
322
+ # cameras needed by gaussian rasterizer
323
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
324
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
325
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
326
+
327
+ result = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)
328
+ image = result['image']
329
+ alpha = result['alpha']
330
+
331
+ render_img_V.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
332
+ render_img_V = np.concatenate(render_img_V, axis=2)
333
+ render_img_TV.append(render_img_V)
334
+ render_img_TV = np.concatenate(render_img_TV, axis=0)
335
+
336
+
337
+ images = []
338
+ azimuth = np.arange(0, 360, 1*30/VIDEO_FPS, dtype=np.int32)
339
+ elevation = 0
340
+ t = 0
341
+ delta_t = ANIM_FPS / VIDEO_FPS
342
+ for azi in azimuth:
343
+ if azi in [0, 90, 180, 270]:
344
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
345
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
346
+
347
+ # cameras needed by gaussian rasterizer
348
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
349
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
350
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
351
+
352
+ for _ in range(45):
353
+ gaussians = gaussians_all_frame[:, int(t) % gaussians_all_frame.shape[1]]
354
+ t += delta_t
355
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
356
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
357
+ else:
358
+ gaussians = gaussians_all_frame[:, int(t) % gaussians_all_frame.shape[1]]
359
+ t += delta_t
360
+
361
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
362
+
363
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
364
+
365
+ # cameras needed by gaussian rasterizer
366
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
367
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
368
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
369
+
370
+ image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), bg_color=bg_color)['image']
371
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
372
+
373
+ images = np.concatenate(images, axis=0)
374
+
375
+ torch.cuda.empty_cache()
376
+
377
+
378
+ imageio.mimwrite(os.path.join(opt.workspace, f'{sv_name}_{name}_fixed.mp4'), render_img_TV, fps=ANIM_FPS)
379
+ print("Fixed video saved.")
380
+ imageio.mimwrite(os.path.join(opt.workspace, f'{sv_name}_{name}.mp4'), images, fps=VIDEO_FPS)
381
+ print("Stop video saved.")
382
+
383
+
384
+ assert opt.test_path is not None
385
+
386
+ if os.path.isdir(opt.test_path):
387
+ file_paths = glob.glob(os.path.join(opt.test_path, "*"))
388
+ else:
389
+ file_paths = [opt.test_path]
390
+
391
+ for path in sorted(file_paths):
392
+ process(opt, path)
main.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import tyro
17
+ import time
18
+ import random
19
+
20
+ import torch
21
+ from core.options import AllConfigs
22
+ from core.models import LGM
23
+
24
+ from accelerate import Accelerator, DistributedDataParallelKwargs
25
+ from safetensors.torch import load_file
26
+
27
+ import kiui
28
+ from PIL import Image
29
+
30
+ import json
31
+ import os
32
+ import numpy as np
33
+ import imageio
34
+
35
+ def main():
36
+ opt = tyro.cli(AllConfigs)
37
+
38
+ accelerator = Accelerator(
39
+ mixed_precision=opt.mixed_precision,
40
+ gradient_accumulation_steps=opt.gradient_accumulation_steps,
41
+ # kwargs_handlers=[ddp_kwargs],
42
+ )
43
+ if accelerator.is_main_process:
44
+ print(opt)
45
+
46
+ # model
47
+ model = LGM(opt)
48
+
49
+ epoch_start = 0
50
+ if os.path.exists(f'{opt.workspace}/model.safetensors') and os.path.exists(f'{opt.workspace}/metadata.json'):
51
+ opt.resume = f'{opt.workspace}/model.safetensors'
52
+ with open(f'{opt.workspace}/metadata.json', 'r') as f:
53
+ dc = json.load(f)
54
+ epoch_start = dc['epoch'] + 1
55
+
56
+
57
+ # resume
58
+ if opt.resume is not None and opt.resume != 'None':
59
+ if opt.resume.endswith('safetensors'):
60
+ ckpt = load_file(opt.resume, device='cpu')
61
+ else:
62
+ ckpt = torch.load(opt.resume, map_location='cpu')
63
+
64
+ # tolerant load (only load matching shapes)
65
+ # model.load_state_dict(ckpt, strict=False)
66
+ state_dict = model.state_dict()
67
+ for k, v in ckpt.items():
68
+ if k in state_dict:
69
+ if state_dict[k].shape == v.shape:
70
+ state_dict[k].copy_(v)
71
+ else:
72
+ accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
73
+ else:
74
+ accelerator.print(f'[WARN] unexpected param {k}: {v.shape}')
75
+
76
+ # data
77
+ if opt.data_mode == '4d':
78
+ from core.provider_objaverse_4d import ObjaverseDataset as Dataset
79
+ elif opt.data_mode == '4d_interp':
80
+ from core.provider_objaverse_4d_interp import ObjaverseDataset as Dataset
81
+ else:
82
+ raise NotImplementedError
83
+
84
+ train_dataset = Dataset(opt, training=True)
85
+ train_dataloader = torch.utils.data.DataLoader(
86
+ train_dataset,
87
+ batch_size=opt.batch_size,
88
+ shuffle=True,
89
+ num_workers=opt.num_workers,
90
+ pin_memory=True,
91
+ drop_last=True,
92
+ )
93
+
94
+ test_dataset = Dataset(opt, training=False)
95
+ test_dataloader = torch.utils.data.DataLoader(
96
+ test_dataset,
97
+ batch_size=opt.batch_size,
98
+ shuffle=False,
99
+ num_workers=0,
100
+ pin_memory=True,
101
+ drop_last=False,
102
+ )
103
+
104
+ # optimizer
105
+ optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))
106
+
107
+ # scheduler (per-iteration)
108
+ total_steps = opt.num_epochs * len(train_dataloader)
109
+ pct_start = 3000 / total_steps
110
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start)
111
+
112
+ if epoch_start > 0:
113
+ optimizer.load_state_dict(torch.load(os.path.join(opt.workspace, 'optimizer.pth'), map_location='cpu'))
114
+ scheduler.load_state_dict(torch.load(os.path.join(opt.workspace, 'scheduler.pth')))
115
+
116
+ # accelerate
117
+ model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(
118
+ model, optimizer, train_dataloader, test_dataloader, scheduler
119
+ )
120
+
121
+
122
+
123
+ # loop
124
+ os.makedirs(opt.workspace, exist_ok=True)
125
+ end_time = time.time()
126
+ for epoch in range(epoch_start, opt.num_epochs):
127
+ # train
128
+ model.train()
129
+ total_loss = 0
130
+ total_psnr = 0
131
+ for i, data in enumerate(train_dataloader):
132
+ with accelerator.accumulate(model):
133
+
134
+ optimizer.zero_grad()
135
+
136
+ step_ratio = (epoch + i / len(train_dataloader)) / opt.num_epochs
137
+
138
+ out = model(data, step_ratio)
139
+ loss = out['loss']
140
+ psnr = out['psnr']
141
+ accelerator.backward(loss)
142
+
143
+ # gradient clipping
144
+ if accelerator.sync_gradients:
145
+ accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip)
146
+
147
+ optimizer.step()
148
+ scheduler.step()
149
+
150
+ total_loss += loss.detach()
151
+ total_psnr += psnr.detach()
152
+
153
+ if accelerator.is_main_process:
154
+ # logging
155
+ if i % 10 == 0:
156
+ mem_free, mem_total = torch.cuda.mem_get_info()
157
+ print(f"[INFO] {i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} step_ratio: {step_ratio:.4f} loss: {loss.item():.6f} time: {time.time() - end_time:.6f}")
158
+ end_time = time.time()
159
+
160
+ # save log images
161
+ if i % 500 == 0:
162
+ if '4d' in opt.data_mode:
163
+ B, T, V = opt.batch_size, opt.num_frames, opt.num_views
164
+
165
+ gt_images = data['images_output'].reshape(B, T, V, *data['images_output'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
166
+ pred_images = out['images_pred'].reshape(B, T, V, *out['images_pred'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
167
+
168
+ train_gt_images = []
169
+ train_pred_images = []
170
+ for t in range(T):
171
+ train_gt_images_V = []
172
+ train_pred_images_V = []
173
+ for v in range(V):
174
+ train_gt_images_V.append((gt_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
175
+ train_pred_images_V.append((pred_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
176
+ train_gt_images.append(np.concatenate(train_gt_images_V, axis=2))
177
+ train_pred_images.append(np.concatenate(train_pred_images_V, axis=2))
178
+ train_gt_images = np.concatenate(train_gt_images, axis=0)
179
+ train_pred_images = np.concatenate(train_pred_images, axis=0)
180
+ imageio.mimwrite(f'{opt.workspace}/train_gt_images_{epoch}_{i}.mp4', train_gt_images, fps=8)
181
+ imageio.mimwrite(f'{opt.workspace}/train_pred_images_{epoch}_{i}.mp4', train_pred_images, fps=8)
182
+
183
+
184
+ elif '3d' in opt.data_mode:
185
+ gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
186
+ gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
187
+ kiui.write_image(f'{opt.workspace}/train_gt_images_{epoch}_{i}.jpg', gt_images)
188
+
189
+ pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
190
+ pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
191
+ kiui.write_image(f'{opt.workspace}/train_pred_images_{epoch}_{i}.jpg', pred_images)
192
+ else:
193
+ raise NotImplementedError
194
+
195
+
196
+ total_loss = accelerator.gather_for_metrics(total_loss).mean()
197
+ total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
198
+ if accelerator.is_main_process:
199
+ total_loss /= len(train_dataloader)
200
+ total_psnr /= len(train_dataloader)
201
+ accelerator.print(f"[train] epoch: {epoch} loss: {total_loss.item():.6f} psnr: {total_psnr.item():.4f}")
202
+
203
+ # checkpoint
204
+ accelerator.wait_for_everyone()
205
+ accelerator.save_model(model, opt.workspace)
206
+ accelerator.save_model(model, os.path.join(opt.workspace, 'backup'))
207
+ if accelerator.is_main_process:
208
+ torch.save(optimizer.state_dict(), os.path.join(opt.workspace, 'optimizer.pth'))
209
+ torch.save(scheduler.state_dict(), os.path.join(opt.workspace, 'scheduler.pth'))
210
+ with open(f'{opt.workspace}/metadata.json', 'w') as f:
211
+ json.dump({'epoch': epoch}, f)
212
+
213
+ torch.save(optimizer.state_dict(), os.path.join(opt.workspace, 'backup', 'optimizer.pth'))
214
+ torch.save(scheduler.state_dict(), os.path.join(opt.workspace, 'backup', 'scheduler.pth'))
215
+ with open(f'{opt.workspace}/backup/metadata.json', 'w') as f:
216
+ json.dump({'epoch': epoch}, f)
217
+
218
+
219
+ # eval
220
+ with torch.no_grad():
221
+ model.eval()
222
+ total_psnr = 0
223
+ for i, data in enumerate(test_dataloader):
224
+
225
+ out = model(data)
226
+
227
+ psnr = out['psnr']
228
+ total_psnr += psnr.detach()
229
+
230
+ # save some images
231
+ if accelerator.is_main_process:
232
+ if '4d' in opt.data_mode:
233
+ B, T, V = opt.batch_size, opt.num_frames, opt.num_views
234
+
235
+ gt_images = data['images_output'].reshape(-1, T, V, *data['images_output'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
236
+ pred_images = out['images_pred'].reshape(-1, T, V, *out['images_pred'].shape[2:]).detach() # [B, V, 3, output_size, output_size]
237
+
238
+ eval_gt_images = []
239
+ eval_pred_images = []
240
+ for t in range(T):
241
+ eval_gt_images_V = []
242
+ eval_pred_images_V = []
243
+ for v in range(V):
244
+ eval_gt_images_V.append((gt_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
245
+ eval_pred_images_V.append((pred_images[:, t, v].permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
246
+ eval_gt_images.append(np.concatenate(eval_gt_images_V, axis=2))
247
+ eval_pred_images.append(np.concatenate(eval_pred_images_V, axis=2))
248
+ eval_gt_images = np.concatenate(eval_gt_images, axis=0)
249
+ eval_pred_images = np.concatenate(eval_pred_images, axis=0)
250
+ imageio.mimwrite(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.mp4', eval_gt_images, fps=8)
251
+ imageio.mimwrite(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.mp4', eval_pred_images, fps=8)
252
+
253
+ elif '3d' in opt.data_mode:
254
+ gt_images = data['images_output'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
255
+ gt_images = gt_images.transpose(0, 3, 1, 4, 2).reshape(-1, gt_images.shape[1] * gt_images.shape[3], 3) # [B*output_size, V*output_size, 3]
256
+ kiui.write_image(f'{opt.workspace}/eval_gt_images_{epoch}_{i}.jpg', gt_images)
257
+
258
+ pred_images = out['images_pred'].detach().cpu().numpy() # [B, V, 3, output_size, output_size]
259
+ pred_images = pred_images.transpose(0, 3, 1, 4, 2).reshape(-1, pred_images.shape[1] * pred_images.shape[3], 3)
260
+ kiui.write_image(f'{opt.workspace}/eval_pred_images_{epoch}_{i}.jpg', pred_images)
261
+ else:
262
+ raise NotImplementedError
263
+
264
+ torch.cuda.empty_cache()
265
+
266
+ total_psnr = accelerator.gather_for_metrics(total_psnr).mean()
267
+ if accelerator.is_main_process:
268
+ total_psnr /= len(test_dataloader)
269
+ accelerator.print(f"[eval] epoch: {epoch} psnr: {psnr:.4f}")
270
+
271
+
272
+
273
+ if __name__ == "__main__":
274
+ main()
mvdream/mv_unet.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=0, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(elevation, azimuth, radius=1) # [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an [N x dim] Tensor of positional embeddings.
50
+ """
51
+ if not repeat_only:
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period)
55
+ * torch.arange(start=0, end=half, dtype=torch.float32)
56
+ / half
57
+ ).to(device=timesteps.device)
58
+ args = timesteps[:, None] * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat(
62
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
+ )
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ # import pdb; pdb.set_trace()
67
+ return embedding
68
+
69
+
70
+ def zero_module(module):
71
+ """
72
+ Zero out the parameters of a module and return it.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ def conv_nd(dims, *args, **kwargs):
80
+ """
81
+ Create a 1D, 2D, or 3D convolution module.
82
+ """
83
+ if dims == 1:
84
+ return nn.Conv1d(*args, **kwargs)
85
+ elif dims == 2:
86
+ return nn.Conv2d(*args, **kwargs)
87
+ elif dims == 3:
88
+ return nn.Conv3d(*args, **kwargs)
89
+ raise ValueError(f"unsupported dimensions: {dims}")
90
+
91
+
92
+ def avg_pool_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D average pooling module.
95
+ """
96
+ if dims == 1:
97
+ return nn.AvgPool1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.AvgPool2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.AvgPool3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def default(val, d):
106
+ if val is not None:
107
+ return val
108
+ return d() if isfunction(d) else d
109
+
110
+
111
+ class GEGLU(nn.Module):
112
+ def __init__(self, dim_in, dim_out):
113
+ super().__init__()
114
+ self.proj = nn.Linear(dim_in, dim_out * 2)
115
+
116
+ def forward(self, x):
117
+ x, gate = self.proj(x).chunk(2, dim=-1)
118
+ return x * F.gelu(gate)
119
+
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
+ super().__init__()
124
+ inner_dim = int(dim * mult)
125
+ dim_out = default(dim_out, dim)
126
+ project_in = (
127
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
+ if not glu
129
+ else GEGLU(dim, inner_dim)
130
+ )
131
+
132
+ self.net = nn.Sequential(
133
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
+ )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
+
139
+
140
+ class MemoryEfficientCrossAttention(nn.Module):
141
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
+ def __init__(
143
+ self,
144
+ query_dim,
145
+ context_dim=None,
146
+ heads=8,
147
+ dim_head=64,
148
+ dropout=0.0,
149
+ ip_dim=0,
150
+ ip_weight=1,
151
+ ):
152
+ super().__init__()
153
+
154
+ inner_dim = dim_head * heads
155
+ context_dim = default(context_dim, query_dim)
156
+
157
+ self.heads = heads
158
+ self.dim_head = dim_head
159
+
160
+ self.ip_dim = ip_dim
161
+ self.ip_weight = ip_weight
162
+
163
+ if self.ip_dim > 0:
164
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
+
167
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
+
171
+ self.to_out = nn.Sequential(
172
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
+ )
174
+ self.attention_op: Optional[Any] = None
175
+
176
+ def forward(self, x, context=None):
177
+ q = self.to_q(x)
178
+ context = default(context, x)
179
+
180
+ if self.ip_dim > 0:
181
+ # context: [B, 77 + 16(ip), 1024]
182
+ token_len = context.shape[1]
183
+ context_ip = context[:, -self.ip_dim :, :]
184
+ k_ip = self.to_k_ip(context_ip)
185
+ v_ip = self.to_v_ip(context_ip)
186
+ context = context[:, : (token_len - self.ip_dim), :]
187
+
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+
191
+ b, _, _ = q.shape
192
+ q, k, v = map(
193
+ lambda t: t.unsqueeze(3)
194
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
195
+ .permute(0, 2, 1, 3)
196
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
197
+ .contiguous(),
198
+ (q, k, v),
199
+ )
200
+
201
+ # actually compute the attention, what we cannot get enough of
202
+ out = xformers.ops.memory_efficient_attention(
203
+ q, k, v, attn_bias=None, op=self.attention_op
204
+ )
205
+
206
+ if self.ip_dim > 0:
207
+ k_ip, v_ip = map(
208
+ lambda t: t.unsqueeze(3)
209
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
210
+ .permute(0, 2, 1, 3)
211
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
212
+ .contiguous(),
213
+ (k_ip, v_ip),
214
+ )
215
+ # actually compute the attention, what we cannot get enough of
216
+ out_ip = xformers.ops.memory_efficient_attention(
217
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
+ )
219
+ out = out + self.ip_weight * out_ip
220
+
221
+ out = (
222
+ out.unsqueeze(0)
223
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
224
+ .permute(0, 2, 1, 3)
225
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
226
+ )
227
+ return self.to_out(out)
228
+
229
+
230
+ class BasicTransformerBlock3D(nn.Module):
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ n_heads,
236
+ d_head,
237
+ context_dim,
238
+ dropout=0.0,
239
+ gated_ff=True,
240
+ ip_dim=0,
241
+ ip_weight=1,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.attn1 = MemoryEfficientCrossAttention(
246
+ query_dim=dim,
247
+ context_dim=None, # self-attention
248
+ heads=n_heads,
249
+ dim_head=d_head,
250
+ dropout=dropout,
251
+ )
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = MemoryEfficientCrossAttention(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ # ip only applies to cross-attention
260
+ ip_dim=ip_dim,
261
+ ip_weight=ip_weight,
262
+ )
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x, context=None, num_frames=1):
268
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
+ x = self.attn1(self.norm1(x), context=None) + x
270
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
+ x = self.attn2(self.norm2(x), context=context) + x
272
+ x = self.ff(self.norm3(x)) + x
273
+ return x
274
+
275
+
276
+ class SpatialTransformer3D(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ in_channels,
281
+ n_heads,
282
+ d_head,
283
+ context_dim, # cross attention input dim
284
+ depth=1,
285
+ dropout=0.0,
286
+ ip_dim=0,
287
+ ip_weight=1,
288
+ ):
289
+ super().__init__()
290
+
291
+ if not isinstance(context_dim, list):
292
+ context_dim = [context_dim]
293
+
294
+ self.in_channels = in_channels
295
+
296
+ inner_dim = n_heads * d_head
297
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Linear(in_channels, inner_dim)
299
+
300
+ self.transformer_blocks = nn.ModuleList(
301
+ [
302
+ BasicTransformerBlock3D(
303
+ inner_dim,
304
+ n_heads,
305
+ d_head,
306
+ context_dim=context_dim[d],
307
+ dropout=dropout,
308
+ ip_dim=ip_dim,
309
+ ip_weight=ip_weight,
310
+ )
311
+ for d in range(depth)
312
+ ]
313
+ )
314
+
315
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
+
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ if not isinstance(context, list):
321
+ context = [context]
322
+ b, c, h, w = x.shape
323
+ x_in = x
324
+ x = self.norm(x)
325
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
+ x = self.proj_in(x)
327
+ for i, block in enumerate(self.transformer_blocks):
328
+ x = block(x, context=context[i], num_frames=num_frames)
329
+ x = self.proj_out(x)
330
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
+
332
+ return x + x_in
333
+
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(self, *, dim, dim_head=64, heads=8):
337
+ super().__init__()
338
+ self.scale = dim_head ** -0.5
339
+ self.dim_head = dim_head
340
+ self.heads = heads
341
+ inner_dim = dim_head * heads
342
+
343
+ self.norm1 = nn.LayerNorm(dim)
344
+ self.norm2 = nn.LayerNorm(dim)
345
+
346
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
+
350
+ def forward(self, x, latents):
351
+ """
352
+ Args:
353
+ x (torch.Tensor): image features
354
+ shape (b, n1, D)
355
+ latent (torch.Tensor): latent features
356
+ shape (b, n2, D)
357
+ """
358
+ x = self.norm1(x)
359
+ latents = self.norm2(latents)
360
+
361
+ b, l, _ = latents.shape
362
+
363
+ q = self.to_q(latents)
364
+ kv_input = torch.cat((x, latents), dim=-2)
365
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
+
367
+ q, k, v = map(
368
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
+ .transpose(1, 2)
370
+ .reshape(b, self.heads, t.shape[1], -1)
371
+ .contiguous(),
372
+ (q, k, v),
373
+ )
374
+
375
+ # attention
376
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
+ out = weight @ v
380
+
381
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
+
383
+ return self.to_out(out)
384
+
385
+
386
+ class Resampler(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim=1024,
390
+ depth=8,
391
+ dim_head=64,
392
+ heads=16,
393
+ num_queries=8,
394
+ embedding_dim=768,
395
+ output_dim=1024,
396
+ ff_mult=4,
397
+ ):
398
+ super().__init__()
399
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
+ self.proj_in = nn.Linear(embedding_dim, dim)
401
+ self.proj_out = nn.Linear(dim, output_dim)
402
+ self.norm_out = nn.LayerNorm(output_dim)
403
+
404
+ self.layers = nn.ModuleList([])
405
+ for _ in range(depth):
406
+ self.layers.append(
407
+ nn.ModuleList(
408
+ [
409
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
+ nn.Sequential(
411
+ nn.LayerNorm(dim),
412
+ nn.Linear(dim, dim * ff_mult, bias=False),
413
+ nn.GELU(),
414
+ nn.Linear(dim * ff_mult, dim, bias=False),
415
+ )
416
+ ]
417
+ )
418
+ )
419
+
420
+ def forward(self, x):
421
+ latents = self.latents.repeat(x.size(0), 1, 1)
422
+ x = self.proj_in(x)
423
+ for attn, ff in self.layers:
424
+ latents = attn(x, latents) + latents
425
+ latents = ff(latents) + latents
426
+
427
+ latents = self.proj_out(latents)
428
+ return self.norm_out(latents)
429
+
430
+
431
+ class CondSequential(nn.Sequential):
432
+ """
433
+ A sequential module that passes timestep embeddings to the children that
434
+ support it as an extra input.
435
+ """
436
+
437
+ def forward(self, x, emb, context=None, num_frames=1):
438
+ for layer in self:
439
+ if isinstance(layer, ResBlock):
440
+ x = layer(x, emb)
441
+ elif isinstance(layer, SpatialTransformer3D):
442
+ x = layer(x, context, num_frames=num_frames)
443
+ else:
444
+ x = layer(x)
445
+ return x
446
+
447
+
448
+ class Upsample(nn.Module):
449
+ """
450
+ An upsampling layer with an optional convolution.
451
+ :param channels: channels in the inputs and outputs.
452
+ :param use_conv: a bool determining if a convolution is applied.
453
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
+ upsampling occurs in the inner-two dimensions.
455
+ """
456
+
457
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
+ super().__init__()
459
+ self.channels = channels
460
+ self.out_channels = out_channels or channels
461
+ self.use_conv = use_conv
462
+ self.dims = dims
463
+ if use_conv:
464
+ self.conv = conv_nd(
465
+ dims, self.channels, self.out_channels, 3, padding=padding
466
+ )
467
+
468
+ def forward(self, x):
469
+ assert x.shape[1] == self.channels
470
+ if self.dims == 3:
471
+ x = F.interpolate(
472
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
+ )
474
+ else:
475
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
476
+ if self.use_conv:
477
+ x = self.conv(x)
478
+ return x
479
+
480
+
481
+ class Downsample(nn.Module):
482
+ """
483
+ A downsampling layer with an optional convolution.
484
+ :param channels: channels in the inputs and outputs.
485
+ :param use_conv: a bool determining if a convolution is applied.
486
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
+ downsampling occurs in the inner-two dimensions.
488
+ """
489
+
490
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
+ super().__init__()
492
+ self.channels = channels
493
+ self.out_channels = out_channels or channels
494
+ self.use_conv = use_conv
495
+ self.dims = dims
496
+ stride = 2 if dims != 3 else (1, 2, 2)
497
+ if use_conv:
498
+ self.op = conv_nd(
499
+ dims,
500
+ self.channels,
501
+ self.out_channels,
502
+ 3,
503
+ stride=stride,
504
+ padding=padding,
505
+ )
506
+ else:
507
+ assert self.channels == self.out_channels
508
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
+
510
+ def forward(self, x):
511
+ assert x.shape[1] == self.channels
512
+ return self.op(x)
513
+
514
+
515
+ class ResBlock(nn.Module):
516
+ """
517
+ A residual block that can optionally change the number of channels.
518
+ :param channels: the number of input channels.
519
+ :param emb_channels: the number of timestep embedding channels.
520
+ :param dropout: the rate of dropout.
521
+ :param out_channels: if specified, the number of out channels.
522
+ :param use_conv: if True and out_channels is specified, use a spatial
523
+ convolution instead of a smaller 1x1 convolution to change the
524
+ channels in the skip connection.
525
+ :param dims: determines if the signal is 1D, 2D, or 3D.
526
+ :param up: if True, use this block for upsampling.
527
+ :param down: if True, use this block for downsampling.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ channels,
533
+ emb_channels,
534
+ dropout,
535
+ out_channels=None,
536
+ use_conv=False,
537
+ use_scale_shift_norm=False,
538
+ dims=2,
539
+ up=False,
540
+ down=False,
541
+ ):
542
+ super().__init__()
543
+ self.channels = channels
544
+ self.emb_channels = emb_channels
545
+ self.dropout = dropout
546
+ self.out_channels = out_channels or channels
547
+ self.use_conv = use_conv
548
+ self.use_scale_shift_norm = use_scale_shift_norm
549
+
550
+ self.in_layers = nn.Sequential(
551
+ nn.GroupNorm(32, channels),
552
+ nn.SiLU(),
553
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
+ )
555
+
556
+ self.updown = up or down
557
+
558
+ if up:
559
+ self.h_upd = Upsample(channels, False, dims)
560
+ self.x_upd = Upsample(channels, False, dims)
561
+ elif down:
562
+ self.h_upd = Downsample(channels, False, dims)
563
+ self.x_upd = Downsample(channels, False, dims)
564
+ else:
565
+ self.h_upd = self.x_upd = nn.Identity()
566
+
567
+ self.emb_layers = nn.Sequential(
568
+ nn.SiLU(),
569
+ nn.Linear(
570
+ emb_channels,
571
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
+ ),
573
+ )
574
+ self.out_layers = nn.Sequential(
575
+ nn.GroupNorm(32, self.out_channels),
576
+ nn.SiLU(),
577
+ nn.Dropout(p=dropout),
578
+ zero_module(
579
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
+ ),
581
+ )
582
+
583
+ if self.out_channels == channels:
584
+ self.skip_connection = nn.Identity()
585
+ elif use_conv:
586
+ self.skip_connection = conv_nd(
587
+ dims, channels, self.out_channels, 3, padding=1
588
+ )
589
+ else:
590
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
+
592
+ def forward(self, x, emb):
593
+ if self.updown:
594
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
+ h = in_rest(x)
596
+ h = self.h_upd(h)
597
+ x = self.x_upd(x)
598
+ h = in_conv(h)
599
+ else:
600
+ h = self.in_layers(x)
601
+ emb_out = self.emb_layers(emb).type(h.dtype)
602
+ while len(emb_out.shape) < len(h.shape):
603
+ emb_out = emb_out[..., None]
604
+ if self.use_scale_shift_norm:
605
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
607
+ h = out_norm(h) * (1 + scale) + shift
608
+ h = out_rest(h)
609
+ else:
610
+ h = h + emb_out
611
+ h = self.out_layers(h)
612
+ return self.skip_connection(x) + h
613
+
614
+
615
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
+ """
617
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
+ :param in_channels: channels in the input Tensor.
619
+ :param model_channels: base channel count for the model.
620
+ :param out_channels: channels in the output Tensor.
621
+ :param num_res_blocks: number of residual blocks per downsample.
622
+ :param attention_resolutions: a collection of downsample rates at which
623
+ attention will take place. May be a set, list, or tuple.
624
+ For example, if this contains 4, then at 4x downsampling, attention
625
+ will be used.
626
+ :param dropout: the dropout probability.
627
+ :param channel_mult: channel multiplier for each level of the UNet.
628
+ :param conv_resample: if True, use learned convolutions for upsampling and
629
+ downsampling.
630
+ :param dims: determines if the signal is 1D, 2D, or 3D.
631
+ :param num_classes: if specified (as an int), then this model will be
632
+ class-conditional with `num_classes` classes.
633
+ :param num_heads: the number of attention heads in each attention layer.
634
+ :param num_heads_channels: if specified, ignore num_heads and instead use
635
+ a fixed channel width per attention head.
636
+ :param num_heads_upsample: works with num_heads to set a different number
637
+ of heads for upsampling. Deprecated.
638
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
+ :param resblock_updown: use residual blocks for up/downsampling.
640
+ :param use_new_attention_order: use a different attention pattern for potentially
641
+ increased efficiency.
642
+ :param camera_dim: dimensionality of camera input.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ image_size,
648
+ in_channels,
649
+ model_channels,
650
+ out_channels,
651
+ num_res_blocks,
652
+ attention_resolutions,
653
+ dropout=0,
654
+ channel_mult=(1, 2, 4, 8),
655
+ conv_resample=True,
656
+ dims=2,
657
+ num_classes=None,
658
+ num_heads=-1,
659
+ num_head_channels=-1,
660
+ num_heads_upsample=-1,
661
+ use_scale_shift_norm=False,
662
+ resblock_updown=False,
663
+ transformer_depth=1,
664
+ context_dim=None,
665
+ n_embed=None,
666
+ num_attention_blocks=None,
667
+ adm_in_channels=None,
668
+ camera_dim=None,
669
+ ip_dim=0, # imagedream uses ip_dim > 0
670
+ ip_weight=1.0,
671
+ **kwargs,
672
+ ):
673
+ super().__init__()
674
+ assert context_dim is not None
675
+
676
+ if num_heads_upsample == -1:
677
+ num_heads_upsample = num_heads
678
+
679
+ if num_heads == -1:
680
+ assert (
681
+ num_head_channels != -1
682
+ ), "Either num_heads or num_head_channels has to be set"
683
+
684
+ if num_head_channels == -1:
685
+ assert (
686
+ num_heads != -1
687
+ ), "Either num_heads or num_head_channels has to be set"
688
+
689
+ self.image_size = image_size
690
+ self.in_channels = in_channels
691
+ self.model_channels = model_channels
692
+ self.out_channels = out_channels
693
+ if isinstance(num_res_blocks, int):
694
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
+ else:
696
+ if len(num_res_blocks) != len(channel_mult):
697
+ raise ValueError(
698
+ "provide num_res_blocks either as an int (globally constant) or "
699
+ "as a list/tuple (per-level) with the same length as channel_mult"
700
+ )
701
+ self.num_res_blocks = num_res_blocks
702
+
703
+ if num_attention_blocks is not None:
704
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
705
+ assert all(
706
+ map(
707
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
+ range(len(num_attention_blocks)),
709
+ )
710
+ )
711
+ print(
712
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
+ f"attention will still not be set."
716
+ )
717
+
718
+ self.attention_resolutions = attention_resolutions
719
+ self.dropout = dropout
720
+ self.channel_mult = channel_mult
721
+ self.conv_resample = conv_resample
722
+ self.num_classes = num_classes
723
+ self.num_heads = num_heads
724
+ self.num_head_channels = num_head_channels
725
+ self.num_heads_upsample = num_heads_upsample
726
+ self.predict_codebook_ids = n_embed is not None
727
+
728
+ self.ip_dim = ip_dim
729
+ self.ip_weight = ip_weight
730
+
731
+ if self.ip_dim > 0:
732
+ self.image_embed = Resampler(
733
+ dim=context_dim,
734
+ depth=4,
735
+ dim_head=64,
736
+ heads=12,
737
+ num_queries=ip_dim, # num token
738
+ embedding_dim=1280,
739
+ output_dim=context_dim,
740
+ ff_mult=4,
741
+ )
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ nn.Linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ nn.Linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ if camera_dim is not None:
751
+ time_embed_dim = model_channels * 4
752
+ self.camera_embed = nn.Sequential(
753
+ nn.Linear(camera_dim, time_embed_dim),
754
+ nn.SiLU(),
755
+ nn.Linear(time_embed_dim, time_embed_dim),
756
+ )
757
+
758
+ if self.num_classes is not None:
759
+ if isinstance(self.num_classes, int):
760
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
+ elif self.num_classes == "continuous":
762
+ # print("setting up linear c_adm embedding layer")
763
+ self.label_emb = nn.Linear(1, time_embed_dim)
764
+ elif self.num_classes == "sequential":
765
+ assert adm_in_channels is not None
766
+ self.label_emb = nn.Sequential(
767
+ nn.Sequential(
768
+ nn.Linear(adm_in_channels, time_embed_dim),
769
+ nn.SiLU(),
770
+ nn.Linear(time_embed_dim, time_embed_dim),
771
+ )
772
+ )
773
+ else:
774
+ raise ValueError()
775
+
776
+ self.input_blocks = nn.ModuleList(
777
+ [
778
+ CondSequential(
779
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
+ )
781
+ ]
782
+ )
783
+ self._feature_size = model_channels
784
+ input_block_chans = [model_channels]
785
+ ch = model_channels
786
+ ds = 1
787
+ for level, mult in enumerate(channel_mult):
788
+ for nr in range(self.num_res_blocks[level]):
789
+ layers: List[Any] = [
790
+ ResBlock(
791
+ ch,
792
+ time_embed_dim,
793
+ dropout,
794
+ out_channels=mult * model_channels,
795
+ dims=dims,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = mult * model_channels
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+
807
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
+ layers.append(
809
+ SpatialTransformer3D(
810
+ ch,
811
+ num_heads,
812
+ dim_head,
813
+ context_dim=context_dim,
814
+ depth=transformer_depth,
815
+ ip_dim=self.ip_dim,
816
+ ip_weight=self.ip_weight,
817
+ )
818
+ )
819
+ self.input_blocks.append(CondSequential(*layers))
820
+ self._feature_size += ch
821
+ input_block_chans.append(ch)
822
+ if level != len(channel_mult) - 1:
823
+ out_ch = ch
824
+ self.input_blocks.append(
825
+ CondSequential(
826
+ ResBlock(
827
+ ch,
828
+ time_embed_dim,
829
+ dropout,
830
+ out_channels=out_ch,
831
+ dims=dims,
832
+ use_scale_shift_norm=use_scale_shift_norm,
833
+ down=True,
834
+ )
835
+ if resblock_updown
836
+ else Downsample(
837
+ ch, conv_resample, dims=dims, out_channels=out_ch
838
+ )
839
+ )
840
+ )
841
+ ch = out_ch
842
+ input_block_chans.append(ch)
843
+ ds *= 2
844
+ self._feature_size += ch
845
+
846
+ if num_head_channels == -1:
847
+ dim_head = ch // num_heads
848
+ else:
849
+ num_heads = ch // num_head_channels
850
+ dim_head = num_head_channels
851
+
852
+ self.middle_block = CondSequential(
853
+ ResBlock(
854
+ ch,
855
+ time_embed_dim,
856
+ dropout,
857
+ dims=dims,
858
+ use_scale_shift_norm=use_scale_shift_norm,
859
+ ),
860
+ SpatialTransformer3D(
861
+ ch,
862
+ num_heads,
863
+ dim_head,
864
+ context_dim=context_dim,
865
+ depth=transformer_depth,
866
+ ip_dim=self.ip_dim,
867
+ ip_weight=self.ip_weight,
868
+ ),
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ ),
876
+ )
877
+ self._feature_size += ch
878
+
879
+ self.output_blocks = nn.ModuleList([])
880
+ for level, mult in list(enumerate(channel_mult))[::-1]:
881
+ for i in range(self.num_res_blocks[level] + 1):
882
+ ich = input_block_chans.pop()
883
+ layers = [
884
+ ResBlock(
885
+ ch + ich,
886
+ time_embed_dim,
887
+ dropout,
888
+ out_channels=model_channels * mult,
889
+ dims=dims,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ )
892
+ ]
893
+ ch = model_channels * mult
894
+ if ds in attention_resolutions:
895
+ if num_head_channels == -1:
896
+ dim_head = ch // num_heads
897
+ else:
898
+ num_heads = ch // num_head_channels
899
+ dim_head = num_head_channels
900
+
901
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
902
+ layers.append(
903
+ SpatialTransformer3D(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ context_dim=context_dim,
908
+ depth=transformer_depth,
909
+ ip_dim=self.ip_dim,
910
+ ip_weight=self.ip_weight,
911
+ )
912
+ )
913
+ if level and i == self.num_res_blocks[level]:
914
+ out_ch = ch
915
+ layers.append(
916
+ ResBlock(
917
+ ch,
918
+ time_embed_dim,
919
+ dropout,
920
+ out_channels=out_ch,
921
+ dims=dims,
922
+ use_scale_shift_norm=use_scale_shift_norm,
923
+ up=True,
924
+ )
925
+ if resblock_updown
926
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
+ )
928
+ ds //= 2
929
+ self.output_blocks.append(CondSequential(*layers))
930
+ self._feature_size += ch
931
+
932
+ self.out = nn.Sequential(
933
+ nn.GroupNorm(32, ch),
934
+ nn.SiLU(),
935
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
+ )
937
+ if self.predict_codebook_ids:
938
+ self.id_predictor = nn.Sequential(
939
+ nn.GroupNorm(32, ch),
940
+ conv_nd(dims, model_channels, n_embed, 1),
941
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
+ )
943
+
944
+ def forward(
945
+ self,
946
+ x,
947
+ timesteps=None,
948
+ context=None,
949
+ y=None,
950
+ camera=None,
951
+ num_frames=1,
952
+ ip=None,
953
+ ip_img=None,
954
+ **kwargs,
955
+ ):
956
+ """
957
+ Apply the model to an input batch.
958
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
+ :param timesteps: a 1-D batch of timesteps.
960
+ :param context: conditioning plugged in via crossattn
961
+ :param y: an [N] Tensor of labels, if class-conditional.
962
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
963
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
+ """
965
+ assert (
966
+ x.shape[0] % num_frames == 0
967
+ ), "input batch size must be dividable by num_frames!"
968
+ assert (y is not None) == (
969
+ self.num_classes is not None
970
+ ), "must specify y if and only if the model is class-conditional"
971
+
972
+ hs = []
973
+
974
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
+
976
+ emb = self.time_embed(t_emb)
977
+
978
+ if self.num_classes is not None:
979
+ assert y is not None
980
+ assert y.shape[0] == x.shape[0]
981
+ emb = emb + self.label_emb(y)
982
+
983
+ # Add camera embeddings
984
+ if camera is not None:
985
+ emb = emb + self.camera_embed(camera)
986
+
987
+ # imagedream variant
988
+ if self.ip_dim > 0:
989
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
+ ip_emb = self.image_embed(ip)
991
+ context = torch.cat((context, ip_emb), 1)
992
+
993
+ h = x
994
+ for module in self.input_blocks:
995
+ h = module(h, emb, context, num_frames=num_frames)
996
+ hs.append(h)
997
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
998
+ for module in self.output_blocks:
999
+ h = torch.cat([h, hs.pop()], dim=1)
1000
+ h = module(h, emb, context, num_frames=num_frames)
1001
+ h = h.type(x.dtype)
1002
+ if self.predict_codebook_ids:
1003
+ return self.id_predictor(h)
1004
+ else:
1005
+ return self.out(h)
mvdream/pipeline_mvdream.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import inspect
4
+ import numpy as np
5
+ from typing import Callable, List, Optional, Union
6
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
7
+ from diffusers import AutoencoderKL, DiffusionPipeline
8
+ from diffusers.utils import (
9
+ deprecate,
10
+ is_accelerate_available,
11
+ is_accelerate_version,
12
+ logging,
13
+ )
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.schedulers import DDIMScheduler
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from mvdream.mv_unet import MultiViewUNetModel, get_camera
19
+
20
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
+
22
+
23
+ class MVDreamPipeline(DiffusionPipeline):
24
+
25
+ _optional_components = ["feature_extractor", "image_encoder"]
26
+
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKL,
30
+ unet: MultiViewUNetModel,
31
+ tokenizer: CLIPTokenizer,
32
+ text_encoder: CLIPTextModel,
33
+ scheduler: DDIMScheduler,
34
+ # imagedream variant
35
+ feature_extractor: CLIPImageProcessor,
36
+ image_encoder: CLIPVisionModel,
37
+ requires_safety_checker: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
42
+ deprecation_message = (
43
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
44
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
45
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
46
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
47
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
48
+ " file"
49
+ )
50
+ deprecate(
51
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
52
+ )
53
+ new_config = dict(scheduler.config)
54
+ new_config["steps_offset"] = 1
55
+ scheduler._internal_dict = FrozenDict(new_config)
56
+
57
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
58
+ deprecation_message = (
59
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
60
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
61
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
62
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
63
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
64
+ )
65
+ deprecate(
66
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
67
+ )
68
+ new_config = dict(scheduler.config)
69
+ new_config["clip_sample"] = False
70
+ scheduler._internal_dict = FrozenDict(new_config)
71
+
72
+ self.register_modules(
73
+ vae=vae,
74
+ unet=unet,
75
+ scheduler=scheduler,
76
+ tokenizer=tokenizer,
77
+ text_encoder=text_encoder,
78
+ feature_extractor=feature_extractor,
79
+ image_encoder=image_encoder,
80
+ )
81
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
83
+
84
+ def enable_vae_slicing(self):
85
+ r"""
86
+ Enable sliced VAE decoding.
87
+
88
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
89
+ steps. This is useful to save some memory and allow larger batch sizes.
90
+ """
91
+ self.vae.enable_slicing()
92
+
93
+ def disable_vae_slicing(self):
94
+ r"""
95
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
96
+ computing decoding in one step.
97
+ """
98
+ self.vae.disable_slicing()
99
+
100
+ def enable_vae_tiling(self):
101
+ r"""
102
+ Enable tiled VAE decoding.
103
+
104
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
105
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
106
+ """
107
+ self.vae.enable_tiling()
108
+
109
+ def disable_vae_tiling(self):
110
+ r"""
111
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
112
+ computing decoding in one step.
113
+ """
114
+ self.vae.disable_tiling()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
119
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
120
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
121
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
122
+ `enable_model_cpu_offload`, but performance is lower.
123
+ """
124
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
125
+ from accelerate import cpu_offload
126
+ else:
127
+ raise ImportError(
128
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
129
+ )
130
+
131
+ device = torch.device(f"cuda:{gpu_id}")
132
+
133
+ if self.device.type != "cpu":
134
+ self.to("cpu", silence_dtype_warnings=True)
135
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
136
+
137
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
138
+ cpu_offload(cpu_offloaded_model, device)
139
+
140
+ def enable_model_cpu_offload(self, gpu_id=0):
141
+ r"""
142
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
143
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
144
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
145
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
146
+ """
147
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
148
+ from accelerate import cpu_offload_with_hook
149
+ else:
150
+ raise ImportError(
151
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
152
+ )
153
+
154
+ device = torch.device(f"cuda:{gpu_id}")
155
+
156
+ if self.device.type != "cpu":
157
+ self.to("cpu", silence_dtype_warnings=True)
158
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
159
+
160
+ hook = None
161
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
162
+ _, hook = cpu_offload_with_hook(
163
+ cpu_offloaded_model, device, prev_module_hook=hook
164
+ )
165
+
166
+ # We'll offload the last model manually.
167
+ self.final_offload_hook = hook
168
+
169
+ @property
170
+ def _execution_device(self):
171
+ r"""
172
+ Returns the device on which the pipeline's models will be executed. After calling
173
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
174
+ hooks.
175
+ """
176
+ if not hasattr(self.unet, "_hf_hook"):
177
+ return self.device
178
+ for module in self.unet.modules():
179
+ if (
180
+ hasattr(module, "_hf_hook")
181
+ and hasattr(module._hf_hook, "execution_device")
182
+ and module._hf_hook.execution_device is not None
183
+ ):
184
+ return torch.device(module._hf_hook.execution_device)
185
+ return self.device
186
+
187
+ def _encode_prompt(
188
+ self,
189
+ prompt,
190
+ device,
191
+ num_images_per_prompt,
192
+ do_classifier_free_guidance: bool,
193
+ negative_prompt=None,
194
+ ):
195
+ r"""
196
+ Encodes the prompt into text encoder hidden states.
197
+
198
+ Args:
199
+ prompt (`str` or `List[str]`, *optional*):
200
+ prompt to be encoded
201
+ device: (`torch.device`):
202
+ torch device
203
+ num_images_per_prompt (`int`):
204
+ number of images that should be generated per prompt
205
+ do_classifier_free_guidance (`bool`):
206
+ whether to use classifier free guidance or not
207
+ negative_prompt (`str` or `List[str]`, *optional*):
208
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
209
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
210
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
211
+ prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
216
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
217
+ argument.
218
+ """
219
+ if prompt is not None and isinstance(prompt, str):
220
+ batch_size = 1
221
+ elif prompt is not None and isinstance(prompt, list):
222
+ batch_size = len(prompt)
223
+ else:
224
+ raise ValueError(
225
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
226
+ )
227
+
228
+ text_inputs = self.tokenizer(
229
+ prompt,
230
+ padding="max_length",
231
+ max_length=self.tokenizer.model_max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+ text_input_ids = text_inputs.input_ids
236
+ untruncated_ids = self.tokenizer(
237
+ prompt, padding="longest", return_tensors="pt"
238
+ ).input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
241
+ text_input_ids, untruncated_ids
242
+ ):
243
+ removed_text = self.tokenizer.batch_decode(
244
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
245
+ )
246
+ logger.warning(
247
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
248
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
249
+ )
250
+
251
+ if (
252
+ hasattr(self.text_encoder.config, "use_attention_mask")
253
+ and self.text_encoder.config.use_attention_mask
254
+ ):
255
+ attention_mask = text_inputs.attention_mask.to(device)
256
+ else:
257
+ attention_mask = None
258
+
259
+ prompt_embeds = self.text_encoder(
260
+ text_input_ids.to(device),
261
+ attention_mask=attention_mask,
262
+ )
263
+ prompt_embeds = prompt_embeds[0]
264
+
265
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
+
267
+ bs_embed, seq_len, _ = prompt_embeds.shape
268
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(
271
+ bs_embed * num_images_per_prompt, seq_len, -1
272
+ )
273
+
274
+ # get unconditional embeddings for classifier free guidance
275
+ if do_classifier_free_guidance:
276
+ uncond_tokens: List[str]
277
+ if negative_prompt is None:
278
+ uncond_tokens = [""] * batch_size
279
+ elif type(prompt) is not type(negative_prompt):
280
+ raise TypeError(
281
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
+ f" {type(prompt)}."
283
+ )
284
+ elif isinstance(negative_prompt, str):
285
+ uncond_tokens = [negative_prompt]
286
+ elif batch_size != len(negative_prompt):
287
+ raise ValueError(
288
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
+ " the batch size of `prompt`."
291
+ )
292
+ else:
293
+ uncond_tokens = negative_prompt
294
+
295
+ max_length = prompt_embeds.shape[1]
296
+ uncond_input = self.tokenizer(
297
+ uncond_tokens,
298
+ padding="max_length",
299
+ max_length=max_length,
300
+ truncation=True,
301
+ return_tensors="pt",
302
+ )
303
+
304
+ if (
305
+ hasattr(self.text_encoder.config, "use_attention_mask")
306
+ and self.text_encoder.config.use_attention_mask
307
+ ):
308
+ attention_mask = uncond_input.attention_mask.to(device)
309
+ else:
310
+ attention_mask = None
311
+
312
+ negative_prompt_embeds = self.text_encoder(
313
+ uncond_input.input_ids.to(device),
314
+ attention_mask=attention_mask,
315
+ )
316
+ negative_prompt_embeds = negative_prompt_embeds[0]
317
+
318
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
319
+ seq_len = negative_prompt_embeds.shape[1]
320
+
321
+ negative_prompt_embeds = negative_prompt_embeds.to(
322
+ dtype=self.text_encoder.dtype, device=device
323
+ )
324
+
325
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
326
+ 1, num_images_per_prompt, 1
327
+ )
328
+ negative_prompt_embeds = negative_prompt_embeds.view(
329
+ batch_size * num_images_per_prompt, seq_len, -1
330
+ )
331
+
332
+ # For classifier free guidance, we need to do two forward passes.
333
+ # Here we concatenate the unconditional and text embeddings into a single batch
334
+ # to avoid doing two forward passes
335
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
336
+
337
+ return prompt_embeds
338
+
339
+ def decode_latents(self, latents):
340
+ latents = 1 / self.vae.config.scaling_factor * latents
341
+ image = self.vae.decode(latents).sample
342
+ image = (image / 2 + 0.5).clamp(0, 1)
343
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
344
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
+ return image
346
+
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(
354
+ inspect.signature(self.scheduler.step).parameters.keys()
355
+ )
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(
362
+ inspect.signature(self.scheduler.step).parameters.keys()
363
+ )
364
+ if accepts_generator:
365
+ extra_step_kwargs["generator"] = generator
366
+ return extra_step_kwargs
367
+
368
+ def prepare_latents(
369
+ self,
370
+ batch_size,
371
+ num_channels_latents,
372
+ height,
373
+ width,
374
+ dtype,
375
+ device,
376
+ generator,
377
+ latents=None,
378
+ ):
379
+ shape = (
380
+ batch_size,
381
+ num_channels_latents,
382
+ height // self.vae_scale_factor,
383
+ width // self.vae_scale_factor,
384
+ )
385
+ if isinstance(generator, list) and len(generator) != batch_size:
386
+ raise ValueError(
387
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
388
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
389
+ )
390
+
391
+ if latents is None:
392
+ latents = randn_tensor(
393
+ shape, generator=generator, device=device, dtype=dtype
394
+ )
395
+ else:
396
+ latents = latents.to(device)
397
+
398
+ # scale the initial noise by the standard deviation required by the scheduler
399
+ latents = latents * self.scheduler.init_noise_sigma
400
+ return latents
401
+
402
+ def encode_image(self, image, device, num_images_per_prompt):
403
+ dtype = next(self.image_encoder.parameters()).dtype
404
+
405
+ if image.dtype == np.float32:
406
+ image = (image * 255).astype(np.uint8)
407
+
408
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
409
+ image = image.to(device=device, dtype=dtype)
410
+
411
+ image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
412
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
413
+
414
+ return torch.zeros_like(image_embeds), image_embeds
415
+
416
+ def encode_image_latents(self, image, device, num_images_per_prompt):
417
+
418
+ dtype = next(self.image_encoder.parameters()).dtype
419
+
420
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
421
+ image = 2 * image - 1
422
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
423
+ image = image.to(dtype=dtype)
424
+
425
+ posterior = self.vae.encode(image).latent_dist
426
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
427
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
428
+
429
+ return torch.zeros_like(latents), latents
430
+
431
+ @torch.no_grad()
432
+ def __call__(
433
+ self,
434
+ prompt: str = "",
435
+ image: Optional[np.ndarray] = None,
436
+ height: int = 256,
437
+ width: int = 256,
438
+ elevation: float = 0,
439
+ num_inference_steps: int = 50,
440
+ guidance_scale: float = 7.0,
441
+ negative_prompt: str = "",
442
+ num_images_per_prompt: int = 1,
443
+ eta: float = 0.0,
444
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
445
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
446
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
447
+ callback_steps: int = 1,
448
+ num_frames: int = 4,
449
+ device=torch.device("cuda:0"),
450
+ ):
451
+ self.unet = self.unet.to(device=device)
452
+ self.vae = self.vae.to(device=device)
453
+ self.text_encoder = self.text_encoder.to(device=device)
454
+
455
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
456
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
457
+ # corresponds to doing no classifier free guidance.
458
+ do_classifier_free_guidance = guidance_scale > 1.0
459
+
460
+ # Prepare timesteps
461
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
462
+ timesteps = self.scheduler.timesteps
463
+
464
+ # imagedream variant
465
+ if image is not None:
466
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
467
+ self.image_encoder = self.image_encoder.to(device=device)
468
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
469
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
470
+
471
+ _prompt_embeds = self._encode_prompt(
472
+ prompt=prompt,
473
+ device=device,
474
+ num_images_per_prompt=num_images_per_prompt,
475
+ do_classifier_free_guidance=do_classifier_free_guidance,
476
+ negative_prompt=negative_prompt,
477
+ ) # type: ignore
478
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
479
+
480
+ # Prepare latent variables
481
+ actual_num_frames = num_frames if image is None else num_frames + 1
482
+ latents: torch.Tensor = self.prepare_latents(
483
+ actual_num_frames * num_images_per_prompt,
484
+ 4,
485
+ height,
486
+ width,
487
+ prompt_embeds_pos.dtype,
488
+ device,
489
+ generator,
490
+ None,
491
+ )
492
+
493
+ if image is not None:
494
+ camera = get_camera(num_frames, elevation=elevation, extra_view=True).to(dtype=latents.dtype, device=device)
495
+ else:
496
+ camera = get_camera(num_frames, elevation=elevation, extra_view=False).to(dtype=latents.dtype, device=device)
497
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
498
+
499
+ # Prepare extra step kwargs.
500
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
501
+
502
+ # Denoising loop
503
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
504
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
505
+ for i, t in enumerate(timesteps):
506
+ # expand the latents if we are doing classifier free guidance
507
+ multiplier = 2 if do_classifier_free_guidance else 1
508
+ latent_model_input = torch.cat([latents] * multiplier)
509
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
510
+
511
+ unet_inputs = {
512
+ 'x': latent_model_input,
513
+ 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
514
+ 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
515
+ 'num_frames': actual_num_frames,
516
+ 'camera': torch.cat([camera] * multiplier),
517
+ }
518
+
519
+ if image is not None:
520
+ unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
521
+ unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
522
+
523
+ # predict the noise residual
524
+ noise_pred = self.unet.forward(**unet_inputs)
525
+
526
+ # perform guidance
527
+ if do_classifier_free_guidance:
528
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
529
+ noise_pred = noise_pred_uncond + guidance_scale * (
530
+ noise_pred_text - noise_pred_uncond
531
+ )
532
+
533
+ # compute the previous noisy sample x_t -> x_t-1
534
+ latents: torch.Tensor = self.scheduler.step(
535
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
536
+ )[0]
537
+
538
+ # call the callback, if provided
539
+ if i == len(timesteps) - 1 or (
540
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
541
+ ):
542
+ progress_bar.update()
543
+ if callback is not None and i % callback_steps == 0:
544
+ callback(i, t, latents) # type: ignore
545
+
546
+ # Post-processing
547
+ if output_type == "latent":
548
+ image = latents
549
+ elif output_type == "pil":
550
+ image = self.decode_latents(latents)
551
+ image = self.numpy_to_pil(image)
552
+ else: # numpy
553
+ image = self.decode_latents(latents)
554
+
555
+ # Offload last model to CPU
556
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
557
+ self.final_offload_hook.offload()
558
+
559
+ return image
readme.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## L4GM: Large 4D Gaussian Reconstruction Model
3
+ <p align="center">
4
+ <img src="assets/teaser.jpg">
5
+ </p>
6
+
7
+ [**Paper**](https://arxiv.org/abs/2406.10324) | [**Project Page**](https://research.nvidia.com/labs/toronto-ai/l4gm/) | [**Model Weights**](https://huggingface.co/jiawei011/L4GM)
8
+
9
+ We present L4GM, the first 4D Large Reconstruction Model that produces animated objects from a single-view video input -- in a single feed-forward pass that takes only a second.
10
+
11
+ ---
12
+
13
+ ### Install
14
+ ```bash
15
+ conda env create -f environment.yml
16
+ conda activate l4gm
17
+ ```
18
+
19
+ ### Inference
20
+ Download pretrained [L4GM model](https://huggingface.co/jiawei011/L4GM/blob/main/recon.safetensors) and [4D interpolation model](https://huggingface.co/jiawei011/L4GM/blob/main/interp.safetensors) to `pretrained/recon.safetensors` and `pretrained/interp.safetensors` respectively.
21
+
22
+ Select an input video. Remove its background and crop it to 256x256 with third-party tools. We provide some processed examples in the `data_test` folder.
23
+
24
+ 1. Generate 3D by:
25
+ ```sh
26
+ python infer_3d.py big --workspace results --resume pretrained/recon.safetensors --num_frames 1 --test_path data_test/otter-on-surfboard_fg.mp4
27
+ ```
28
+
29
+ 2. Generate 4D by:
30
+ ```sh
31
+ python infer_4d.py big --workspace results --resume pretrained/recon.safetensors --interpresume pretrained/interp.safetensors --num_frames 16 --test_path data_test/otter-on-surfboard_fg.mp4
32
+ ```
33
+
34
+ ### Training
35
+ Render Objaverse with Blender scripts in the `blender_scripts` folder first.
36
+
37
+ Download pretrained [LGM](https://huggingface.co/ashawkey/LGM/blob/main/model_fixrot.safetensors) to `pretrained/model_fixrot.safetensors`.
38
+
39
+ L4GM model training:
40
+ ```sh
41
+ accelerate launch \
42
+ --config_file acc_configs/gpu8.yaml \
43
+ main.py big \
44
+ --workspace workspace_recon \
45
+ --resume pretrained/model_fixrot.safetensors \
46
+ --data_mode 4d \
47
+ --num_epochs 200 \
48
+ --prob_cam_jitter 0 \
49
+ --datalist data_train/datalist_8fps.txt \
50
+ ```
51
+ Our released checkpoint uses `--num_epochs 500`.
52
+
53
+ 4D Interpolation model training:
54
+ ```sh
55
+ accelerate launch \
56
+ --config_file acc_configs/gpu8.yaml \
57
+ main.py big \
58
+ --workspace workspace_interp \
59
+ --resume workspace_recon/model.safetensors \
60
+ --data_mode 4d_interp \
61
+ --num_frames 4 \
62
+ --num_epochs 200 \
63
+ --prob_cam_jitter 0 \
64
+ --prob_grid_distortion 0 \
65
+ --datalist data_train/datalist_24fps.txt \
66
+ ```
67
+
68
+ ### Citation
69
+ ```bib
70
+ @inproceedings{ren2024l4gm,
71
+ title={L4GM: Large 4D Gaussian Reconstruction Model},
72
+ author={Jiawei Ren and Kevin Xie and Ashkan Mirzaei and Hanxue Liang and Xiaohui Zeng and Karsten Kreis and Ziwei Liu and Antonio Torralba and Sanja Fidler and Seung Wook Kim and Huan Ling},
73
+ booktitle={Proceedings of Neural Information Processing Systems(NeurIPS)},
74
+ month = {Dec},
75
+ year={2024}
76
+ }
77
+ ```
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tyro
2
+ accelerate==0.28.0
3
+ imageio
4
+ imageio-ffmpeg
5
+ lpips
6
+ Pillow
7
+ safetensors
8
+ scikit-image
9
+ scikit-learn
10
+ scipy
11
+ tqdm
12
+ kiui >= 0.2.3
13
+ roma
14
+ plyfile
15
+
16
+ # mvdream
17
+ diffusers==0.27.2
18
+ huggingface_hub==0.23.5
19
+ transformers
20
+