jiawei011 commited on
Commit
12b7f59
1 Parent(s): 5f58ec6
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ build/
3
+ *.egg-info/
4
+ *.so
5
+ venv_*/
6
+ .vs/
7
+ .vscode/
8
+ .idea/
9
+
10
+ tmp_*
11
+ data?
12
+ data??
13
+ scripts2
14
+
15
+ model_cache
16
+
17
+ logs
18
+ videos
19
+ images
20
+ *.mp4
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 dreamgaussian
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LICENSE_GAUSSIAN_SPLATTING.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gaussian-Splatting License
2
+ ===========================
3
+
4
+ **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5
+ The *Software* is in the process of being registered with the Agence pour la Protection des
6
+ Programmes (APP).
7
+
8
+ The *Software* is still being developed by the *Licensor*.
9
+
10
+ *Licensor*'s goal is to allow the research community to use, test and evaluate
11
+ the *Software*.
12
+
13
+ ## 1. Definitions
14
+
15
+ *Licensee* means any person or entity that uses the *Software* and distributes
16
+ its *Work*.
17
+
18
+ *Licensor* means the owners of the *Software*, i.e Inria and MPII
19
+
20
+ *Software* means the original work of authorship made available under this
21
+ License ie gaussian-splatting.
22
+
23
+ *Work* means the *Software* and any additions to or derivative works of the
24
+ *Software* that are made available under this License.
25
+
26
+
27
+ ## 2. Purpose
28
+ This license is intended to define the rights granted to the *Licensee* by
29
+ Licensors under the *Software*.
30
+
31
+ ## 3. Rights granted
32
+
33
+ For the above reasons Licensors have decided to distribute the *Software*.
34
+ Licensors grant non-exclusive rights to use the *Software* for research purposes
35
+ to research users (both academic and industrial), free of charge, without right
36
+ to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37
+ and/or evaluation purposes only.
38
+
39
+ Subject to the terms and conditions of this License, you are granted a
40
+ non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41
+ publicly display, publicly perform and distribute its *Work* and any resulting
42
+ derivative works in any form.
43
+
44
+ ## 4. Limitations
45
+
46
+ **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47
+ so under this License, (b) you include a complete copy of this License with
48
+ your distribution, and (c) you retain without modification any copyright,
49
+ patent, trademark, or attribution notices that are present in the *Work*.
50
+
51
+ **4.2 Derivative Works.** You may specify that additional or different terms apply
52
+ to the use, reproduction, and distribution of your derivative works of the *Work*
53
+ ("Your Terms") only if (a) Your Terms provide that the use limitation in
54
+ Section 2 applies to your derivative works, and (b) you identify the specific
55
+ derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56
+ this License (including the redistribution requirements in Section 3.1) will
57
+ continue to apply to the *Work* itself.
58
+
59
+ **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60
+ users explicitly acknowledge having received from Licensors all information
61
+ allowing to appreciate the adequacy between of the *Software* and their needs and
62
+ to undertake all necessary precautions for its execution and use.
63
+
64
+ **4.4** The *Software* is provided both as a compiled library file and as source
65
+ code. In case of using the *Software* for a publication or other results obtained
66
+ through the use of the *Software*, users are strongly encouraged to cite the
67
+ corresponding publications as explained in the documentation of the *Software*.
68
+
69
+ ## 5. Disclaimer
70
+
71
+ THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72
+ WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73
+ UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74
+ CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75
+ OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76
+ USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77
+ ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78
+ AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80
+ GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81
+ HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83
+ IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Dreamgaussian
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 3.47.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from PIL import Image
4
+ import subprocess
5
+
6
+
7
+ # check if there is a picture uploaded or selected
8
+ def check_img_input(control_image):
9
+ if control_image is None:
10
+ raise gr.Error("Please select or upload an input image")
11
+
12
+
13
+ def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, elevation_slider: float):
14
+ if not os.path.exists('tmp_data'):
15
+ os.makedirs('tmp_data')
16
+ if preprocess_chk:
17
+ # save image to a designated path
18
+ image_block.save('tmp_data/tmp.png')
19
+
20
+ # preprocess image
21
+ subprocess.run([f'python process.py tmp_data/tmp.png'], shell=True)
22
+ else:
23
+ image_block.save('tmp_data/tmp_rgba.png')
24
+
25
+ # stage 1
26
+ subprocess.run([
27
+ f'python main.py --config configs/image.yaml input=tmp_data/tmp_rgba.png save_path=tmp mesh_format=glb elevation={elevation_slider} force_cuda_rast=True'],
28
+ shell=True)
29
+
30
+ return f'logs/tmp_mesh.glb'
31
+
32
+
33
+ def optimize_stage_2(elevation_slider: float):
34
+ # stage 2
35
+ subprocess.run([
36
+ f'python main2.py --config configs/image.yaml input=tmp_data/tmp_rgba.png save_path=tmp mesh_format=glb elevation={elevation_slider} force_cuda_rast=True'],
37
+ shell=True)
38
+
39
+ return f'logs/tmp.glb'
40
+
41
+
42
+ if __name__ == "__main__":
43
+ _TITLE = '''DreamGaussian: Generative Gaussian Splatting for Efficient 3D Content Creation'''
44
+
45
+ _DESCRIPTION = '''
46
+ <div>
47
+ <a style="display:inline-block" href="https://dreamgaussian.github.io"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
48
+ <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2309.16653"><img src="https://img.shields.io/badge/2306.16928-f9f7f7?logo="></a>
49
+ <a style="display:inline-block; margin-left: .5em" href='https://github.com/dreamgaussian/dreamgaussian'><img src='https://img.shields.io/github/stars/dreamgaussian/dreamgaussian?style=social'/></a>
50
+ </div>
51
+ We present DreamGausssion, a 3D content generation framework that significantly improves the efficiency of 3D content creation.
52
+ '''
53
+ _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above) and click **Generate 3D**."
54
+
55
+ # load images in 'data' folder as examples
56
+ example_folder = os.path.join(os.path.dirname(__file__), 'data')
57
+ example_fns = os.listdir(example_folder)
58
+ example_fns.sort()
59
+ examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
60
+
61
+ # Compose demo layout & data flow
62
+ with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo:
63
+ with gr.Row():
64
+ with gr.Column(scale=1):
65
+ gr.Markdown('# ' + _TITLE)
66
+ gr.Markdown(_DESCRIPTION)
67
+
68
+ # Image-to-3D
69
+ with gr.Row(variant='panel'):
70
+ with gr.Column(scale=5):
71
+ image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image', tool=None)
72
+
73
+ elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
74
+ gr.Markdown(
75
+ "default to 0 (horizontal), range from [-90, 90]. If you upload a look-down image, try a value like -30")
76
+
77
+ preprocess_chk = gr.Checkbox(True,
78
+ label='Preprocess image automatically (remove background and recenter object)')
79
+
80
+ gr.Examples(
81
+ examples=examples_full, # NOTE: elements must match inputs list!
82
+ inputs=[image_block],
83
+ outputs=[image_block],
84
+ cache_examples=False,
85
+ label='Examples (click one of the images below to start)',
86
+ examples_per_page=40
87
+ )
88
+ img_run_btn = gr.Button("Generate 3D")
89
+ img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
90
+
91
+ with gr.Column(scale=5):
92
+ obj3d_stage1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Stage 1)")
93
+ obj3d = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model (Final)")
94
+
95
+ # if there is an input image, continue with inference
96
+ # else display an error message
97
+ img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1,
98
+ inputs=[image_block,
99
+ preprocess_chk,
100
+ elevation_slider],
101
+ outputs=[
102
+ obj3d_stage1]).success(
103
+ optimize_stage_2, inputs=[elevation_slider], outputs=[obj3d])
104
+
105
+ demo.queue().launch(share=True)
cam_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+
4
+ import torch
5
+
6
+ def dot(x, y):
7
+ if isinstance(x, np.ndarray):
8
+ return np.sum(x * y, -1, keepdims=True)
9
+ else:
10
+ return torch.sum(x * y, -1, keepdim=True)
11
+
12
+
13
+ def length(x, eps=1e-20):
14
+ if isinstance(x, np.ndarray):
15
+ return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
16
+ else:
17
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
18
+
19
+
20
+ def safe_normalize(x, eps=1e-20):
21
+ return x / length(x, eps)
22
+
23
+
24
+ def look_at(campos, target, opengl=True):
25
+ # campos: [N, 3], camera/eye position
26
+ # target: [N, 3], object to look at
27
+ # return: [N, 3, 3], rotation matrix
28
+ if not opengl:
29
+ # camera forward aligns with -z
30
+ forward_vector = safe_normalize(target - campos)
31
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
32
+ right_vector = safe_normalize(np.cross(forward_vector, up_vector))
33
+ up_vector = safe_normalize(np.cross(right_vector, forward_vector))
34
+ else:
35
+ # camera forward aligns with +z
36
+ forward_vector = safe_normalize(campos - target)
37
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
38
+ right_vector = safe_normalize(np.cross(up_vector, forward_vector))
39
+ up_vector = safe_normalize(np.cross(forward_vector, right_vector))
40
+ R = np.stack([right_vector, up_vector, forward_vector], axis=1)
41
+ return R
42
+
43
+
44
+ # elevation & azimuth to pose (cam2world) matrix
45
+ def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
46
+ # radius: scalar
47
+ # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
48
+ # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
49
+ # return: [4, 4], camera pose matrix
50
+ if is_degree:
51
+ elevation = np.deg2rad(elevation)
52
+ azimuth = np.deg2rad(azimuth)
53
+ x = radius * np.cos(elevation) * np.sin(azimuth)
54
+ y = - radius * np.sin(elevation)
55
+ z = radius * np.cos(elevation) * np.cos(azimuth)
56
+ if target is None:
57
+ target = np.zeros([3], dtype=np.float32)
58
+ campos = np.array([x, y, z]) + target # [3]
59
+ T = np.eye(4, dtype=np.float32)
60
+ T[:3, :3] = look_at(campos, target, opengl)
61
+ T[:3, 3] = campos
62
+ return T
63
+
64
+
65
+ class OrbitCamera:
66
+ def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
67
+ self.W = W
68
+ self.H = H
69
+ self.radius = r # camera distance from center
70
+ self.fovy = np.deg2rad(fovy) # deg 2 rad
71
+ self.near = near
72
+ self.far = far
73
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
74
+ self.rot = R.from_matrix(np.eye(3))
75
+ self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
76
+
77
+ @property
78
+ def fovx(self):
79
+ return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
80
+
81
+ @property
82
+ def campos(self):
83
+ return self.pose[:3, 3]
84
+
85
+ # pose (c2w)
86
+ @property
87
+ def pose(self):
88
+ # first move camera to radius
89
+ res = np.eye(4, dtype=np.float32)
90
+ res[2, 3] = self.radius # opengl convention...
91
+ # rotate
92
+ rot = np.eye(4, dtype=np.float32)
93
+ rot[:3, :3] = self.rot.as_matrix()
94
+ res = rot @ res
95
+ # translate
96
+ res[:3, 3] -= self.center
97
+ return res
98
+
99
+ # view (w2c)
100
+ @property
101
+ def view(self):
102
+ return np.linalg.inv(self.pose)
103
+
104
+ # projection (perspective)
105
+ @property
106
+ def perspective(self):
107
+ y = np.tan(self.fovy / 2)
108
+ aspect = self.W / self.H
109
+ return np.array(
110
+ [
111
+ [1 / (y * aspect), 0, 0, 0],
112
+ [0, -1 / y, 0, 0],
113
+ [
114
+ 0,
115
+ 0,
116
+ -(self.far + self.near) / (self.far - self.near),
117
+ -(2 * self.far * self.near) / (self.far - self.near),
118
+ ],
119
+ [0, 0, -1, 0],
120
+ ],
121
+ dtype=np.float32,
122
+ )
123
+
124
+ # intrinsics
125
+ @property
126
+ def intrinsics(self):
127
+ focal = self.H / (2 * np.tan(self.fovy / 2))
128
+ return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
129
+
130
+ @property
131
+ def mvp(self):
132
+ return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
133
+
134
+ def orbit(self, dx, dy):
135
+ # rotate along camera up/side axis!
136
+ side = self.rot.as_matrix()[:3, 0]
137
+ rotvec_x = self.up * np.radians(-0.05 * dx)
138
+ rotvec_y = side * np.radians(-0.05 * dy)
139
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
140
+
141
+ def scale(self, delta):
142
+ self.radius *= 1.1 ** (-delta)
143
+
144
+ def pan(self, dx, dy, dz=0):
145
+ # pan in camera coordinate system (careful on the sensitivity!)
146
+ self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
configs/image.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Input
2
+ # input rgba image path (default to None, can be load in GUI too)
3
+ input:
4
+ # input text prompt (default to None, can be input in GUI too)
5
+ prompt:
6
+ # input mesh for stage 2 (auto-search from stage 1 output path if None)
7
+ mesh:
8
+ # estimated elevation angle for input image
9
+ elevation: 0
10
+ # reference image resolution
11
+ ref_size: 256
12
+ # density thresh for mesh extraction
13
+ density_thresh: 1
14
+
15
+ ### Output
16
+ outdir: logs
17
+ mesh_format: obj
18
+ save_path: ???
19
+
20
+ ### Training
21
+ # guidance loss weights (0 to disable)
22
+ lambda_sd: 0
23
+ lambda_zero123: 1
24
+ # training batch size per iter
25
+ batch_size: 1
26
+ # training iterations for stage 1
27
+ iters: 500
28
+ # training iterations for stage 2
29
+ iters_refine: 50
30
+ # training camera radius
31
+ radius: 2
32
+ # training camera fovy
33
+ fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61
34
+ # checkpoint to load for stage 1 (should be a ply file)
35
+ load:
36
+ # whether allow geom training in stage 2
37
+ train_geo: False
38
+ # prob to invert background color during training (0 = always black, 1 = always white)
39
+ invert_bg_prob: 0.5
40
+
41
+
42
+ ### GUI
43
+ gui: False
44
+ force_cuda_rast: False
45
+ # GUI resolution
46
+ H: 800
47
+ W: 800
48
+
49
+ ### Gaussian splatting
50
+ num_pts: 5000
51
+ sh_degree: 0
52
+ position_lr_init: 0.001
53
+ position_lr_final: 0.00002
54
+ position_lr_delay_mult: 0.02
55
+ position_lr_max_steps: 500
56
+ feature_lr: 0.01
57
+ opacity_lr: 0.05
58
+ scaling_lr: 0.005
59
+ rotation_lr: 0.005
60
+ percent_dense: 0.1
61
+ density_start_iter: 100
62
+ density_end_iter: 3000
63
+ densification_interval: 100
64
+ opacity_reset_interval: 700
65
+ densify_grad_threshold: 0.5
66
+
67
+ ### Textured Mesh
68
+ geom_lr: 0.0001
69
+ texture_lr: 0.2
configs/text.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Input
2
+ # input rgba image path (default to None, can be load in GUI too)
3
+ input:
4
+ # input text prompt (default to None, can be input in GUI too)
5
+ prompt:
6
+ # input mesh for stage 2 (auto-search from stage 1 output path if None)
7
+ mesh:
8
+ # estimated elevation angle for input image
9
+ elevation: 0
10
+ # reference image resolution
11
+ ref_size: 256
12
+ # density thresh for mesh extraction
13
+ density_thresh: 1
14
+
15
+ ### Output
16
+ outdir: logs
17
+ mesh_format: obj
18
+ save_path: ???
19
+
20
+ ### Training
21
+ # guidance loss weights (0 to disable)
22
+ lambda_sd: 1
23
+ lambda_zero123: 0
24
+ # training batch size per iter
25
+ batch_size: 1
26
+ # training iterations for stage 1
27
+ iters: 500
28
+ # training iterations for stage 2
29
+ iters_refine: 50
30
+ # training camera radius
31
+ radius: 2.5
32
+ # training camera fovy
33
+ fovy: 49.1
34
+ # checkpoint to load for stage 1 (should be a ply file)
35
+ load:
36
+ # whether allow geom training in stage 2
37
+ train_geo: False
38
+ # prob to invert background color during training (0 = always black, 1 = always white)
39
+ invert_bg_prob: 0.5
40
+
41
+ ### GUI
42
+ gui: False
43
+ force_cuda_rast: False
44
+ # GUI resolution
45
+ H: 800
46
+ W: 800
47
+
48
+ ### Gaussian splatting
49
+ num_pts: 1000
50
+ sh_degree: 0
51
+ position_lr_init: 0.001
52
+ position_lr_final: 0.00002
53
+ position_lr_delay_mult: 0.02
54
+ position_lr_max_steps: 500
55
+ feature_lr: 0.01
56
+ opacity_lr: 0.05
57
+ scaling_lr: 0.005
58
+ rotation_lr: 0.005
59
+ percent_dense: 0.1
60
+ density_start_iter: 100
61
+ density_end_iter: 3000
62
+ densification_interval: 50
63
+ opacity_reset_interval: 700
64
+ densify_grad_threshold: 0.01
65
+
66
+ ### Textured Mesh
67
+ geom_lr: 0.0001
68
+ texture_lr: 0.2
data/anya_rgba.png ADDED

Git LFS Details

  • SHA256: b8c3e8fe7fb51c4ae7f8b561e3780a50f1f25a9cb8c838d7fce4b38d773473f8
  • Pointer size: 130 Bytes
  • Size of remote file: 32.9 kB
data/catstatue_rgba.png ADDED

Git LFS Details

  • SHA256: 6a571efb23ff05f92d7363d32a4027c08137d84e9bde863c7dfca5086bd3005d
  • Pointer size: 130 Bytes
  • Size of remote file: 45.5 kB
data/csm_luigi_rgba.png ADDED

Git LFS Details

  • SHA256: 538fd1c3d1be3f0ef0cbdbf60d3e77821cb304dd68e3fbd62229191d5d050186
  • Pointer size: 130 Bytes
  • Size of remote file: 35.4 kB
data/test.png ADDED

Git LFS Details

  • SHA256: 479f4fa9a5d2fcbf81240533f347a0d080050162757702317c8d7e06401bb958
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
data/zelda_rgba.png ADDED

Git LFS Details

  • SHA256: b5e5004f1c64cbb9aceaf47c3594cfb89dfee64fbdf1a5a10faa5f51e87f0c4f
  • Pointer size: 130 Bytes
  • Size of remote file: 44.9 kB
grid_put.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def stride_from_shape(shape):
5
+ stride = [1]
6
+ for x in reversed(shape[1:]):
7
+ stride.append(stride[-1] * x)
8
+ return list(reversed(stride))
9
+
10
+
11
+ def scatter_add_nd(input, indices, values):
12
+ # input: [..., C], D dimension + C channel
13
+ # indices: [N, D], long
14
+ # values: [N, C]
15
+
16
+ D = indices.shape[-1]
17
+ C = input.shape[-1]
18
+ size = input.shape[:-1]
19
+ stride = stride_from_shape(size)
20
+
21
+ assert len(size) == D
22
+
23
+ input = input.view(-1, C) # [HW, C]
24
+ flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
25
+
26
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
27
+
28
+ return input.view(*size, C)
29
+
30
+
31
+ def scatter_add_nd_with_count(input, count, indices, values, weights=None):
32
+ # input: [..., C], D dimension + C channel
33
+ # count: [..., 1], D dimension
34
+ # indices: [N, D], long
35
+ # values: [N, C]
36
+
37
+ D = indices.shape[-1]
38
+ C = input.shape[-1]
39
+ size = input.shape[:-1]
40
+ stride = stride_from_shape(size)
41
+
42
+ assert len(size) == D
43
+
44
+ input = input.view(-1, C) # [HW, C]
45
+ count = count.view(-1, 1)
46
+
47
+ flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
48
+
49
+ if weights is None:
50
+ weights = torch.ones_like(values[..., :1])
51
+
52
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
53
+ count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
54
+
55
+ return input.view(*size, C), count.view(*size, 1)
56
+
57
+ def nearest_grid_put_2d(H, W, coords, values, return_count=False):
58
+ # coords: [N, 2], float in [-1, 1]
59
+ # values: [N, C]
60
+
61
+ C = values.shape[-1]
62
+
63
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
64
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
65
+ )
66
+ indices = indices.round().long() # [N, 2]
67
+
68
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
69
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
70
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
71
+
72
+ result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
73
+
74
+ if return_count:
75
+ return result, count
76
+
77
+ mask = (count.squeeze(-1) > 0)
78
+ result[mask] = result[mask] / count[mask].repeat(1, C)
79
+
80
+ return result
81
+
82
+
83
+ def linear_grid_put_2d(H, W, coords, values, return_count=False):
84
+ # coords: [N, 2], float in [-1, 1]
85
+ # values: [N, C]
86
+
87
+ C = values.shape[-1]
88
+
89
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
90
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
91
+ )
92
+ indices_00 = indices.floor().long() # [N, 2]
93
+ indices_00[:, 0].clamp_(0, H - 2)
94
+ indices_00[:, 1].clamp_(0, W - 2)
95
+ indices_01 = indices_00 + torch.tensor(
96
+ [0, 1], dtype=torch.long, device=indices.device
97
+ )
98
+ indices_10 = indices_00 + torch.tensor(
99
+ [1, 0], dtype=torch.long, device=indices.device
100
+ )
101
+ indices_11 = indices_00 + torch.tensor(
102
+ [1, 1], dtype=torch.long, device=indices.device
103
+ )
104
+
105
+ h = indices[..., 0] - indices_00[..., 0].float()
106
+ w = indices[..., 1] - indices_00[..., 1].float()
107
+ w_00 = (1 - h) * (1 - w)
108
+ w_01 = (1 - h) * w
109
+ w_10 = h * (1 - w)
110
+ w_11 = h * w
111
+
112
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
113
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
114
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
115
+
116
+ result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
117
+ result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
118
+ result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
119
+ result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))
120
+
121
+ if return_count:
122
+ return result, count
123
+
124
+ mask = (count.squeeze(-1) > 0)
125
+ result[mask] = result[mask] / count[mask].repeat(1, C)
126
+
127
+ return result
128
+
129
+ def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
130
+ # coords: [N, 2], float in [-1, 1]
131
+ # values: [N, C]
132
+
133
+ C = values.shape[-1]
134
+
135
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
136
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
137
+
138
+ cur_H, cur_W = H, W
139
+
140
+ while min(cur_H, cur_W) > min_resolution:
141
+
142
+ # try to fill the holes
143
+ mask = (count.squeeze(-1) == 0)
144
+ if not mask.any():
145
+ break
146
+
147
+ cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
148
+ result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
149
+ count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
150
+ cur_H //= 2
151
+ cur_W //= 2
152
+
153
+ if return_count:
154
+ return result, count
155
+
156
+ mask = (count.squeeze(-1) > 0)
157
+ result[mask] = result[mask] / count[mask].repeat(1, C)
158
+
159
+ return result
160
+
161
+ def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
162
+ # coords: [N, 3], float in [-1, 1]
163
+ # values: [N, C]
164
+
165
+ C = values.shape[-1]
166
+
167
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
168
+ [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
169
+ )
170
+ indices = indices.round().long() # [N, 2]
171
+
172
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C]
173
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
174
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
175
+
176
+ result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
177
+
178
+ if return_count:
179
+ return result, count
180
+
181
+ mask = (count.squeeze(-1) > 0)
182
+ result[mask] = result[mask] / count[mask].repeat(1, C)
183
+
184
+ return result
185
+
186
+
187
+ def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
188
+ # coords: [N, 3], float in [-1, 1]
189
+ # values: [N, C]
190
+
191
+ C = values.shape[-1]
192
+
193
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
194
+ [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
195
+ )
196
+ indices_000 = indices.floor().long() # [N, 3]
197
+ indices_000[:, 0].clamp_(0, H - 2)
198
+ indices_000[:, 1].clamp_(0, W - 2)
199
+ indices_000[:, 2].clamp_(0, D - 2)
200
+
201
+ indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
202
+ indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
203
+ indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
204
+ indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
205
+ indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
206
+ indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
207
+ indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)
208
+
209
+ h = indices[..., 0] - indices_000[..., 0].float()
210
+ w = indices[..., 1] - indices_000[..., 1].float()
211
+ d = indices[..., 2] - indices_000[..., 2].float()
212
+
213
+ w_000 = (1 - h) * (1 - w) * (1 - d)
214
+ w_001 = (1 - h) * w * (1 - d)
215
+ w_010 = h * (1 - w) * (1 - d)
216
+ w_011 = h * w * (1 - d)
217
+ w_100 = (1 - h) * (1 - w) * d
218
+ w_101 = (1 - h) * w * d
219
+ w_110 = h * (1 - w) * d
220
+ w_111 = h * w * d
221
+
222
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
223
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
224
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
225
+
226
+ result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
227
+ result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
228
+ result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
229
+ result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
230
+ result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
231
+ result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
232
+ result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
233
+ result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))
234
+
235
+ if return_count:
236
+ return result, count
237
+
238
+ mask = (count.squeeze(-1) > 0)
239
+ result[mask] = result[mask] / count[mask].repeat(1, C)
240
+
241
+ return result
242
+
243
+ def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
244
+ # coords: [N, 3], float in [-1, 1]
245
+ # values: [N, C]
246
+
247
+ C = values.shape[-1]
248
+
249
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
250
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
251
+ cur_H, cur_W, cur_D = H, W, D
252
+
253
+ while min(min(cur_H, cur_W), cur_D) > min_resolution:
254
+
255
+ # try to fill the holes
256
+ mask = (count.squeeze(-1) == 0)
257
+ if not mask.any():
258
+ break
259
+
260
+ cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
261
+ result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
262
+ count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
263
+ cur_H //= 2
264
+ cur_W //= 2
265
+ cur_D //= 2
266
+
267
+ if return_count:
268
+ return result, count
269
+
270
+ mask = (count.squeeze(-1) > 0)
271
+ result[mask] = result[mask] / count[mask].repeat(1, C)
272
+
273
+ return result
274
+
275
+
276
+ def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False):
277
+ # shape: [D], list/tuple
278
+ # coords: [N, D], float in [-1, 1]
279
+ # values: [N, C]
280
+
281
+ D = len(shape)
282
+ assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}'
283
+
284
+ if mode == 'nearest':
285
+ if D == 2:
286
+ return nearest_grid_put_2d(*shape, coords, values, return_raw)
287
+ else:
288
+ return nearest_grid_put_3d(*shape, coords, values, return_raw)
289
+ elif mode == 'linear':
290
+ if D == 2:
291
+ return linear_grid_put_2d(*shape, coords, values, return_raw)
292
+ else:
293
+ return linear_grid_put_3d(*shape, coords, values, return_raw)
294
+ elif mode == 'linear-mipmap':
295
+ if D == 2:
296
+ return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw)
297
+ else:
298
+ return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw)
299
+ else:
300
+ raise NotImplementedError(f"got mode {mode}")
gs_renderer.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ from typing import NamedTuple
5
+ from plyfile import PlyData, PlyElement
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from diff_gaussian_rasterization import (
11
+ GaussianRasterizationSettings,
12
+ GaussianRasterizer,
13
+ )
14
+ from simple_knn._C import distCUDA2
15
+
16
+ from sh_utils import eval_sh, SH2RGB, RGB2SH
17
+ from mesh import Mesh
18
+ from mesh_utils import decimate_mesh, clean_mesh
19
+
20
+ import kiui
21
+
22
+ def inverse_sigmoid(x):
23
+ return torch.log(x/(1-x))
24
+
25
+ def get_expon_lr_func(
26
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
27
+ ):
28
+
29
+ def helper(step):
30
+ if lr_init == lr_final:
31
+ # constant lr, ignore other params
32
+ return lr_init
33
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
34
+ # Disable this parameter
35
+ return 0.0
36
+ if lr_delay_steps > 0:
37
+ # A kind of reverse cosine decay.
38
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
39
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
40
+ )
41
+ else:
42
+ delay_rate = 1.0
43
+ t = np.clip(step / max_steps, 0, 1)
44
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
45
+ return delay_rate * log_lerp
46
+
47
+ return helper
48
+
49
+
50
+ def strip_lowerdiag(L):
51
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
52
+
53
+ uncertainty[:, 0] = L[:, 0, 0]
54
+ uncertainty[:, 1] = L[:, 0, 1]
55
+ uncertainty[:, 2] = L[:, 0, 2]
56
+ uncertainty[:, 3] = L[:, 1, 1]
57
+ uncertainty[:, 4] = L[:, 1, 2]
58
+ uncertainty[:, 5] = L[:, 2, 2]
59
+ return uncertainty
60
+
61
+ def strip_symmetric(sym):
62
+ return strip_lowerdiag(sym)
63
+
64
+ def gaussian_3d_coeff(xyzs, covs):
65
+ # xyzs: [N, 3]
66
+ # covs: [N, 6]
67
+ x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
68
+ a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
69
+
70
+ # eps must be small enough !!!
71
+ inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
72
+ inv_a = (d * f - e**2) * inv_det
73
+ inv_b = (e * c - b * f) * inv_det
74
+ inv_c = (e * b - c * d) * inv_det
75
+ inv_d = (a * f - c**2) * inv_det
76
+ inv_e = (b * c - e * a) * inv_det
77
+ inv_f = (a * d - b**2) * inv_det
78
+
79
+ power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
80
+
81
+ power[power > 0] = -1e10 # abnormal values... make weights 0
82
+
83
+ return torch.exp(power)
84
+
85
+ def build_rotation(r):
86
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
87
+
88
+ q = r / norm[:, None]
89
+
90
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
91
+
92
+ r = q[:, 0]
93
+ x = q[:, 1]
94
+ y = q[:, 2]
95
+ z = q[:, 3]
96
+
97
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
98
+ R[:, 0, 1] = 2 * (x*y - r*z)
99
+ R[:, 0, 2] = 2 * (x*z + r*y)
100
+ R[:, 1, 0] = 2 * (x*y + r*z)
101
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
102
+ R[:, 1, 2] = 2 * (y*z - r*x)
103
+ R[:, 2, 0] = 2 * (x*z - r*y)
104
+ R[:, 2, 1] = 2 * (y*z + r*x)
105
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
106
+ return R
107
+
108
+ def build_scaling_rotation(s, r):
109
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
110
+ R = build_rotation(r)
111
+
112
+ L[:,0,0] = s[:,0]
113
+ L[:,1,1] = s[:,1]
114
+ L[:,2,2] = s[:,2]
115
+
116
+ L = R @ L
117
+ return L
118
+
119
+ class BasicPointCloud(NamedTuple):
120
+ points: np.array
121
+ colors: np.array
122
+ normals: np.array
123
+
124
+
125
+ class GaussianModel:
126