Spaces:
Running
Running
Adarsh Patel
commited on
Commit
•
4baad62
1
Parent(s):
93adcf7
files added
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +1 -0
- LICENSE +201 -0
- README.md +7 -5
- app.py +391 -0
- configs/instant-mesh-base.yaml +22 -0
- configs/instant-mesh-large.yaml +22 -0
- configs/instant-nerf-base.yaml +21 -0
- configs/instant-nerf-large.yaml +21 -0
- examples/bird.jpg +0 -0
- examples/bubble_mart_blue.png +0 -0
- examples/cake.jpg +0 -0
- examples/cartoon_dinosaur.png +0 -0
- examples/cartoon_panda.png +3 -0
- examples/chair_armed.png +0 -0
- examples/chair_comfort.jpg +0 -0
- examples/chair_wood.jpg +0 -0
- examples/chest.jpg +0 -0
- examples/cute_horse.jpg +0 -0
- examples/cute_tiger.jpg +0 -0
- examples/earphone.jpg +0 -0
- examples/fox.jpg +0 -0
- examples/fruit.jpg +0 -0
- examples/fruit_elephant.jpg +0 -0
- examples/genshin_building.png +0 -0
- examples/genshin_teapot.png +0 -0
- examples/hatsune_miku.png +0 -0
- examples/house2.jpg +0 -0
- examples/mushroom_teapot.jpg +0 -0
- examples/pikachu.png +0 -0
- examples/plant.jpg +0 -0
- examples/robot.jpg +0 -0
- examples/sea_turtle.png +0 -0
- examples/skating_shoe.jpg +0 -0
- examples/sorting_board.png +0 -0
- examples/sword.png +0 -0
- examples/toy_car.jpg +0 -0
- examples/watermelon.png +0 -0
- examples/whitedog.png +0 -0
- examples/x_teapot.jpg +0 -0
- examples/x_toyduck.jpg +0 -0
- requirements.txt +23 -0
- src/__init__.py +0 -0
- src/data/__init__.py +0 -0
- src/data/objaverse.py +329 -0
- src/model.py +310 -0
- src/model_mesh.py +325 -0
- src/models/__init__.py +0 -0
- src/models/decoder/__init__.py +0 -0
- src/models/decoder/transformer.py +123 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/cartoon_panda.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: InstantMesh
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.26.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
short_description: Create a 3D model from an image in 10 seconds!
|
11 |
+
license: apache-2.0
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
|
3 |
+
import os
|
4 |
+
import imageio
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import rembg
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision.transforms import v2
|
10 |
+
from pytorch_lightning import seed_everything
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from tqdm import tqdm
|
14 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
15 |
+
|
16 |
+
from src.utils.train_util import instantiate_from_config
|
17 |
+
from src.utils.camera_util import (
|
18 |
+
FOV_to_intrinsics,
|
19 |
+
get_zero123plus_input_cameras,
|
20 |
+
get_circular_camera_poses,
|
21 |
+
)
|
22 |
+
from src.utils.mesh_util import save_obj, save_glb
|
23 |
+
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
|
24 |
+
|
25 |
+
import tempfile
|
26 |
+
from functools import partial
|
27 |
+
|
28 |
+
from huggingface_hub import hf_hub_download
|
29 |
+
|
30 |
+
import gradio as gr
|
31 |
+
|
32 |
+
|
33 |
+
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
34 |
+
"""
|
35 |
+
Get the rendering camera parameters.
|
36 |
+
"""
|
37 |
+
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
|
38 |
+
if is_flexicubes:
|
39 |
+
cameras = torch.linalg.inv(c2ws)
|
40 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
41 |
+
else:
|
42 |
+
extrinsics = c2ws.flatten(-2)
|
43 |
+
intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
|
44 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
45 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
46 |
+
return cameras
|
47 |
+
|
48 |
+
|
49 |
+
def images_to_video(images, output_path, fps=30):
|
50 |
+
# images: (N, C, H, W)
|
51 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
52 |
+
frames = []
|
53 |
+
for i in range(images.shape[0]):
|
54 |
+
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
|
55 |
+
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
56 |
+
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
57 |
+
assert frame.min() >= 0 and frame.max() <= 255, \
|
58 |
+
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
59 |
+
frames.append(frame)
|
60 |
+
imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
|
61 |
+
|
62 |
+
|
63 |
+
###############################################################################
|
64 |
+
# Configuration.
|
65 |
+
###############################################################################
|
66 |
+
|
67 |
+
import shutil
|
68 |
+
|
69 |
+
def find_cuda():
|
70 |
+
# Check if CUDA_HOME or CUDA_PATH environment variables are set
|
71 |
+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
72 |
+
|
73 |
+
if cuda_home and os.path.exists(cuda_home):
|
74 |
+
return cuda_home
|
75 |
+
|
76 |
+
# Search for the nvcc executable in the system's PATH
|
77 |
+
nvcc_path = shutil.which('nvcc')
|
78 |
+
|
79 |
+
if nvcc_path:
|
80 |
+
# Remove the 'bin/nvcc' part to get the CUDA installation path
|
81 |
+
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
|
82 |
+
return cuda_path
|
83 |
+
|
84 |
+
return None
|
85 |
+
|
86 |
+
cuda_path = find_cuda()
|
87 |
+
|
88 |
+
if cuda_path:
|
89 |
+
print(f"CUDA installation found at: {cuda_path}")
|
90 |
+
else:
|
91 |
+
print("CUDA installation not found")
|
92 |
+
|
93 |
+
config_path = 'configs/instant-mesh-large.yaml'
|
94 |
+
config = OmegaConf.load(config_path)
|
95 |
+
config_name = os.path.basename(config_path).replace('.yaml', '')
|
96 |
+
model_config = config.model_config
|
97 |
+
infer_config = config.infer_config
|
98 |
+
|
99 |
+
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
|
100 |
+
|
101 |
+
device = torch.device('cuda')
|
102 |
+
|
103 |
+
# load diffusion model
|
104 |
+
print('Loading diffusion model ...')
|
105 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
106 |
+
"sudo-ai/zero123plus-v1.2",
|
107 |
+
custom_pipeline="zero123plus",
|
108 |
+
torch_dtype=torch.float16,
|
109 |
+
)
|
110 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
111 |
+
pipeline.scheduler.config, timestep_spacing='trailing'
|
112 |
+
)
|
113 |
+
|
114 |
+
# load custom white-background UNet
|
115 |
+
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
116 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
117 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
118 |
+
|
119 |
+
pipeline = pipeline.to(device)
|
120 |
+
|
121 |
+
# load reconstruction model
|
122 |
+
print('Loading reconstruction model ...')
|
123 |
+
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
|
124 |
+
model = instantiate_from_config(model_config)
|
125 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
126 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
|
127 |
+
model.load_state_dict(state_dict, strict=True)
|
128 |
+
|
129 |
+
model = model.to(device)
|
130 |
+
|
131 |
+
print('Loading Finished!')
|
132 |
+
|
133 |
+
|
134 |
+
def check_input_image(input_image):
|
135 |
+
if input_image is None:
|
136 |
+
raise gr.Error("No image uploaded!")
|
137 |
+
|
138 |
+
|
139 |
+
def preprocess(input_image, do_remove_background):
|
140 |
+
|
141 |
+
rembg_session = rembg.new_session() if do_remove_background else None
|
142 |
+
|
143 |
+
if do_remove_background:
|
144 |
+
input_image = remove_background(input_image, rembg_session)
|
145 |
+
input_image = resize_foreground(input_image, 0.85)
|
146 |
+
|
147 |
+
return input_image
|
148 |
+
|
149 |
+
|
150 |
+
@spaces.GPU
|
151 |
+
def generate_mvs(input_image, sample_steps, sample_seed):
|
152 |
+
|
153 |
+
seed_everything(sample_seed)
|
154 |
+
|
155 |
+
# sampling
|
156 |
+
z123_image = pipeline(
|
157 |
+
input_image,
|
158 |
+
num_inference_steps=sample_steps
|
159 |
+
).images[0]
|
160 |
+
|
161 |
+
show_image = np.asarray(z123_image, dtype=np.uint8)
|
162 |
+
show_image = torch.from_numpy(show_image) # (960, 640, 3)
|
163 |
+
show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
164 |
+
show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
165 |
+
show_image = Image.fromarray(show_image.numpy())
|
166 |
+
|
167 |
+
return z123_image, show_image
|
168 |
+
|
169 |
+
|
170 |
+
@spaces.GPU
|
171 |
+
def make3d(images):
|
172 |
+
|
173 |
+
global model
|
174 |
+
if IS_FLEXICUBES:
|
175 |
+
model.init_flexicubes_geometry(device, use_renderer=False)
|
176 |
+
model = model.eval()
|
177 |
+
|
178 |
+
images = np.asarray(images, dtype=np.float32) / 255.0
|
179 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
180 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
181 |
+
|
182 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
|
183 |
+
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
|
184 |
+
|
185 |
+
images = images.unsqueeze(0).to(device)
|
186 |
+
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
187 |
+
|
188 |
+
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
189 |
+
print(mesh_fpath)
|
190 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
191 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
192 |
+
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
|
193 |
+
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
194 |
+
|
195 |
+
with torch.no_grad():
|
196 |
+
# get triplane
|
197 |
+
planes = model.forward_planes(images, input_cameras)
|
198 |
+
|
199 |
+
# # get video
|
200 |
+
# chunk_size = 20 if IS_FLEXICUBES else 1
|
201 |
+
# render_size = 384
|
202 |
+
|
203 |
+
# frames = []
|
204 |
+
# for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
205 |
+
# if IS_FLEXICUBES:
|
206 |
+
# frame = model.forward_geometry(
|
207 |
+
# planes,
|
208 |
+
# render_cameras[:, i:i+chunk_size],
|
209 |
+
# render_size=render_size,
|
210 |
+
# )['img']
|
211 |
+
# else:
|
212 |
+
# frame = model.synthesizer(
|
213 |
+
# planes,
|
214 |
+
# cameras=render_cameras[:, i:i+chunk_size],
|
215 |
+
# render_size=render_size,
|
216 |
+
# )['images_rgb']
|
217 |
+
# frames.append(frame)
|
218 |
+
# frames = torch.cat(frames, dim=1)
|
219 |
+
|
220 |
+
# images_to_video(
|
221 |
+
# frames[0],
|
222 |
+
# video_fpath,
|
223 |
+
# fps=30,
|
224 |
+
# )
|
225 |
+
|
226 |
+
# print(f"Video saved to {video_fpath}")
|
227 |
+
|
228 |
+
# get mesh
|
229 |
+
mesh_out = model.extract_mesh(
|
230 |
+
planes,
|
231 |
+
use_texture_map=False,
|
232 |
+
**infer_config,
|
233 |
+
)
|
234 |
+
|
235 |
+
vertices, faces, vertex_colors = mesh_out
|
236 |
+
vertices = vertices[:, [1, 2, 0]]
|
237 |
+
|
238 |
+
save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
|
239 |
+
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
240 |
+
|
241 |
+
print(f"Mesh saved to {mesh_fpath}")
|
242 |
+
|
243 |
+
return mesh_fpath, mesh_glb_fpath
|
244 |
+
|
245 |
+
|
246 |
+
_HEADER_ = '''
|
247 |
+
<h2><b>Welcome to 3DFusion!</b></h2>
|
248 |
+
<h2><a href='https://github.com/TencentARC/InstantMesh' target='_blank'><b>3D Mesh Generation from Single Images with 3DFusion</b></a></h2>
|
249 |
+
|
250 |
+
3DFusion is a cutting-edge, efficient 3D mesh generation tool based on the powerful LRM/Instant3D architecture.
|
251 |
+
|
252 |
+
Code and Original Framework: <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>InstantMesh GitHub</a>. Technical report: <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a>.
|
253 |
+
|
254 |
+
❗️**Important Notes:**
|
255 |
+
- This demo exports both `.obj` and `.glb` meshes, including vertex colors.
|
256 |
+
- The 3D mesh generation depends on the quality of generated multi-view images, so try different seed values (default: 42) for optimal results.
|
257 |
+
'''
|
258 |
+
|
259 |
+
_CITE_ = r"""
|
260 |
+
If you find **3DFusion** helpful, please give a ⭐ to the original <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>InstantMesh repository</a>. We appreciate the work of the TencentARC team! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/InstantMesh?style=social)](https://github.com/TencentARC/InstantMesh)
|
261 |
+
---
|
262 |
+
📝 **Citation**
|
263 |
+
|
264 |
+
If you use this work for research or applications, cite it as follows:
|
265 |
+
```bibtex
|
266 |
+
@article{xu2024instantmesh,
|
267 |
+
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
|
268 |
+
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
|
269 |
+
journal={arXiv preprint arXiv:2404.07191},
|
270 |
+
year={2024}
|
271 |
+
}
|
272 |
+
```
|
273 |
+
|
274 |
+
📋 **License**
|
275 |
+
|
276 |
+
Apache-2.0 LICENSE. Please refer to the [LICENSE file](https://huggingface.co/spaces/TencentARC/InstantMesh/blob/main/LICENSE) for details.
|
277 |
+
|
278 |
+
📧 **Contact**
|
279 |
+
|
280 |
+
If you have any questions, feel free to open a discussion or contact us at <b>bluestyle928@gmail.com</b>.
|
281 |
+
"""
|
282 |
+
|
283 |
+
|
284 |
+
with gr.Blocks() as demo:
|
285 |
+
gr.Markdown(_HEADER_)
|
286 |
+
with gr.Row(variant="panel"):
|
287 |
+
with gr.Column():
|
288 |
+
with gr.Row():
|
289 |
+
input_image = gr.Image(
|
290 |
+
label="Input Image",
|
291 |
+
image_mode="RGBA",
|
292 |
+
sources="upload",
|
293 |
+
#width=256,
|
294 |
+
#height=256,
|
295 |
+
type="pil",
|
296 |
+
elem_id="content_image",
|
297 |
+
)
|
298 |
+
processed_image = gr.Image(
|
299 |
+
label="Processed Image",
|
300 |
+
image_mode="RGBA",
|
301 |
+
#width=256,
|
302 |
+
#height=256,
|
303 |
+
type="pil",
|
304 |
+
interactive=False
|
305 |
+
)
|
306 |
+
with gr.Row():
|
307 |
+
with gr.Group():
|
308 |
+
do_remove_background = gr.Checkbox(
|
309 |
+
label="Remove Background", value=True
|
310 |
+
)
|
311 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
312 |
+
|
313 |
+
sample_steps = gr.Slider(
|
314 |
+
label="Sample Steps",
|
315 |
+
minimum=30,
|
316 |
+
maximum=75,
|
317 |
+
value=75,
|
318 |
+
step=5
|
319 |
+
)
|
320 |
+
|
321 |
+
with gr.Row():
|
322 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
323 |
+
|
324 |
+
with gr.Row(variant="panel"):
|
325 |
+
gr.Examples(
|
326 |
+
examples=[
|
327 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
328 |
+
],
|
329 |
+
inputs=[input_image],
|
330 |
+
label="Examples",
|
331 |
+
cache_examples=False,
|
332 |
+
examples_per_page=16
|
333 |
+
)
|
334 |
+
|
335 |
+
with gr.Column():
|
336 |
+
|
337 |
+
with gr.Row():
|
338 |
+
|
339 |
+
with gr.Column():
|
340 |
+
mv_show_images = gr.Image(
|
341 |
+
label="Generated Multi-views",
|
342 |
+
type="pil",
|
343 |
+
width=379,
|
344 |
+
interactive=False
|
345 |
+
)
|
346 |
+
|
347 |
+
# with gr.Column():
|
348 |
+
# output_video = gr.Video(
|
349 |
+
# label="video", format="mp4",
|
350 |
+
# width=379,
|
351 |
+
# autoplay=True,
|
352 |
+
# interactive=False
|
353 |
+
# )
|
354 |
+
|
355 |
+
with gr.Row():
|
356 |
+
with gr.Tab("OBJ"):
|
357 |
+
output_model_obj = gr.Model3D(
|
358 |
+
label="Output Model (OBJ Format)",
|
359 |
+
interactive=False,
|
360 |
+
)
|
361 |
+
gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
|
362 |
+
with gr.Tab("GLB"):
|
363 |
+
output_model_glb = gr.Model3D(
|
364 |
+
label="Output Model (GLB Format)",
|
365 |
+
interactive=False,
|
366 |
+
)
|
367 |
+
gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
|
368 |
+
|
369 |
+
with gr.Row():
|
370 |
+
gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
|
371 |
+
|
372 |
+
gr.Markdown(_CITE_)
|
373 |
+
|
374 |
+
mv_images = gr.State()
|
375 |
+
|
376 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
377 |
+
fn=preprocess,
|
378 |
+
inputs=[input_image, do_remove_background],
|
379 |
+
outputs=[processed_image],
|
380 |
+
).success(
|
381 |
+
fn=generate_mvs,
|
382 |
+
inputs=[processed_image, sample_steps, sample_seed],
|
383 |
+
outputs=[mv_images, mv_show_images]
|
384 |
+
|
385 |
+
).success(
|
386 |
+
fn=make3d,
|
387 |
+
inputs=[mv_images],
|
388 |
+
outputs=[output_model_obj, output_model_glb]
|
389 |
+
)
|
390 |
+
|
391 |
+
demo.launch()
|
configs/instant-mesh-base.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.InstantMesh
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 12
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 40
|
13 |
+
rendering_samples_per_ray: 96
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/instant_mesh_base.ckpt
|
21 |
+
texture_resolution: 1024
|
22 |
+
render_resolution: 512
|
configs/instant-mesh-large.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.InstantMesh
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/instant_mesh_large.ckpt
|
21 |
+
texture_resolution: 1024
|
22 |
+
render_resolution: 512
|
configs/instant-nerf-base.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm.InstantNeRF
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 12
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 40
|
13 |
+
rendering_samples_per_ray: 96
|
14 |
+
|
15 |
+
|
16 |
+
infer_config:
|
17 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
+
model_path: ckpts/instant_nerf_base.ckpt
|
19 |
+
mesh_threshold: 10.0
|
20 |
+
mesh_resolution: 256
|
21 |
+
render_resolution: 384
|
configs/instant-nerf-large.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm.InstantNeRF
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
|
15 |
+
|
16 |
+
infer_config:
|
17 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
+
model_path: ckpts/instant_nerf_large.ckpt
|
19 |
+
mesh_threshold: 10.0
|
20 |
+
mesh_resolution: 256
|
21 |
+
render_resolution: 384
|
examples/bird.jpg
ADDED
examples/bubble_mart_blue.png
ADDED
examples/cake.jpg
ADDED
examples/cartoon_dinosaur.png
ADDED
examples/cartoon_panda.png
ADDED
Git LFS Details
|
examples/chair_armed.png
ADDED
examples/chair_comfort.jpg
ADDED
examples/chair_wood.jpg
ADDED
examples/chest.jpg
ADDED
examples/cute_horse.jpg
ADDED
examples/cute_tiger.jpg
ADDED
examples/earphone.jpg
ADDED
examples/fox.jpg
ADDED
examples/fruit.jpg
ADDED
examples/fruit_elephant.jpg
ADDED
examples/genshin_building.png
ADDED
examples/genshin_teapot.png
ADDED
examples/hatsune_miku.png
ADDED
examples/house2.jpg
ADDED
examples/mushroom_teapot.jpg
ADDED
examples/pikachu.png
ADDED
examples/plant.jpg
ADDED
examples/robot.jpg
ADDED
examples/sea_turtle.png
ADDED
examples/skating_shoe.jpg
ADDED
examples/sorting_board.png
ADDED
examples/sword.png
ADDED
examples/toy_car.jpg
ADDED
examples/watermelon.png
ADDED
examples/whitedog.png
ADDED
examples/x_teapot.jpg
ADDED
examples/x_toyduck.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.0
|
2 |
+
torchvision==0.16.0
|
3 |
+
torchaudio==2.1.0
|
4 |
+
pytorch-lightning==2.1.2
|
5 |
+
einops
|
6 |
+
omegaconf
|
7 |
+
deepspeed
|
8 |
+
torchmetrics
|
9 |
+
webdataset
|
10 |
+
accelerate
|
11 |
+
tensorboard
|
12 |
+
PyMCubes
|
13 |
+
trimesh
|
14 |
+
rembg
|
15 |
+
transformers==4.34.1
|
16 |
+
diffusers==0.19.3
|
17 |
+
bitsandbytes
|
18 |
+
imageio[ffmpeg]
|
19 |
+
xatlas
|
20 |
+
plyfile
|
21 |
+
xformers==0.0.22.post7
|
22 |
+
git+https://github.com/NVlabs/nvdiffrast/
|
23 |
+
huggingface-hub
|
src/__init__.py
ADDED
File without changes
|
src/data/__init__.py
ADDED
File without changes
|
src/data/objaverse.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
import importlib
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import webdataset as wds
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torch.utils.data.distributed import DistributedSampler
|
19 |
+
from torchvision import transforms
|
20 |
+
|
21 |
+
from src.utils.train_util import instantiate_from_config
|
22 |
+
from src.utils.camera_util import (
|
23 |
+
FOV_to_intrinsics,
|
24 |
+
center_looking_at_camera_pose,
|
25 |
+
get_surrounding_views,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
batch_size=8,
|
33 |
+
num_workers=4,
|
34 |
+
train=None,
|
35 |
+
validation=None,
|
36 |
+
test=None,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.batch_size = batch_size
|
42 |
+
self.num_workers = num_workers
|
43 |
+
|
44 |
+
self.dataset_configs = dict()
|
45 |
+
if train is not None:
|
46 |
+
self.dataset_configs['train'] = train
|
47 |
+
if validation is not None:
|
48 |
+
self.dataset_configs['validation'] = validation
|
49 |
+
if test is not None:
|
50 |
+
self.dataset_configs['test'] = test
|
51 |
+
|
52 |
+
def setup(self, stage):
|
53 |
+
|
54 |
+
if stage in ['fit']:
|
55 |
+
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def train_dataloader(self):
|
60 |
+
|
61 |
+
sampler = DistributedSampler(self.datasets['train'])
|
62 |
+
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
63 |
+
|
64 |
+
def val_dataloader(self):
|
65 |
+
|
66 |
+
sampler = DistributedSampler(self.datasets['validation'])
|
67 |
+
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
68 |
+
|
69 |
+
def test_dataloader(self):
|
70 |
+
|
71 |
+
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
72 |
+
|
73 |
+
|
74 |
+
class ObjaverseData(Dataset):
|
75 |
+
def __init__(self,
|
76 |
+
root_dir='objaverse/',
|
77 |
+
meta_fname='valid_paths.json',
|
78 |
+
input_image_dir='rendering_random_32views',
|
79 |
+
target_image_dir='rendering_random_32views',
|
80 |
+
input_view_num=6,
|
81 |
+
target_view_num=2,
|
82 |
+
total_view_n=32,
|
83 |
+
fov=50,
|
84 |
+
camera_rotation=True,
|
85 |
+
validation=False,
|
86 |
+
):
|
87 |
+
self.root_dir = Path(root_dir)
|
88 |
+
self.input_image_dir = input_image_dir
|
89 |
+
self.target_image_dir = target_image_dir
|
90 |
+
|
91 |
+
self.input_view_num = input_view_num
|
92 |
+
self.target_view_num = target_view_num
|
93 |
+
self.total_view_n = total_view_n
|
94 |
+
self.fov = fov
|
95 |
+
self.camera_rotation = camera_rotation
|
96 |
+
|
97 |
+
with open(os.path.join(root_dir, meta_fname)) as f:
|
98 |
+
filtered_dict = json.load(f)
|
99 |
+
paths = filtered_dict['good_objs']
|
100 |
+
self.paths = paths
|
101 |
+
|
102 |
+
self.depth_scale = 4.0
|
103 |
+
|
104 |
+
total_objects = len(self.paths)
|
105 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.paths)
|
109 |
+
|
110 |
+
def load_im(self, path, color):
|
111 |
+
'''
|
112 |
+
replace background pixel with random color in rendering
|
113 |
+
'''
|
114 |
+
pil_img = Image.open(path)
|
115 |
+
|
116 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
117 |
+
alpha = image[:, :, 3:]
|
118 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
119 |
+
|
120 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
121 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
122 |
+
return image, alpha
|
123 |
+
|
124 |
+
def __getitem__(self, index):
|
125 |
+
# load data
|
126 |
+
while True:
|
127 |
+
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
|
128 |
+
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
|
129 |
+
|
130 |
+
indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
|
131 |
+
input_indices = indices[:self.input_view_num]
|
132 |
+
target_indices = indices[self.input_view_num:]
|
133 |
+
|
134 |
+
'''background color, default: white'''
|
135 |
+
bg_white = [1., 1., 1.]
|
136 |
+
bg_black = [0., 0., 0.]
|
137 |
+
|
138 |
+
image_list = []
|
139 |
+
alpha_list = []
|
140 |
+
depth_list = []
|
141 |
+
normal_list = []
|
142 |
+
pose_list = []
|
143 |
+
|
144 |
+
try:
|
145 |
+
input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
|
146 |
+
for idx in input_indices:
|
147 |
+
image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
|
148 |
+
normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
|
149 |
+
depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
150 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
151 |
+
pose = input_cameras[idx]
|
152 |
+
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
153 |
+
|
154 |
+
image_list.append(image)
|
155 |
+
alpha_list.append(alpha)
|
156 |
+
depth_list.append(depth)
|
157 |
+
normal_list.append(normal)
|
158 |
+
pose_list.append(pose)
|
159 |
+
|
160 |
+
target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
|
161 |
+
for idx in target_indices:
|
162 |
+
image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
|
163 |
+
normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
|
164 |
+
depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
165 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
166 |
+
pose = target_cameras[idx]
|
167 |
+
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
168 |
+
|
169 |
+
image_list.append(image)
|
170 |
+
alpha_list.append(alpha)
|
171 |
+
depth_list.append(depth)
|
172 |
+
normal_list.append(normal)
|
173 |
+
pose_list.append(pose)
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
print(e)
|
177 |
+
index = np.random.randint(0, len(self.paths))
|
178 |
+
continue
|
179 |
+
|
180 |
+
break
|
181 |
+
|
182 |
+
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
183 |
+
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
184 |
+
depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
|
185 |
+
normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
|
186 |
+
w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
|
187 |
+
c2ws = torch.linalg.inv(w2cs).float()
|
188 |
+
|
189 |
+
normals = normals * 2.0 - 1.0
|
190 |
+
normals = F.normalize(normals, dim=1)
|
191 |
+
normals = (normals + 1.0) / 2.0
|
192 |
+
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
193 |
+
|
194 |
+
# random rotation along z axis
|
195 |
+
if self.camera_rotation:
|
196 |
+
degree = np.random.uniform(0, math.pi * 2)
|
197 |
+
rot = torch.tensor([
|
198 |
+
[np.cos(degree), -np.sin(degree), 0, 0],
|
199 |
+
[np.sin(degree), np.cos(degree), 0, 0],
|
200 |
+
[0, 0, 1, 0],
|
201 |
+
[0, 0, 0, 1],
|
202 |
+
]).unsqueeze(0).float()
|
203 |
+
c2ws = torch.matmul(rot, c2ws)
|
204 |
+
|
205 |
+
# rotate normals
|
206 |
+
N, _, H, W = normals.shape
|
207 |
+
normals = normals * 2.0 - 1.0
|
208 |
+
normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
|
209 |
+
normals = F.normalize(normals, dim=1)
|
210 |
+
normals = (normals + 1.0) / 2.0
|
211 |
+
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
212 |
+
|
213 |
+
# random scaling
|
214 |
+
if np.random.rand() < 0.5:
|
215 |
+
scale = np.random.uniform(0.8, 1.0)
|
216 |
+
c2ws[:, :3, 3] *= scale
|
217 |
+
depths *= scale
|
218 |
+
|
219 |
+
# instrinsics of perspective cameras
|
220 |
+
K = FOV_to_intrinsics(self.fov)
|
221 |
+
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
|
222 |
+
|
223 |
+
data = {
|
224 |
+
'input_images': images[:self.input_view_num], # (6, 3, H, W)
|
225 |
+
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
|
226 |
+
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
|
227 |
+
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
|
228 |
+
'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
|
229 |
+
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
|
230 |
+
|
231 |
+
# lrm generator input and supervision
|
232 |
+
'target_images': images[self.input_view_num:], # (V, 3, H, W)
|
233 |
+
'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
|
234 |
+
'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
|
235 |
+
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
|
236 |
+
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
|
237 |
+
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
|
238 |
+
|
239 |
+
'depth_available': 1,
|
240 |
+
}
|
241 |
+
return data
|
242 |
+
|
243 |
+
|
244 |
+
class ValidationData(Dataset):
|
245 |
+
def __init__(self,
|
246 |
+
root_dir='objaverse/',
|
247 |
+
input_view_num=6,
|
248 |
+
input_image_size=256,
|
249 |
+
fov=50,
|
250 |
+
):
|
251 |
+
self.root_dir = Path(root_dir)
|
252 |
+
self.input_view_num = input_view_num
|
253 |
+
self.input_image_size = input_image_size
|
254 |
+
self.fov = fov
|
255 |
+
|
256 |
+
self.paths = sorted(os.listdir(self.root_dir))
|
257 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
258 |
+
|
259 |
+
cam_distance = 2.5
|
260 |
+
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
261 |
+
elevations = np.array([30, -20, 30, -20, 30, -20])
|
262 |
+
azimuths = np.deg2rad(azimuths)
|
263 |
+
elevations = np.deg2rad(elevations)
|
264 |
+
|
265 |
+
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
|
266 |
+
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
|
267 |
+
z = cam_distance * np.sin(elevations)
|
268 |
+
|
269 |
+
cam_locations = np.stack([x, y, z], axis=-1)
|
270 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
271 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
272 |
+
self.c2ws = c2ws.float()
|
273 |
+
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
274 |
+
|
275 |
+
render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
|
276 |
+
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
277 |
+
self.render_c2ws = render_c2ws.float()
|
278 |
+
self.render_Ks = render_Ks.float()
|
279 |
+
|
280 |
+
def __len__(self):
|
281 |
+
return len(self.paths)
|
282 |
+
|
283 |
+
def load_im(self, path, color):
|
284 |
+
'''
|
285 |
+
replace background pixel with random color in rendering
|
286 |
+
'''
|
287 |
+
pil_img = Image.open(path)
|
288 |
+
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
|
289 |
+
|
290 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
291 |
+
if image.shape[-1] == 4:
|
292 |
+
alpha = image[:, :, 3:]
|
293 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
294 |
+
else:
|
295 |
+
alpha = np.ones_like(image[:, :, :1])
|
296 |
+
|
297 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
298 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
299 |
+
return image, alpha
|
300 |
+
|
301 |
+
def __getitem__(self, index):
|
302 |
+
# load data
|
303 |
+
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
304 |
+
|
305 |
+
'''background color, default: white'''
|
306 |
+
# color = np.random.uniform(0.48, 0.52)
|
307 |
+
bkg_color = [1.0, 1.0, 1.0]
|
308 |
+
|
309 |
+
image_list = []
|
310 |
+
alpha_list = []
|
311 |
+
|
312 |
+
for idx in range(self.input_view_num):
|
313 |
+
image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
|
314 |
+
image_list.append(image)
|
315 |
+
alpha_list.append(alpha)
|
316 |
+
|
317 |
+
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
318 |
+
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
319 |
+
|
320 |
+
data = {
|
321 |
+
'input_images': images, # (6, 3, H, W)
|
322 |
+
'input_alphas': alphas, # (6, 1, H, W)
|
323 |
+
'input_c2ws': self.c2ws, # (6, 4, 4)
|
324 |
+
'input_Ks': self.Ks, # (6, 3, 3)
|
325 |
+
|
326 |
+
'render_c2ws': self.render_c2ws,
|
327 |
+
'render_Ks': self.render_Ks,
|
328 |
+
}
|
329 |
+
return data
|
src/model.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms import v2
|
6 |
+
from torchvision.utils import make_grid, save_image
|
7 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from src.utils.train_util import instantiate_from_config
|
12 |
+
|
13 |
+
|
14 |
+
class MVRecon(pl.LightningModule):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
lrm_generator_config,
|
18 |
+
lrm_path=None,
|
19 |
+
input_size=256,
|
20 |
+
render_size=192,
|
21 |
+
):
|
22 |
+
super(MVRecon, self).__init__()
|
23 |
+
|
24 |
+
self.input_size = input_size
|
25 |
+
self.render_size = render_size
|
26 |
+
|
27 |
+
# init modules
|
28 |
+
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
29 |
+
if lrm_path is not None:
|
30 |
+
lrm_ckpt = torch.load(lrm_path)
|
31 |
+
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
|
32 |
+
|
33 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
34 |
+
|
35 |
+
self.validation_step_outputs = []
|
36 |
+
|
37 |
+
def on_fit_start(self):
|
38 |
+
if self.global_rank == 0:
|
39 |
+
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
40 |
+
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
41 |
+
|
42 |
+
def prepare_batch_data(self, batch):
|
43 |
+
lrm_generator_input = {}
|
44 |
+
render_gt = {} # for supervision
|
45 |
+
|
46 |
+
# input images
|
47 |
+
images = batch['input_images']
|
48 |
+
images = v2.functional.resize(
|
49 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
50 |
+
|
51 |
+
lrm_generator_input['images'] = images.to(self.device)
|
52 |
+
|
53 |
+
# input cameras and render cameras
|
54 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
55 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
56 |
+
target_c2ws = batch['target_c2ws'].flatten(-2)
|
57 |
+
target_Ks = batch['target_Ks'].flatten(-2)
|
58 |
+
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
|
59 |
+
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
|
60 |
+
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
|
61 |
+
|
62 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
63 |
+
input_intrinsics = torch.stack([
|
64 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
65 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
66 |
+
], dim=-1)
|
67 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
68 |
+
|
69 |
+
# add noise to input cameras
|
70 |
+
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
71 |
+
|
72 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
73 |
+
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
74 |
+
|
75 |
+
# target images
|
76 |
+
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
77 |
+
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
78 |
+
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
79 |
+
|
80 |
+
# random crop
|
81 |
+
render_size = np.random.randint(self.render_size, 513)
|
82 |
+
target_images = v2.functional.resize(
|
83 |
+
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
84 |
+
target_depths = v2.functional.resize(
|
85 |
+
target_depths, render_size, interpolation=0, antialias=True)
|
86 |
+
target_alphas = v2.functional.resize(
|
87 |
+
target_alphas, render_size, interpolation=0, antialias=True)
|
88 |
+
|
89 |
+
crop_params = v2.RandomCrop.get_params(
|
90 |
+
target_images, output_size=(self.render_size, self.render_size))
|
91 |
+
target_images = v2.functional.crop(target_images, *crop_params)
|
92 |
+
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
|
93 |
+
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
|
94 |
+
|
95 |
+
lrm_generator_input['render_size'] = render_size
|
96 |
+
lrm_generator_input['crop_params'] = crop_params
|
97 |
+
|
98 |
+
render_gt['target_images'] = target_images.to(self.device)
|
99 |
+
render_gt['target_depths'] = target_depths.to(self.device)
|
100 |
+
render_gt['target_alphas'] = target_alphas.to(self.device)
|
101 |
+
|
102 |
+
return lrm_generator_input, render_gt
|
103 |
+
|
104 |
+
def prepare_validation_batch_data(self, batch):
|
105 |
+
lrm_generator_input = {}
|
106 |
+
|
107 |
+
# input images
|
108 |
+
images = batch['input_images']
|
109 |
+
images = v2.functional.resize(
|
110 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
111 |
+
|
112 |
+
lrm_generator_input['images'] = images.to(self.device)
|
113 |
+
|
114 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
115 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
116 |
+
|
117 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
118 |
+
input_intrinsics = torch.stack([
|
119 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
120 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
121 |
+
], dim=-1)
|
122 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
123 |
+
|
124 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
125 |
+
|
126 |
+
render_c2ws = batch['render_c2ws'].flatten(-2)
|
127 |
+
render_Ks = batch['render_Ks'].flatten(-2)
|
128 |
+
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
|
129 |
+
|
130 |
+
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
131 |
+
lrm_generator_input['render_size'] = 384
|
132 |
+
lrm_generator_input['crop_params'] = None
|
133 |
+
|
134 |
+
return lrm_generator_input
|
135 |
+
|
136 |
+
def forward_lrm_generator(
|
137 |
+
self,
|
138 |
+
images,
|
139 |
+
cameras,
|
140 |
+
render_cameras,
|
141 |
+
render_size=192,
|
142 |
+
crop_params=None,
|
143 |
+
chunk_size=1,
|
144 |
+
):
|
145 |
+
planes = torch.utils.checkpoint.checkpoint(
|
146 |
+
self.lrm_generator.forward_planes,
|
147 |
+
images,
|
148 |
+
cameras,
|
149 |
+
use_reentrant=False,
|
150 |
+
)
|
151 |
+
frames = []
|
152 |
+
for i in range(0, render_cameras.shape[1], chunk_size):
|
153 |
+
frames.append(
|
154 |
+
torch.utils.checkpoint.checkpoint(
|
155 |
+
self.lrm_generator.synthesizer,
|
156 |
+
planes,
|
157 |
+
cameras=render_cameras[:, i:i+chunk_size],
|
158 |
+
render_size=render_size,
|
159 |
+
crop_params=crop_params,
|
160 |
+
use_reentrant=False
|
161 |
+
)
|
162 |
+
)
|
163 |
+
frames = {
|
164 |
+
k: torch.cat([r[k] for r in frames], dim=1)
|
165 |
+
for k in frames[0].keys()
|
166 |
+
}
|
167 |
+
return frames
|
168 |
+
|
169 |
+
def forward(self, lrm_generator_input):
|
170 |
+
images = lrm_generator_input['images']
|
171 |
+
cameras = lrm_generator_input['cameras']
|
172 |
+
render_cameras = lrm_generator_input['render_cameras']
|
173 |
+
render_size = lrm_generator_input['render_size']
|
174 |
+
crop_params = lrm_generator_input['crop_params']
|
175 |
+
|
176 |
+
out = self.forward_lrm_generator(
|
177 |
+
images,
|
178 |
+
cameras,
|
179 |
+
render_cameras,
|
180 |
+
render_size=render_size,
|
181 |
+
crop_params=crop_params,
|
182 |
+
chunk_size=1,
|
183 |
+
)
|
184 |
+
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
|
185 |
+
render_depths = out['images_depth']
|
186 |
+
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
|
187 |
+
|
188 |
+
out = {
|
189 |
+
'render_images': render_images,
|
190 |
+
'render_depths': render_depths,
|
191 |
+
'render_alphas': render_alphas,
|
192 |
+
}
|
193 |
+
return out
|
194 |
+
|
195 |
+
def training_step(self, batch, batch_idx):
|
196 |
+
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
197 |
+
|
198 |
+
render_out = self.forward(lrm_generator_input)
|
199 |
+
|
200 |
+
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
201 |
+
|
202 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
203 |
+
|
204 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
205 |
+
B, N, C, H, W = render_gt['target_images'].shape
|
206 |
+
N_in = lrm_generator_input['images'].shape[1]
|
207 |
+
|
208 |
+
input_images = v2.functional.resize(
|
209 |
+
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
|
210 |
+
input_images = torch.cat(
|
211 |
+
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
|
212 |
+
|
213 |
+
input_images = rearrange(
|
214 |
+
input_images, 'b n c h w -> b c h (n w)')
|
215 |
+
target_images = rearrange(
|
216 |
+
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
217 |
+
render_images = rearrange(
|
218 |
+
render_out['render_images'], 'b n c h w -> b c h (n w)')
|
219 |
+
target_alphas = rearrange(
|
220 |
+
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
221 |
+
render_alphas = rearrange(
|
222 |
+
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
223 |
+
target_depths = rearrange(
|
224 |
+
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
225 |
+
render_depths = rearrange(
|
226 |
+
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
227 |
+
MAX_DEPTH = torch.max(target_depths)
|
228 |
+
target_depths = target_depths / MAX_DEPTH * target_alphas
|
229 |
+
render_depths = render_depths / MAX_DEPTH
|
230 |
+
|
231 |
+
grid = torch.cat([
|
232 |
+
input_images,
|
233 |
+
target_images, render_images,
|
234 |
+
target_alphas, render_alphas,
|
235 |
+
target_depths, render_depths,
|
236 |
+
], dim=-2)
|
237 |
+
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
238 |
+
|
239 |
+
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
|
240 |
+
|
241 |
+
return loss
|
242 |
+
|
243 |
+
def compute_loss(self, render_out, render_gt):
|
244 |
+
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
245 |
+
render_images = render_out['render_images']
|
246 |
+
target_images = render_gt['target_images'].to(render_images)
|
247 |
+
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
248 |
+
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
249 |
+
|
250 |
+
loss_mse = F.mse_loss(render_images, target_images)
|
251 |
+
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
252 |
+
|
253 |
+
render_alphas = render_out['render_alphas']
|
254 |
+
target_alphas = render_gt['target_alphas']
|
255 |
+
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
256 |
+
|
257 |
+
loss = loss_mse + loss_lpips + loss_mask
|
258 |
+
|
259 |
+
prefix = 'train'
|
260 |
+
loss_dict = {}
|
261 |
+
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
262 |
+
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
263 |
+
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
264 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
265 |
+
|
266 |
+
return loss, loss_dict
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
def validation_step(self, batch, batch_idx):
|
270 |
+
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
271 |
+
|
272 |
+
render_out = self.forward(lrm_generator_input)
|
273 |
+
render_images = render_out['render_images']
|
274 |
+
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
275 |
+
|
276 |
+
self.validation_step_outputs.append(render_images)
|
277 |
+
|
278 |
+
def on_validation_epoch_end(self):
|
279 |
+
images = torch.cat(self.validation_step_outputs, dim=-1)
|
280 |
+
|
281 |
+
all_images = self.all_gather(images)
|
282 |
+
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
283 |
+
|
284 |
+
if self.global_rank == 0:
|
285 |
+
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
286 |
+
|
287 |
+
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
288 |
+
save_image(grid, image_path)
|
289 |
+
print(f"Saved image to {image_path}")
|
290 |
+
|
291 |
+
self.validation_step_outputs.clear()
|
292 |
+
|
293 |
+
def configure_optimizers(self):
|
294 |
+
lr = self.learning_rate
|
295 |
+
|
296 |
+
params = []
|
297 |
+
|
298 |
+
lrm_params_fast, lrm_params_slow = [], []
|
299 |
+
for n, p in self.lrm_generator.named_parameters():
|
300 |
+
if 'adaLN_modulation' in n or 'camera_embedder' in n:
|
301 |
+
lrm_params_fast.append(p)
|
302 |
+
else:
|
303 |
+
lrm_params_slow.append(p)
|
304 |
+
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
|
305 |
+
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
|
306 |
+
|
307 |
+
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
|
308 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
|
309 |
+
|
310 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
src/model_mesh.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms import v2
|
6 |
+
from torchvision.utils import make_grid, save_image
|
7 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from src.utils.train_util import instantiate_from_config
|
12 |
+
|
13 |
+
|
14 |
+
# Regulrarization loss for FlexiCubes
|
15 |
+
def sdf_reg_loss_batch(sdf, all_edges):
|
16 |
+
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
|
17 |
+
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
18 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
19 |
+
sdf_diff = F.binary_cross_entropy_with_logits(
|
20 |
+
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
21 |
+
F.binary_cross_entropy_with_logits(
|
22 |
+
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
23 |
+
return sdf_diff
|
24 |
+
|
25 |
+
|
26 |
+
class MVRecon(pl.LightningModule):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
lrm_generator_config,
|
30 |
+
input_size=256,
|
31 |
+
render_size=512,
|
32 |
+
init_ckpt=None,
|
33 |
+
):
|
34 |
+
super(MVRecon, self).__init__()
|
35 |
+
|
36 |
+
self.input_size = input_size
|
37 |
+
self.render_size = render_size
|
38 |
+
|
39 |
+
# init modules
|
40 |
+
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
41 |
+
|
42 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
43 |
+
|
44 |
+
# Load weights from pretrained MVRecon model, and use the mlp
|
45 |
+
# weights to initialize the weights of sdf and rgb mlps.
|
46 |
+
if init_ckpt is not None:
|
47 |
+
sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
|
48 |
+
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
|
49 |
+
sd_fc = {}
|
50 |
+
for k, v in sd.items():
|
51 |
+
if k.startswith('lrm_generator.synthesizer.decoder.net.'):
|
52 |
+
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
|
53 |
+
# Here we assume the density filed's isosurface threshold is t,
|
54 |
+
# we reverse the sign of density filed to initialize SDF field.
|
55 |
+
# -(w*x + b - t) = (-w)*x + (t - b)
|
56 |
+
if 'weight' in k:
|
57 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
|
58 |
+
else:
|
59 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
|
60 |
+
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
|
61 |
+
else:
|
62 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = v
|
63 |
+
sd_fc[k.replace('net.', 'net_rgb.')] = v
|
64 |
+
else:
|
65 |
+
sd_fc[k] = v
|
66 |
+
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
|
67 |
+
# missing `net_deformation` and `net_weight` parameters
|
68 |
+
self.lrm_generator.load_state_dict(sd_fc, strict=False)
|
69 |
+
print(f'Loaded weights from {init_ckpt}')
|
70 |
+
|
71 |
+
self.validation_step_outputs = []
|
72 |
+
|
73 |
+
def on_fit_start(self):
|
74 |
+
device = torch.device(f'cuda:{self.global_rank}')
|
75 |
+
self.lrm_generator.init_flexicubes_geometry(device)
|
76 |
+
if self.global_rank == 0:
|
77 |
+
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
78 |
+
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
79 |
+
|
80 |
+
def prepare_batch_data(self, batch):
|
81 |
+
lrm_generator_input = {}
|
82 |
+
render_gt = {}
|
83 |
+
|
84 |
+
# input images
|
85 |
+
images = batch['input_images']
|
86 |
+
images = v2.functional.resize(
|
87 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
88 |
+
|
89 |
+
lrm_generator_input['images'] = images.to(self.device)
|
90 |
+
|
91 |
+
# input cameras and render cameras
|
92 |
+
input_c2ws = batch['input_c2ws']
|
93 |
+
input_Ks = batch['input_Ks']
|
94 |
+
target_c2ws = batch['target_c2ws']
|
95 |
+
|
96 |
+
render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
|
97 |
+
render_w2cs = torch.linalg.inv(render_c2ws)
|
98 |
+
|
99 |
+
input_extrinsics = input_c2ws.flatten(-2)
|
100 |
+
input_extrinsics = input_extrinsics[:, :, :12]
|
101 |
+
input_intrinsics = input_Ks.flatten(-2)
|
102 |
+
input_intrinsics = torch.stack([
|
103 |
+
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
|
104 |
+
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
|
105 |
+
], dim=-1)
|
106 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
107 |
+
|
108 |
+
# add noise to input_cameras
|
109 |
+
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
110 |
+
|
111 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
112 |
+
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
113 |
+
|
114 |
+
# target images
|
115 |
+
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
116 |
+
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
117 |
+
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
118 |
+
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
|
119 |
+
|
120 |
+
render_size = self.render_size
|
121 |
+
target_images = v2.functional.resize(
|
122 |
+
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
123 |
+
target_depths = v2.functional.resize(
|
124 |
+
target_depths, render_size, interpolation=0, antialias=True)
|
125 |
+
target_alphas = v2.functional.resize(
|
126 |
+
target_alphas, render_size, interpolation=0, antialias=True)
|
127 |
+
target_normals = v2.functional.resize(
|
128 |
+
target_normals, render_size, interpolation=3, antialias=True)
|
129 |
+
|
130 |
+
lrm_generator_input['render_size'] = render_size
|
131 |
+
|
132 |
+
render_gt['target_images'] = target_images.to(self.device)
|
133 |
+
render_gt['target_depths'] = target_depths.to(self.device)
|
134 |
+
render_gt['target_alphas'] = target_alphas.to(self.device)
|
135 |
+
render_gt['target_normals'] = target_normals.to(self.device)
|
136 |
+
|
137 |
+
return lrm_generator_input, render_gt
|
138 |
+
|
139 |
+
def prepare_validation_batch_data(self, batch):
|
140 |
+
lrm_generator_input = {}
|
141 |
+
|
142 |
+
# input images
|
143 |
+
images = batch['input_images']
|
144 |
+
images = v2.functional.resize(
|
145 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
146 |
+
|
147 |
+
lrm_generator_input['images'] = images.to(self.device)
|
148 |
+
|
149 |
+
# input cameras
|
150 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
151 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
152 |
+
|
153 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
154 |
+
input_intrinsics = torch.stack([
|
155 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
156 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
157 |
+
], dim=-1)
|
158 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
159 |
+
|
160 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
161 |
+
|
162 |
+
# render cameras
|
163 |
+
render_c2ws = batch['render_c2ws']
|
164 |
+
render_w2cs = torch.linalg.inv(render_c2ws)
|
165 |
+
|
166 |
+
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
167 |
+
lrm_generator_input['render_size'] = 384
|
168 |
+
|
169 |
+
return lrm_generator_input
|
170 |
+
|
171 |
+
def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
|
172 |
+
planes = torch.utils.checkpoint.checkpoint(
|
173 |
+
self.lrm_generator.forward_planes,
|
174 |
+
images,
|
175 |
+
cameras,
|
176 |
+
use_reentrant=False,
|
177 |
+
)
|
178 |
+
out = self.lrm_generator.forward_geometry(
|
179 |
+
planes,
|
180 |
+
render_cameras,
|
181 |
+
render_size,
|
182 |
+
)
|
183 |
+
return out
|
184 |
+
|
185 |
+
def forward(self, lrm_generator_input):
|
186 |
+
images = lrm_generator_input['images']
|
187 |
+
cameras = lrm_generator_input['cameras']
|
188 |
+
render_cameras = lrm_generator_input['render_cameras']
|
189 |
+
render_size = lrm_generator_input['render_size']
|
190 |
+
|
191 |
+
out = self.forward_lrm_generator(
|
192 |
+
images, cameras, render_cameras, render_size=render_size)
|
193 |
+
|
194 |
+
return out
|
195 |
+
|
196 |
+
def training_step(self, batch, batch_idx):
|
197 |
+
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
198 |
+
|
199 |
+
render_out = self.forward(lrm_generator_input)
|
200 |
+
|
201 |
+
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
202 |
+
|
203 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
204 |
+
|
205 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
206 |
+
B, N, C, H, W = render_gt['target_images'].shape
|
207 |
+
N_in = lrm_generator_input['images'].shape[1]
|
208 |
+
|
209 |
+
target_images = rearrange(
|
210 |
+
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
211 |
+
render_images = rearrange(
|
212 |
+
render_out['img'], 'b n c h w -> b c h (n w)')
|
213 |
+
target_alphas = rearrange(
|
214 |
+
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
215 |
+
render_alphas = rearrange(
|
216 |
+
repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
217 |
+
target_depths = rearrange(
|
218 |
+
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
219 |
+
render_depths = rearrange(
|
220 |
+
repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
221 |
+
target_normals = rearrange(
|
222 |
+
render_gt['target_normals'], 'b n c h w -> b c h (n w)')
|
223 |
+
render_normals = rearrange(
|
224 |
+
render_out['normal'], 'b n c h w -> b c h (n w)')
|
225 |
+
MAX_DEPTH = torch.max(target_depths)
|
226 |
+
target_depths = target_depths / MAX_DEPTH * target_alphas
|
227 |
+
render_depths = render_depths / MAX_DEPTH
|
228 |
+
|
229 |
+
grid = torch.cat([
|
230 |
+
target_images, render_images,
|
231 |
+
target_alphas, render_alphas,
|
232 |
+
target_depths, render_depths,
|
233 |
+
target_normals, render_normals,
|
234 |
+
], dim=-2)
|
235 |
+
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
236 |
+
|
237 |
+
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
|
238 |
+
save_image(grid, image_path)
|
239 |
+
print(f"Saved image to {image_path}")
|
240 |
+
|
241 |
+
return loss
|
242 |
+
|
243 |
+
def compute_loss(self, render_out, render_gt):
|
244 |
+
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
245 |
+
render_images = render_out['img']
|
246 |
+
target_images = render_gt['target_images'].to(render_images)
|
247 |
+
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
248 |
+
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
249 |
+
loss_mse = F.mse_loss(render_images, target_images)
|
250 |
+
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
251 |
+
|
252 |
+
render_alphas = render_out['mask']
|
253 |
+
target_alphas = render_gt['target_alphas']
|
254 |
+
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
255 |
+
|
256 |
+
render_depths = render_out['depth']
|
257 |
+
target_depths = render_gt['target_depths']
|
258 |
+
loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
|
259 |
+
|
260 |
+
render_normals = render_out['normal'] * 2.0 - 1.0
|
261 |
+
target_normals = render_gt['target_normals'] * 2.0 - 1.0
|
262 |
+
similarity = (render_normals * target_normals).sum(dim=-3).abs()
|
263 |
+
normal_mask = target_alphas.squeeze(-3)
|
264 |
+
loss_normal = 1 - similarity[normal_mask>0].mean()
|
265 |
+
loss_normal = 0.2 * loss_normal
|
266 |
+
|
267 |
+
# flexicubes regularization loss
|
268 |
+
sdf = render_out['sdf']
|
269 |
+
sdf_reg_loss = render_out['sdf_reg_loss']
|
270 |
+
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
|
271 |
+
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
|
272 |
+
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
|
273 |
+
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
|
274 |
+
|
275 |
+
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
|
276 |
+
|
277 |
+
loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
|
278 |
+
|
279 |
+
prefix = 'train'
|
280 |
+
loss_dict = {}
|
281 |
+
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
282 |
+
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
283 |
+
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
284 |
+
loss_dict.update({f'{prefix}/loss_normal': loss_normal})
|
285 |
+
loss_dict.update({f'{prefix}/loss_depth': loss_depth})
|
286 |
+
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
|
287 |
+
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
|
288 |
+
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
|
289 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
290 |
+
|
291 |
+
return loss, loss_dict
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def validation_step(self, batch, batch_idx):
|
295 |
+
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
296 |
+
|
297 |
+
render_out = self.forward(lrm_generator_input)
|
298 |
+
render_images = render_out['img']
|
299 |
+
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
300 |
+
|
301 |
+
self.validation_step_outputs.append(render_images)
|
302 |
+
|
303 |
+
def on_validation_epoch_end(self):
|
304 |
+
images = torch.cat(self.validation_step_outputs, dim=-1)
|
305 |
+
|
306 |
+
all_images = self.all_gather(images)
|
307 |
+
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
308 |
+
|
309 |
+
if self.global_rank == 0:
|
310 |
+
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
311 |
+
|
312 |
+
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
313 |
+
save_image(grid, image_path)
|
314 |
+
print(f"Saved image to {image_path}")
|
315 |
+
|
316 |
+
self.validation_step_outputs.clear()
|
317 |
+
|
318 |
+
def configure_optimizers(self):
|
319 |
+
lr = self.learning_rate
|
320 |
+
|
321 |
+
optimizer = torch.optim.AdamW(
|
322 |
+
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
|
323 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
|
324 |
+
|
325 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
src/models/__init__.py
ADDED
File without changes
|
src/models/decoder/__init__.py
ADDED
File without changes
|
src/models/decoder/transformer.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class BasicTransformerBlock(nn.Module):
|
21 |
+
"""
|
22 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
23 |
+
"""
|
24 |
+
# use attention from torch.nn.MultiHeadAttention
|
25 |
+
# Block contains a cross-attention layer, a self-attention layer, and a MLP
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
inner_dim: int,
|
29 |
+
cond_dim: int,
|
30 |
+
num_heads: int,
|
31 |
+
eps: float,
|
32 |
+
attn_drop: float = 0.,
|
33 |
+
attn_bias: bool = False,
|
34 |
+
mlp_ratio: float = 4.,
|
35 |
+
mlp_drop: float = 0.,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.norm1 = nn.LayerNorm(inner_dim)
|
40 |
+
self.cross_attn = nn.MultiheadAttention(
|
41 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
42 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
43 |
+
self.norm2 = nn.LayerNorm(inner_dim)
|
44 |
+
self.self_attn = nn.MultiheadAttention(
|
45 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
46 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
47 |
+
self.norm3 = nn.LayerNorm(inner_dim)
|
48 |
+
self.mlp = nn.Sequential(
|
49 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
50 |
+
nn.GELU(),
|
51 |
+
nn.Dropout(mlp_drop),
|
52 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
53 |
+
nn.Dropout(mlp_drop),
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, cond):
|
57 |
+
# x: [N, L, D]
|
58 |
+
# cond: [N, L_cond, D_cond]
|
59 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
|
60 |
+
before_sa = self.norm2(x)
|
61 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
|
62 |
+
x = x + self.mlp(self.norm3(x))
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class TriplaneTransformer(nn.Module):
|
67 |
+
"""
|
68 |
+
Transformer with condition that generates a triplane representation.
|
69 |
+
|
70 |
+
Reference:
|
71 |
+
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
|
72 |
+
"""
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
inner_dim: int,
|
76 |
+
image_feat_dim: int,
|
77 |
+
triplane_low_res: int,
|
78 |
+
triplane_high_res: int,
|
79 |
+
triplane_dim: int,
|
80 |
+
num_layers: int,
|
81 |
+
num_heads: int,
|
82 |
+
eps: float = 1e-6,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
# attributes
|
87 |
+
self.triplane_low_res = triplane_low_res
|
88 |
+
self.triplane_high_res = triplane_high_res
|
89 |
+
self.triplane_dim = triplane_dim
|
90 |
+
|
91 |
+
# modules
|
92 |
+
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
93 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
|
94 |
+
self.layers = nn.ModuleList([
|
95 |
+
BasicTransformerBlock(
|
96 |
+
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
|
97 |
+
for _ in range(num_layers)
|
98 |
+
])
|
99 |
+
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
100 |
+
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
101 |
+
|
102 |
+
def forward(self, image_feats):
|
103 |
+
# image_feats: [N, L_cond, D_cond]
|
104 |
+
|
105 |
+
N = image_feats.shape[0]
|
106 |
+
H = W = self.triplane_low_res
|
107 |
+
L = 3 * H * W
|
108 |
+
|
109 |
+
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
110 |
+
for layer in self.layers:
|
111 |
+
x = layer(x, image_feats)
|
112 |
+
x = self.norm(x)
|
113 |
+
|
114 |
+
# separate each plane and apply deconv
|
115 |
+
x = x.view(N, 3, H, W, -1)
|
116 |
+
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
117 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
118 |
+
x = self.deconv(x) # [3*N, D', H', W']
|
119 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
120 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
121 |
+
x = x.contiguous()
|
122 |
+
|
123 |
+
return x
|