Spaces:
Running
on
A10G
Running
on
A10G
init
Browse files- .gitattributes +1 -0
- .gitignore +20 -0
- LICENSE +21 -0
- LICENSE_GAUSSIAN_SPLATTING.md +83 -0
- README.md +0 -13
- app.py +105 -0
- cam_utils.py +146 -0
- configs/image.yaml +69 -0
- configs/text.yaml +68 -0
- data/anya_rgba.png +3 -0
- data/catstatue_rgba.png +3 -0
- data/csm_luigi_rgba.png +3 -0
- data/test.png +3 -0
- data/zelda_rgba.png +3 -0
- grid_put.py +300 -0
- gs_renderer.py +820 -0
- guidance/sd_utils.py +334 -0
- guidance/zero123_utils.py +226 -0
- main.py +882 -0
- main2.py +671 -0
- mesh.py +622 -0
- mesh_renderer.py +154 -0
- mesh_utils.py +147 -0
- process.py +92 -0
- readme.md +139 -0
- requirements.txt +37 -0
- scripts/convert_obj_to_video.py +20 -0
- scripts/run.sh +5 -0
- scripts/run_sd.sh +31 -0
- scripts/runall.py +48 -0
- scripts/runall_sd.py +45 -0
- sh_utils.py +118 -0
- simple-knn/ext.cpp +17 -0
- simple-knn/setup.py +35 -0
- simple-knn/simple_knn.cu +221 -0
- simple-knn/simple_knn.h +21 -0
- simple-knn/simple_knn/.gitkeep +0 -0
- simple-knn/spatial.cu +26 -0
- simple-knn/spatial.h +14 -0
- zero123.py +666 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
build/
|
3 |
+
*.egg-info/
|
4 |
+
*.so
|
5 |
+
venv_*/
|
6 |
+
.vs/
|
7 |
+
.vscode/
|
8 |
+
.idea/
|
9 |
+
|
10 |
+
tmp_*
|
11 |
+
data?
|
12 |
+
data??
|
13 |
+
scripts2
|
14 |
+
|
15 |
+
model_cache
|
16 |
+
|
17 |
+
logs
|
18 |
+
videos
|
19 |
+
images
|
20 |
+
*.mp4
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 dreamgaussian
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LICENSE_GAUSSIAN_SPLATTING.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Gaussian-Splatting License
|
2 |
+
===========================
|
3 |
+
|
4 |
+
**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
|
5 |
+
The *Software* is in the process of being registered with the Agence pour la Protection des
|
6 |
+
Programmes (APP).
|
7 |
+
|
8 |
+
The *Software* is still being developed by the *Licensor*.
|
9 |
+
|
10 |
+
*Licensor*'s goal is to allow the research community to use, test and evaluate
|
11 |
+
the *Software*.
|
12 |
+
|
13 |
+
## 1. Definitions
|
14 |
+
|
15 |
+
*Licensee* means any person or entity that uses the *Software* and distributes
|
16 |
+
its *Work*.
|
17 |
+
|
18 |
+
*Licensor* means the owners of the *Software*, i.e Inria and MPII
|
19 |
+
|
20 |
+
*Software* means the original work of authorship made available under this
|
21 |
+
License ie gaussian-splatting.
|
22 |
+
|
23 |
+
*Work* means the *Software* and any additions to or derivative works of the
|
24 |
+
*Software* that are made available under this License.
|
25 |
+
|
26 |
+
|
27 |
+
## 2. Purpose
|
28 |
+
This license is intended to define the rights granted to the *Licensee* by
|
29 |
+
Licensors under the *Software*.
|
30 |
+
|
31 |
+
## 3. Rights granted
|
32 |
+
|
33 |
+
For the above reasons Licensors have decided to distribute the *Software*.
|
34 |
+
Licensors grant non-exclusive rights to use the *Software* for research purposes
|
35 |
+
to research users (both academic and industrial), free of charge, without right
|
36 |
+
to sublicense.. The *Software* may be used "non-commercially", i.e., for research
|
37 |
+
and/or evaluation purposes only.
|
38 |
+
|
39 |
+
Subject to the terms and conditions of this License, you are granted a
|
40 |
+
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
|
41 |
+
publicly display, publicly perform and distribute its *Work* and any resulting
|
42 |
+
derivative works in any form.
|
43 |
+
|
44 |
+
## 4. Limitations
|
45 |
+
|
46 |
+
**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
|
47 |
+
so under this License, (b) you include a complete copy of this License with
|
48 |
+
your distribution, and (c) you retain without modification any copyright,
|
49 |
+
patent, trademark, or attribution notices that are present in the *Work*.
|
50 |
+
|
51 |
+
**4.2 Derivative Works.** You may specify that additional or different terms apply
|
52 |
+
to the use, reproduction, and distribution of your derivative works of the *Work*
|
53 |
+
("Your Terms") only if (a) Your Terms provide that the use limitation in
|
54 |
+
Section 2 applies to your derivative works, and (b) you identify the specific
|
55 |
+
derivative works that are subject to Your Terms. Notwithstanding Your Terms,
|
56 |
+
this License (including the redistribution requirements in Section 3.1) will
|
57 |
+
continue to apply to the *Work* itself.
|
58 |
+
|
59 |
+
**4.3** Any other use without of prior consent of Licensors is prohibited. Research
|
60 |
+
users explicitly acknowledge having received from Licensors all information
|
61 |
+
allowing to appreciate the adequacy between of the *Software* and their needs and
|
62 |
+
to undertake all necessary precautions for its execution and use.
|
63 |
+
|
64 |
+
**4.4** The *Software* is provided both as a compiled library file and as source
|
65 |
+
code. In case of using the *Software* for a publication or other results obtained
|
66 |
+
through the use of the *Software*, users are strongly encouraged to cite the
|
67 |
+
corresponding publications as explained in the documentation of the *Software*.
|
68 |
+
|
69 |
+
## 5. Disclaimer
|
70 |
+
|
71 |
+
THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
|
72 |
+
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
|
73 |
+
UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
|
74 |
+
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
|
75 |
+
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
|
76 |
+
USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
|
77 |
+
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
|
78 |
+
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
79 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
80 |
+
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
|
81 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
82 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
|
83 |
+
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
|
README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Dreamgaussian
|
3 |
-
emoji: 🌍
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.47.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
|
7 |
+
# check if there is a picture uploaded or selected
|
8 |
+
def check_img_input(control_image):
|
9 |
+
if control_image is None:
|
10 |
+
raise gr.Error("Please select or upload an input image")
|
11 |
+
|
12 |
+
|
13 |
+
def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, elevation_slider: float):
|
14 |
+
if not os.path.exists('tmp_data'):
|
15 |
+
os.makedirs('tmp_data')
|
16 |
+
if preprocess_chk:
|
17 |
+
# save image to a designated path
|
18 |
+
image_block.save('tmp_data/tmp.png')
|
19 |
+
|
20 |
+
# preprocess image
|
21 |
+
subprocess.run([f'python process.py tmp_data/tmp.png'], shell=True)
|
22 |
+
else:
|
23 |
+
image_block.save('tmp_data/tmp_rgba.png')
|
24 |
+
|
25 |
+
# stage 1
|
26 |
+
subprocess.run([
|
27 |
+
f'python main.py --config configs/image.yaml input=tmp_data/tmp_rgba.png save_path=tmp mesh_format=glb elevation={elevation_slider} force_cuda_rast=True'],
|
28 |
+
shell=True)
|
29 |
+
|
30 |
+
return f'logs/tmp_mesh.glb'
|
31 |
+
|
32 |
+
|
33 |
+
def optimize_stage_2(elevation_slider: float):
|
34 |
+
# stage 2
|
35 |
+
subprocess.run([
|
36 |
+
f'python main2.py --config configs/image.yaml input=tmp_data/tmp_rgba.png save_path=tmp mesh_format=glb elevation={elevation_slider} force_cuda_rast=True'],
|
37 |
+
shell=True)
|
38 |
+
|
39 |
+
return f'logs/tmp.glb'
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
_TITLE = '''DreamGaussian: Generative Gaussian Splatting for Efficient 3D Content Creation'''
|
44 |
+
|
45 |
+
_DESCRIPTION = '''
|
46 |
+
<div>
|
47 |
+
<a style="display:inline-block" href="https://dreamgaussian.github.io"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
|
48 |
+
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.16653"><img src="https://img.shields.io/badge/2306.16928-f9f7f7?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADcAAABMCAYAAADJPi9EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAa2SURBVHja3Zt7bBRFGMAXUCDGF4rY7m7bAwuhlggKStFgLBgFEkCIIRJEEoOBYHwRFYKilUgEReVNJEGCJJpehHI3M9vZvd3bUP1DjNhEIRQQsQgSHiJgQZ5dv7krWEvvdmZ7d7vHJN+ft/f99pv5XvOtJMFCqvoCUpTdIEeRLC+L9Ox5i3Q9LACaCeK0kXoSChVcD3C/tQPHpAEsquQ73IkUcEz2kcLCknyGW5MGjkljRFVL8xJOKyi4CwCOuQAeAkfTP1+tNxLkogvgEbDgffkJqKqvuMA5ifOpqg/5qWecRstNg7xoUTI1Fovdxg8oy2s5AP8CGeYHmGngeZaOL4I4LXLcpHg4149/GDz4xqgsb+UAbMKKUpkrqHA43MUyyJpWUK0EHeG2YKRXr7tB+QMcgGewLD+ebTDbtrtbBt7UPlhS4rV4IvcDI7J8P1OeA/AcAI7LHljN7aB8XTowJmZt9EFRD/o0SDMH4HlwMhMyDWZZSAHFf3YDs3RS49WDLuaAY3IJq+qzmQKLxXAZKN7oDoYbdV3v5elPqiSpMyiOuAEVZVqHXb1OhloUH+MA+ztO0cAO/RkrfyBE7OAEbAZvO8vzVtTRWFD6DAfY5biBM3PWiaL0a4lvXICwnV8WjmE6ntYmhqX2jjp5LbMZjCw/wbYeN6CizOa2GMVzQOlmHjB4Ceuyk6LJ8huccEmR5Xddg7OOV/NAtchW+E3XbOag60QA4Qwuarca0bRuEJyr+cFQwzcY98huxhAKdQelt4kAQpj4qJ3gvFXAYn+aJumXk1yPlpQUgtIHhbYoFMUstNRRWgjnpl4A7IKlayNymqFHFaWCpV9CFry3LGxR1CgA5kB5M8OX2goApwpaz6mdOMGxtAgXWJySxb4WuQD4qTDgU+N5AAnzpr7ChSWpCyisiQJqY0Y7FtmSKpbV23b45kC0KHBxcQ9QeI8w4KgnHRPVtIU7rOtbioLVg5Hl/qDwSVFAMqLSMSObroCdZYlzIJtMRFVHCaRo/wFWPgaAXzdbBpkc2A4aKzCNd97+URQuESYGDDhIVfWOQIKZJu4D2+oXlgDTV1865gUQZDts756BArMNMoR1oa46BYqbyPixZz1ZUFV3sgwoGBajuBKATl3btIn8QYYMuezRgrsiRUWyr2BxA40EkPMpA/Hm6gbUu7fjEXA3azP6AsbKD9bxdUuhjM9W7fII52BF+daRpE4+WA3P501+jbfmHvQKyFqMuXf7Ot4mkN2fr50y+bRH61X7AXdUpHSxaPQ4GVbR5AGw3g+434XgQGKfr72I+vQRhfsu92dOx7WicInzt3CBg1RVpMm0NveWo2SqFzgmdNZMbriILD+S+zoueWf2vSdAipzacWN5nMl6XxNlUHa/J8DoJodUDE0HR8Ll5V0lPxcrLEHZPV4AzS83OLis7FowVa3RSku7BSNxJqQAlN3hBTC2apmDSkpaw22wJemGQFUG7J4MlP3JC6A+f96V7vRyX9It3nzT/GrjIU8edM7rMSnIi10f476lzbE1K7yEiEuWro0OJBguLCwDuFOJc1Na6sRWL/cCeMIwUN9ggSVbe3v/5/EgzTKWLvEAiBrYRUkgwNI2ZaFQNT75UDxEUEx97zYnzpmiLEmbaYCbNxYtFAb0/Z4AztgUrhyxuNgxPnhfHFDHz/vTgFWUQZxTRkkJhQ6YNdVUEPAfO6ZV5BRss6LcCVb7VaAma9giy0XJZBt9IQh42NY0NSdgbLIPlLUF6rEdrdt0CUCK1wsCbkcI3ZSLc7ZSwGLbmJXbPsNxnE5xilYKAobZ77LpGZ8TAIun+/iCKQoF71IxQDI3K2CCd+ARNvXg9sykBcnHAoCZG4u66hlDoQLe6QV4CRtFSxZQ+D0BwNO2jgdkzoGoah1nj3FVlSR19taTSYxI8QLut23U8dsgzqHulJNCQpcqBnpTALCuQ6NSYLHpmR5i42gZzuIdcrMMvMJbQlxe3jXxyZnLACl7ARm/FjPIDOY8ODtpM71sxwfcZpvBeUzKWmfNINM5AS+wO0Khh7dMqKccu4+qatarZjYAwDlgetzStHtEt+XedsBOQtU9XMrRgjg4KTnc5nr+dmqadit/4C4uLm8DuA9koJTj1TL7fI5nDL+qqoo/FLGAzL7dYT17PzvAcQONYSUQRxW/QMrHZVIyik0ZuQA2mzp+Ji8BW4YM3Mbzm9inaHkJCGfrUZZjujiYailfFwA8DHIy3acwUj4v9vUVa+SmgNsl5fuyDTKovW9/IAmfLV0Pi2UncA515kjYdrwC9i9rpuHiq3JwtAAAAABJRU5ErkJggg=="></a>
|
49 |
+
<a style="display:inline-block; margin-left: .5em" href='https://github.com/dreamgaussian/dreamgaussian'><img src='https://img.shields.io/github/stars/dreamgaussian/dreamgaussian?style=social'/></a>
|
50 |
+
</div>
|
51 |
+
We present DreamGausssion, a 3D content generation framework that significantly improves the efficiency of 3D content creation.
|
52 |
+
'''
|
53 |
+
_IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above) and click **Generate 3D**."
|
54 |
+
|
55 |
+
# load images in 'data' folder as examples
|
56 |
+
example_folder = os.path.join(os.path.dirname(__file__), 'data')
|
57 |
+
example_fns = os.listdir(example_folder)
|
58 |
+
example_fns.sort()
|
59 |
+
examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
|
60 |
+
|
61 |
+
# Compose demo layout & data flow
|
62 |
+
with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo:
|
63 |
+
with gr.Row():
|
64 |
+
with gr.Column(scale=1):
|
65 |
+
gr.Markdown('# ' + _TITLE)
|
66 |
+
gr.Markdown(_DESCRIPTION)
|
67 |
+
|
68 |
+
# Image-to-3D
|
69 |
+
with gr.Row(variant='panel'):
|
70 |
+
with gr.Column(scale=5):
|
71 |
+
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image', tool=None)
|
72 |
+
|
73 |
+
elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
74 |
+
gr.Markdown(
|
75 |
+
"default to 0 (horizontal), range from [-90, 90]. If you upload a look-down image, try a value like -30")
|
76 |
+
|
77 |
+
preprocess_chk = gr.Checkbox(True,
|
78 |
+
label='Preprocess image automatically (remove background and recenter object)')
|
79 |
+
|
80 |
+
gr.Examples(
|
81 |
+
examples=examples_full, # NOTE: elements must match inputs list!
|
82 |
+
inputs=[image_block],
|
83 |
+
outputs=[image_block],
|
84 |
+
cache_examples=False,
|
85 |
+
label='Examples (click one of the images below to start)',
|
86 |
+
examples_per_page=40
|
87 |
+
)
|
88 |
+
img_run_btn = gr.Button("Generate 3D")
|
89 |
+
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
90 |
+
|
91 |
+
with gr.Column(scale=5):
|
92 |
+
obj3d_stage1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Stage 1)")
|
93 |
+
obj3d = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Final)")
|
94 |
+
|
95 |
+
# if there is an input image, continue with inference
|
96 |
+
# else display an error message
|
97 |
+
img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
|
98 |
+
inputs=[image_block,
|
99 |
+
preprocess_chk,
|
100 |
+
elevation_slider],
|
101 |
+
outputs=[
|
102 |
+
obj3d_stage1]).success(
|
103 |
+
optimize_stage_2, inputs=[elevation_slider], outputs=[obj3d])
|
104 |
+
|
105 |
+
demo.queue().launch(share=True)
|
cam_utils.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial.transform import Rotation as R
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def dot(x, y):
|
7 |
+
if isinstance(x, np.ndarray):
|
8 |
+
return np.sum(x * y, -1, keepdims=True)
|
9 |
+
else:
|
10 |
+
return torch.sum(x * y, -1, keepdim=True)
|
11 |
+
|
12 |
+
|
13 |
+
def length(x, eps=1e-20):
|
14 |
+
if isinstance(x, np.ndarray):
|
15 |
+
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
|
16 |
+
else:
|
17 |
+
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
|
18 |
+
|
19 |
+
|
20 |
+
def safe_normalize(x, eps=1e-20):
|
21 |
+
return x / length(x, eps)
|
22 |
+
|
23 |
+
|
24 |
+
def look_at(campos, target, opengl=True):
|
25 |
+
# campos: [N, 3], camera/eye position
|
26 |
+
# target: [N, 3], object to look at
|
27 |
+
# return: [N, 3, 3], rotation matrix
|
28 |
+
if not opengl:
|
29 |
+
# camera forward aligns with -z
|
30 |
+
forward_vector = safe_normalize(target - campos)
|
31 |
+
up_vector = np.array([0, 1, 0], dtype=np.float32)
|
32 |
+
right_vector = safe_normalize(np.cross(forward_vector, up_vector))
|
33 |
+
up_vector = safe_normalize(np.cross(right_vector, forward_vector))
|
34 |
+
else:
|
35 |
+
# camera forward aligns with +z
|
36 |
+
forward_vector = safe_normalize(campos - target)
|
37 |
+
up_vector = np.array([0, 1, 0], dtype=np.float32)
|
38 |
+
right_vector = safe_normalize(np.cross(up_vector, forward_vector))
|
39 |
+
up_vector = safe_normalize(np.cross(forward_vector, right_vector))
|
40 |
+
R = np.stack([right_vector, up_vector, forward_vector], axis=1)
|
41 |
+
return R
|
42 |
+
|
43 |
+
|
44 |
+
# elevation & azimuth to pose (cam2world) matrix
|
45 |
+
def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
|
46 |
+
# radius: scalar
|
47 |
+
# elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
|
48 |
+
# azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
|
49 |
+
# return: [4, 4], camera pose matrix
|
50 |
+
if is_degree:
|
51 |
+
elevation = np.deg2rad(elevation)
|
52 |
+
azimuth = np.deg2rad(azimuth)
|
53 |
+
x = radius * np.cos(elevation) * np.sin(azimuth)
|
54 |
+
y = - radius * np.sin(elevation)
|
55 |
+
z = radius * np.cos(elevation) * np.cos(azimuth)
|
56 |
+
if target is None:
|
57 |
+
target = np.zeros([3], dtype=np.float32)
|
58 |
+
campos = np.array([x, y, z]) + target # [3]
|
59 |
+
T = np.eye(4, dtype=np.float32)
|
60 |
+
T[:3, :3] = look_at(campos, target, opengl)
|
61 |
+
T[:3, 3] = campos
|
62 |
+
return T
|
63 |
+
|
64 |
+
|
65 |
+
class OrbitCamera:
|
66 |
+
def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
|
67 |
+
self.W = W
|
68 |
+
self.H = H
|
69 |
+
self.radius = r # camera distance from center
|
70 |
+
self.fovy = np.deg2rad(fovy) # deg 2 rad
|
71 |
+
self.near = near
|
72 |
+
self.far = far
|
73 |
+
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
|
74 |
+
self.rot = R.from_matrix(np.eye(3))
|
75 |
+
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
|
76 |
+
|
77 |
+
@property
|
78 |
+
def fovx(self):
|
79 |
+
return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
|
80 |
+
|
81 |
+
@property
|
82 |
+
def campos(self):
|
83 |
+
return self.pose[:3, 3]
|
84 |
+
|
85 |
+
# pose (c2w)
|
86 |
+
@property
|
87 |
+
def pose(self):
|
88 |
+
# first move camera to radius
|
89 |
+
res = np.eye(4, dtype=np.float32)
|
90 |
+
res[2, 3] = self.radius # opengl convention...
|
91 |
+
# rotate
|
92 |
+
rot = np.eye(4, dtype=np.float32)
|
93 |
+
rot[:3, :3] = self.rot.as_matrix()
|
94 |
+
res = rot @ res
|
95 |
+
# translate
|
96 |
+
res[:3, 3] -= self.center
|
97 |
+
return res
|
98 |
+
|
99 |
+
# view (w2c)
|
100 |
+
@property
|
101 |
+
def view(self):
|
102 |
+
return np.linalg.inv(self.pose)
|
103 |
+
|
104 |
+
# projection (perspective)
|
105 |
+
@property
|
106 |
+
def perspective(self):
|
107 |
+
y = np.tan(self.fovy / 2)
|
108 |
+
aspect = self.W / self.H
|
109 |
+
return np.array(
|
110 |
+
[
|
111 |
+
[1 / (y * aspect), 0, 0, 0],
|
112 |
+
[0, -1 / y, 0, 0],
|
113 |
+
[
|
114 |
+
0,
|
115 |
+
0,
|
116 |
+
-(self.far + self.near) / (self.far - self.near),
|
117 |
+
-(2 * self.far * self.near) / (self.far - self.near),
|
118 |
+
],
|
119 |
+
[0, 0, -1, 0],
|
120 |
+
],
|
121 |
+
dtype=np.float32,
|
122 |
+
)
|
123 |
+
|
124 |
+
# intrinsics
|
125 |
+
@property
|
126 |
+
def intrinsics(self):
|
127 |
+
focal = self.H / (2 * np.tan(self.fovy / 2))
|
128 |
+
return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
|
129 |
+
|
130 |
+
@property
|
131 |
+
def mvp(self):
|
132 |
+
return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
|
133 |
+
|
134 |
+
def orbit(self, dx, dy):
|
135 |
+
# rotate along camera up/side axis!
|
136 |
+
side = self.rot.as_matrix()[:3, 0]
|
137 |
+
rotvec_x = self.up * np.radians(-0.05 * dx)
|
138 |
+
rotvec_y = side * np.radians(-0.05 * dy)
|
139 |
+
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
|
140 |
+
|
141 |
+
def scale(self, delta):
|
142 |
+
self.radius *= 1.1 ** (-delta)
|
143 |
+
|
144 |
+
def pan(self, dx, dy, dz=0):
|
145 |
+
# pan in camera coordinate system (careful on the sensitivity!)
|
146 |
+
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
|
configs/image.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Input
|
2 |
+
# input rgba image path (default to None, can be load in GUI too)
|
3 |
+
input:
|
4 |
+
# input text prompt (default to None, can be input in GUI too)
|
5 |
+
prompt:
|
6 |
+
# input mesh for stage 2 (auto-search from stage 1 output path if None)
|
7 |
+
mesh:
|
8 |
+
# estimated elevation angle for input image
|
9 |
+
elevation: 0
|
10 |
+
# reference image resolution
|
11 |
+
ref_size: 256
|
12 |
+
# density thresh for mesh extraction
|
13 |
+
density_thresh: 1
|
14 |
+
|
15 |
+
### Output
|
16 |
+
outdir: logs
|
17 |
+
mesh_format: obj
|
18 |
+
save_path: ???
|
19 |
+
|
20 |
+
### Training
|
21 |
+
# guidance loss weights (0 to disable)
|
22 |
+
lambda_sd: 0
|
23 |
+
lambda_zero123: 1
|
24 |
+
# training batch size per iter
|
25 |
+
batch_size: 1
|
26 |
+
# training iterations for stage 1
|
27 |
+
iters: 500
|
28 |
+
# training iterations for stage 2
|
29 |
+
iters_refine: 50
|
30 |
+
# training camera radius
|
31 |
+
radius: 2
|
32 |
+
# training camera fovy
|
33 |
+
fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61
|
34 |
+
# checkpoint to load for stage 1 (should be a ply file)
|
35 |
+
load:
|
36 |
+
# whether allow geom training in stage 2
|
37 |
+
train_geo: False
|
38 |
+
# prob to invert background color during training (0 = always black, 1 = always white)
|
39 |
+
invert_bg_prob: 0.5
|
40 |
+
|
41 |
+
|
42 |
+
### GUI
|
43 |
+
gui: False
|
44 |
+
force_cuda_rast: False
|
45 |
+
# GUI resolution
|
46 |
+
H: 800
|
47 |
+
W: 800
|
48 |
+
|
49 |
+
### Gaussian splatting
|
50 |
+
num_pts: 5000
|
51 |
+
sh_degree: 0
|
52 |
+
position_lr_init: 0.001
|
53 |
+
position_lr_final: 0.00002
|
54 |
+
position_lr_delay_mult: 0.02
|
55 |
+
position_lr_max_steps: 500
|
56 |
+
feature_lr: 0.01
|
57 |
+
opacity_lr: 0.05
|
58 |
+
scaling_lr: 0.005
|
59 |
+
rotation_lr: 0.005
|
60 |
+
percent_dense: 0.1
|
61 |
+
density_start_iter: 100
|
62 |
+
density_end_iter: 3000
|
63 |
+
densification_interval: 100
|
64 |
+
opacity_reset_interval: 700
|
65 |
+
densify_grad_threshold: 0.5
|
66 |
+
|
67 |
+
### Textured Mesh
|
68 |
+
geom_lr: 0.0001
|
69 |
+
texture_lr: 0.2
|
configs/text.yaml
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Input
|
2 |
+
# input rgba image path (default to None, can be load in GUI too)
|
3 |
+
input:
|
4 |
+
# input text prompt (default to None, can be input in GUI too)
|
5 |
+
prompt:
|
6 |
+
# input mesh for stage 2 (auto-search from stage 1 output path if None)
|
7 |
+
mesh:
|
8 |
+
# estimated elevation angle for input image
|
9 |
+
elevation: 0
|
10 |
+
# reference image resolution
|
11 |
+
ref_size: 256
|
12 |
+
# density thresh for mesh extraction
|
13 |
+
density_thresh: 1
|
14 |
+
|
15 |
+
### Output
|
16 |
+
outdir: logs
|
17 |
+
mesh_format: obj
|
18 |
+
save_path: ???
|
19 |
+
|
20 |
+
### Training
|
21 |
+
# guidance loss weights (0 to disable)
|
22 |
+
lambda_sd: 1
|
23 |
+
lambda_zero123: 0
|
24 |
+
# training batch size per iter
|
25 |
+
batch_size: 1
|
26 |
+
# training iterations for stage 1
|
27 |
+
iters: 500
|
28 |
+
# training iterations for stage 2
|
29 |
+
iters_refine: 50
|
30 |
+
# training camera radius
|
31 |
+
radius: 2.5
|
32 |
+
# training camera fovy
|
33 |
+
fovy: 49.1
|
34 |
+
# checkpoint to load for stage 1 (should be a ply file)
|
35 |
+
load:
|
36 |
+
# whether allow geom training in stage 2
|
37 |
+
train_geo: False
|
38 |
+
# prob to invert background color during training (0 = always black, 1 = always white)
|
39 |
+
invert_bg_prob: 0.5
|
40 |
+
|
41 |
+
### GUI
|
42 |
+
gui: False
|
43 |
+
force_cuda_rast: False
|
44 |
+
# GUI resolution
|
45 |
+
H: 800
|
46 |
+
W: 800
|
47 |
+
|
48 |
+
### Gaussian splatting
|
49 |
+
num_pts: 1000
|
50 |
+
sh_degree: 0
|
51 |
+
position_lr_init: 0.001
|
52 |
+
position_lr_final: 0.00002
|
53 |
+
position_lr_delay_mult: 0.02
|
54 |
+
position_lr_max_steps: 500
|
55 |
+
feature_lr: 0.01
|
56 |
+
opacity_lr: 0.05
|
57 |
+
scaling_lr: 0.005
|
58 |
+
rotation_lr: 0.005
|
59 |
+
percent_dense: 0.1
|
60 |
+
density_start_iter: 100
|
61 |
+
density_end_iter: 3000
|
62 |
+
densification_interval: 50
|
63 |
+
opacity_reset_interval: 700
|
64 |
+
densify_grad_threshold: 0.01
|
65 |
+
|
66 |
+
### Textured Mesh
|
67 |
+
geom_lr: 0.0001
|
68 |
+
texture_lr: 0.2
|
data/anya_rgba.png
ADDED
Git LFS Details
|
data/catstatue_rgba.png
ADDED
Git LFS Details
|
data/csm_luigi_rgba.png
ADDED
Git LFS Details
|
data/test.png
ADDED
Git LFS Details
|
data/zelda_rgba.png
ADDED
Git LFS Details
|
grid_put.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
def stride_from_shape(shape):
|
5 |
+
stride = [1]
|
6 |
+
for x in reversed(shape[1:]):
|
7 |
+
stride.append(stride[-1] * x)
|
8 |
+
return list(reversed(stride))
|
9 |
+
|
10 |
+
|
11 |
+
def scatter_add_nd(input, indices, values):
|
12 |
+
# input: [..., C], D dimension + C channel
|
13 |
+
# indices: [N, D], long
|
14 |
+
# values: [N, C]
|
15 |
+
|
16 |
+
D = indices.shape[-1]
|
17 |
+
C = input.shape[-1]
|
18 |
+
size = input.shape[:-1]
|
19 |
+
stride = stride_from_shape(size)
|
20 |
+
|
21 |
+
assert len(size) == D
|
22 |
+
|
23 |
+
input = input.view(-1, C) # [HW, C]
|
24 |
+
flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
|
25 |
+
|
26 |
+
input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
|
27 |
+
|
28 |
+
return input.view(*size, C)
|
29 |
+
|
30 |
+
|
31 |
+
def scatter_add_nd_with_count(input, count, indices, values, weights=None):
|
32 |
+
# input: [..., C], D dimension + C channel
|
33 |
+
# count: [..., 1], D dimension
|
34 |
+
# indices: [N, D], long
|
35 |
+
# values: [N, C]
|
36 |
+
|
37 |
+
D = indices.shape[-1]
|
38 |
+
C = input.shape[-1]
|
39 |
+
size = input.shape[:-1]
|
40 |
+
stride = stride_from_shape(size)
|
41 |
+
|
42 |
+
assert len(size) == D
|
43 |
+
|
44 |
+
input = input.view(-1, C) # [HW, C]
|
45 |
+
count = count.view(-1, 1)
|
46 |
+
|
47 |
+
flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
|
48 |
+
|
49 |
+
if weights is None:
|
50 |
+
weights = torch.ones_like(values[..., :1])
|
51 |
+
|
52 |
+
input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
|
53 |
+
count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
|
54 |
+
|
55 |
+
return input.view(*size, C), count.view(*size, 1)
|
56 |
+
|
57 |
+
def nearest_grid_put_2d(H, W, coords, values, return_count=False):
|
58 |
+
# coords: [N, 2], float in [-1, 1]
|
59 |
+
# values: [N, C]
|
60 |
+
|
61 |
+
C = values.shape[-1]
|
62 |
+
|
63 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
64 |
+
[H - 1, W - 1], dtype=torch.float32, device=coords.device
|
65 |
+
)
|
66 |
+
indices = indices.round().long() # [N, 2]
|
67 |
+
|
68 |
+
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
69 |
+
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
70 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
71 |
+
|
72 |
+
result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
|
73 |
+
|
74 |
+
if return_count:
|
75 |
+
return result, count
|
76 |
+
|
77 |
+
mask = (count.squeeze(-1) > 0)
|
78 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
79 |
+
|
80 |
+
return result
|
81 |
+
|
82 |
+
|
83 |
+
def linear_grid_put_2d(H, W, coords, values, return_count=False):
|
84 |
+
# coords: [N, 2], float in [-1, 1]
|
85 |
+
# values: [N, C]
|
86 |
+
|
87 |
+
C = values.shape[-1]
|
88 |
+
|
89 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
90 |
+
[H - 1, W - 1], dtype=torch.float32, device=coords.device
|
91 |
+
)
|
92 |
+
indices_00 = indices.floor().long() # [N, 2]
|
93 |
+
indices_00[:, 0].clamp_(0, H - 2)
|
94 |
+
indices_00[:, 1].clamp_(0, W - 2)
|
95 |
+
indices_01 = indices_00 + torch.tensor(
|
96 |
+
[0, 1], dtype=torch.long, device=indices.device
|
97 |
+
)
|
98 |
+
indices_10 = indices_00 + torch.tensor(
|
99 |
+
[1, 0], dtype=torch.long, device=indices.device
|
100 |
+
)
|
101 |
+
indices_11 = indices_00 + torch.tensor(
|
102 |
+
[1, 1], dtype=torch.long, device=indices.device
|
103 |
+
)
|
104 |
+
|
105 |
+
h = indices[..., 0] - indices_00[..., 0].float()
|
106 |
+
w = indices[..., 1] - indices_00[..., 1].float()
|
107 |
+
w_00 = (1 - h) * (1 - w)
|
108 |
+
w_01 = (1 - h) * w
|
109 |
+
w_10 = h * (1 - w)
|
110 |
+
w_11 = h * w
|
111 |
+
|
112 |
+
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
113 |
+
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
114 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
115 |
+
|
116 |
+
result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
|
117 |
+
result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
|
118 |
+
result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
|
119 |
+
result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))
|
120 |
+
|
121 |
+
if return_count:
|
122 |
+
return result, count
|
123 |
+
|
124 |
+
mask = (count.squeeze(-1) > 0)
|
125 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
126 |
+
|
127 |
+
return result
|
128 |
+
|
129 |
+
def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
|
130 |
+
# coords: [N, 2], float in [-1, 1]
|
131 |
+
# values: [N, C]
|
132 |
+
|
133 |
+
C = values.shape[-1]
|
134 |
+
|
135 |
+
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
136 |
+
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
137 |
+
|
138 |
+
cur_H, cur_W = H, W
|
139 |
+
|
140 |
+
while min(cur_H, cur_W) > min_resolution:
|
141 |
+
|
142 |
+
# try to fill the holes
|
143 |
+
mask = (count.squeeze(-1) == 0)
|
144 |
+
if not mask.any():
|
145 |
+
break
|
146 |
+
|
147 |
+
cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
|
148 |
+
result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
|
149 |
+
count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
|
150 |
+
cur_H //= 2
|
151 |
+
cur_W //= 2
|
152 |
+
|
153 |
+
if return_count:
|
154 |
+
return result, count
|
155 |
+
|
156 |
+
mask = (count.squeeze(-1) > 0)
|
157 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
158 |
+
|
159 |
+
return result
|
160 |
+
|
161 |
+
def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
|
162 |
+
# coords: [N, 3], float in [-1, 1]
|
163 |
+
# values: [N, C]
|
164 |
+
|
165 |
+
C = values.shape[-1]
|
166 |
+
|
167 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
168 |
+
[H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
|
169 |
+
)
|
170 |
+
indices = indices.round().long() # [N, 2]
|
171 |
+
|
172 |
+
result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C]
|
173 |
+
count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
|
174 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
175 |
+
|
176 |
+
result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
|
177 |
+
|
178 |
+
if return_count:
|
179 |
+
return result, count
|
180 |
+
|
181 |
+
mask = (count.squeeze(-1) > 0)
|
182 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
183 |
+
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
|
188 |
+
# coords: [N, 3], float in [-1, 1]
|
189 |
+
# values: [N, C]
|
190 |
+
|
191 |
+
C = values.shape[-1]
|
192 |
+
|
193 |
+
indices = (coords * 0.5 + 0.5) * torch.tensor(
|
194 |
+
[H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
|
195 |
+
)
|
196 |
+
indices_000 = indices.floor().long() # [N, 3]
|
197 |
+
indices_000[:, 0].clamp_(0, H - 2)
|
198 |
+
indices_000[:, 1].clamp_(0, W - 2)
|
199 |
+
indices_000[:, 2].clamp_(0, D - 2)
|
200 |
+
|
201 |
+
indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
|
202 |
+
indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
|
203 |
+
indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
|
204 |
+
indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
|
205 |
+
indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
|
206 |
+
indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
|
207 |
+
indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)
|
208 |
+
|
209 |
+
h = indices[..., 0] - indices_000[..., 0].float()
|
210 |
+
w = indices[..., 1] - indices_000[..., 1].float()
|
211 |
+
d = indices[..., 2] - indices_000[..., 2].float()
|
212 |
+
|
213 |
+
w_000 = (1 - h) * (1 - w) * (1 - d)
|
214 |
+
w_001 = (1 - h) * w * (1 - d)
|
215 |
+
w_010 = h * (1 - w) * (1 - d)
|
216 |
+
w_011 = h * w * (1 - d)
|
217 |
+
w_100 = (1 - h) * (1 - w) * d
|
218 |
+
w_101 = (1 - h) * w * d
|
219 |
+
w_110 = h * (1 - w) * d
|
220 |
+
w_111 = h * w * d
|
221 |
+
|
222 |
+
result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
|
223 |
+
count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
|
224 |
+
weights = torch.ones_like(values[..., :1]) # [N, 1]
|
225 |
+
|
226 |
+
result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
|
227 |
+
result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
|
228 |
+
result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
|
229 |
+
result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
|
230 |
+
result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
|
231 |
+
result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
|
232 |
+
result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
|
233 |
+
result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))
|
234 |
+
|
235 |
+
if return_count:
|
236 |
+
return result, count
|
237 |
+
|
238 |
+
mask = (count.squeeze(-1) > 0)
|
239 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
240 |
+
|
241 |
+
return result
|
242 |
+
|
243 |
+
def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
|
244 |
+
# coords: [N, 3], float in [-1, 1]
|
245 |
+
# values: [N, C]
|
246 |
+
|
247 |
+
C = values.shape[-1]
|
248 |
+
|
249 |
+
result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
|
250 |
+
count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
|
251 |
+
cur_H, cur_W, cur_D = H, W, D
|
252 |
+
|
253 |
+
while min(min(cur_H, cur_W), cur_D) > min_resolution:
|
254 |
+
|
255 |
+
# try to fill the holes
|
256 |
+
mask = (count.squeeze(-1) == 0)
|
257 |
+
if not mask.any():
|
258 |
+
break
|
259 |
+
|
260 |
+
cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
|
261 |
+
result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
|
262 |
+
count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
|
263 |
+
cur_H //= 2
|
264 |
+
cur_W //= 2
|
265 |
+
cur_D //= 2
|
266 |
+
|
267 |
+
if return_count:
|
268 |
+
return result, count
|
269 |
+
|
270 |
+
mask = (count.squeeze(-1) > 0)
|
271 |
+
result[mask] = result[mask] / count[mask].repeat(1, C)
|
272 |
+
|
273 |
+
return result
|
274 |
+
|
275 |
+
|
276 |
+
def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False):
|
277 |
+
# shape: [D], list/tuple
|
278 |
+
# coords: [N, D], float in [-1, 1]
|
279 |
+
# values: [N, C]
|
280 |
+
|
281 |
+
D = len(shape)
|
282 |
+
assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}'
|
283 |
+
|
284 |
+
if mode == 'nearest':
|
285 |
+
if D == 2:
|
286 |
+
return nearest_grid_put_2d(*shape, coords, values, return_raw)
|
287 |
+
else:
|
288 |
+
return nearest_grid_put_3d(*shape, coords, values, return_raw)
|
289 |
+
elif mode == 'linear':
|
290 |
+
if D == 2:
|
291 |
+
return linear_grid_put_2d(*shape, coords, values, return_raw)
|
292 |
+
else:
|
293 |
+
return linear_grid_put_3d(*shape, coords, values, return_raw)
|
294 |
+
elif mode == 'linear-mipmap':
|
295 |
+
if D == 2:
|
296 |
+
return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw)
|
297 |
+
else:
|
298 |
+
return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw)
|
299 |
+
else:
|
300 |
+
raise NotImplementedError(f"got mode {mode}")
|
gs_renderer.py
ADDED
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
from typing import NamedTuple
|
5 |
+
from plyfile import PlyData, PlyElement
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from diff_gaussian_rasterization import (
|
11 |
+
GaussianRasterizationSettings,
|
12 |
+
GaussianRasterizer,
|
13 |
+
)
|
14 |
+
from simple_knn._C import distCUDA2
|
15 |
+
|
16 |
+
from sh_utils import eval_sh, SH2RGB, RGB2SH
|
17 |
+
from mesh import Mesh
|
18 |
+
from mesh_utils import decimate_mesh, clean_mesh
|
19 |
+
|
20 |
+
import kiui
|
21 |
+
|
22 |
+
def inverse_sigmoid(x):
|
23 |
+
return torch.log(x/(1-x))
|
24 |
+
|
25 |
+
def get_expon_lr_func(
|
26 |
+
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
|
27 |
+
):
|
28 |
+
|
29 |
+
def helper(step):
|
30 |
+
if lr_init == lr_final:
|
31 |
+
# constant lr, ignore other params
|
32 |
+
return lr_init
|
33 |
+
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
|
34 |
+
# Disable this parameter
|
35 |
+
return 0.0
|
36 |
+
if lr_delay_steps > 0:
|
37 |
+
# A kind of reverse cosine decay.
|
38 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
39 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
delay_rate = 1.0
|
43 |
+
t = np.clip(step / max_steps, 0, 1)
|
44 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
45 |
+
return delay_rate * log_lerp
|
46 |
+
|
47 |
+
return helper
|
48 |
+
|
49 |
+
|
50 |
+
def strip_lowerdiag(L):
|
51 |
+
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
|
52 |
+
|
53 |
+
uncertainty[:, 0] = L[:, 0, 0]
|
54 |
+
uncertainty[:, 1] = L[:, 0, 1]
|
55 |
+
uncertainty[:, 2] = L[:, 0, 2]
|
56 |
+
uncertainty[:, 3] = L[:, 1, 1]
|
57 |
+
uncertainty[:, 4] = L[:, 1, 2]
|
58 |
+
uncertainty[:, 5] = L[:, 2, 2]
|
59 |
+
return uncertainty
|
60 |
+
|
61 |
+
def strip_symmetric(sym):
|
62 |
+
return strip_lowerdiag(sym)
|
63 |
+
|
64 |
+
def gaussian_3d_coeff(xyzs, covs):
|
65 |
+
# xyzs: [N, 3]
|
66 |
+
# covs: [N, 6]
|
67 |
+
x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
|
68 |
+
a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
|
69 |
+
|
70 |
+
# eps must be small enough !!!
|
71 |
+
inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
|
72 |
+
inv_a = (d * f - e**2) * inv_det
|
73 |
+
inv_b = (e * c - b * f) * inv_det
|
74 |
+
inv_c = (e * b - c * d) * inv_det
|
75 |
+
inv_d = (a * f - c**2) * inv_det
|
76 |
+
inv_e = (b * c - e * a) * inv_det
|
77 |
+
inv_f = (a * d - b**2) * inv_det
|
78 |
+
|
79 |
+
power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
|
80 |
+
|
81 |
+
power[power > 0] = -1e10 # abnormal values... make weights 0
|
82 |
+
|
83 |
+
return torch.exp(power)
|
84 |
+
|
85 |
+
def build_rotation(r):
|
86 |
+
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
|
87 |
+
|
88 |
+
q = r / norm[:, None]
|
89 |
+
|
90 |
+
R = torch.zeros((q.size(0), 3, 3), device='cuda')
|
91 |
+
|
92 |
+
r = q[:, 0]
|
93 |
+
x = q[:, 1]
|
94 |
+
y = q[:, 2]
|
95 |
+
z = q[:, 3]
|
96 |
+
|
97 |
+
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
|
98 |
+
R[:, 0, 1] = 2 * (x*y - r*z)
|
99 |
+
R[:, 0, 2] = 2 * (x*z + r*y)
|
100 |
+
R[:, 1, 0] = 2 * (x*y + r*z)
|
101 |
+
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
|
102 |
+
R[:, 1, 2] = 2 * (y*z - r*x)
|
103 |
+
R[:, 2, 0] = 2 * (x*z - r*y)
|
104 |
+
R[:, 2, 1] = 2 * (y*z + r*x)
|
105 |
+
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
|
106 |
+
return R
|
107 |
+
|
108 |
+
def build_scaling_rotation(s, r):
|
109 |
+
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
|
110 |
+
R = build_rotation(r)
|
111 |
+
|
112 |
+
L[:,0,0] = s[:,0]
|
113 |
+
L[:,1,1] = s[:,1]
|
114 |
+
L[:,2,2] = s[:,2]
|
115 |
+
|
116 |
+
L = R @ L
|
117 |
+
return L
|
118 |
+
|
119 |
+
class BasicPointCloud(NamedTuple):
|
120 |
+
points: np.array
|
121 |
+
colors: np.array
|
122 |
+
normals: np.array
|
123 |
+
|
124 |
+
|
125 |
+
class GaussianModel:
|
126 |