kairunwen commited on
Commit
35e2073
1 Parent(s): 57e5e56
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +23 -0
  2. README.md +1 -1
  3. app.py +277 -0
  4. arguments/__init__.py +113 -0
  5. assets/example/TT-family-3-views/000301.jpg +0 -0
  6. assets/example/TT-family-3-views/000388.jpg +0 -0
  7. assets/example/TT-family-3-views/000491.jpg +0 -0
  8. assets/example/dl3dv-ba55-3-views/frame_60.jpg +0 -0
  9. assets/example/dl3dv-ba55-3-views/frame_61.jpg +0 -0
  10. assets/example/dl3dv-ba55-3-views/frame_62.jpg +0 -0
  11. assets/example/sora-santorini-3-views/frame_00.jpg +0 -0
  12. assets/example/sora-santorini-3-views/frame_06.jpg +0 -0
  13. assets/example/sora-santorini-3-views/frame_12.jpg +0 -0
  14. assets/load/.gitkeep +0 -0
  15. coarse_init_infer.py +100 -0
  16. gaussian_renderer/__init__.py +144 -0
  17. gaussian_renderer/__init__3dgs.py +100 -0
  18. gaussian_renderer/network_gui.py +86 -0
  19. lpipsPyTorch/__init__.py +21 -0
  20. lpipsPyTorch/modules/lpips.py +36 -0
  21. lpipsPyTorch/modules/networks.py +96 -0
  22. lpipsPyTorch/modules/utils.py +30 -0
  23. render_by_interp.py +152 -0
  24. requirements.txt +17 -0
  25. scene/__init__.py +96 -0
  26. scene/cameras.py +71 -0
  27. scene/colmap_loader.py +294 -0
  28. scene/dataset_readers.py +363 -0
  29. scene/gaussian_model.py +502 -0
  30. submodules/diff-gaussian-rasterization/.gitignore +3 -0
  31. submodules/diff-gaussian-rasterization/.gitmodules +3 -0
  32. submodules/diff-gaussian-rasterization/CMakeLists.txt +36 -0
  33. submodules/diff-gaussian-rasterization/LICENSE.md +83 -0
  34. submodules/diff-gaussian-rasterization/README.md +19 -0
  35. submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h +175 -0
  36. submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu +657 -0
  37. submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h +65 -0
  38. submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h +19 -0
  39. submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu +455 -0
  40. submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h +66 -0
  41. submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h +88 -0
  42. submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu +434 -0
  43. submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h +74 -0
  44. submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py +221 -0
  45. submodules/diff-gaussian-rasterization/ext.cpp +19 -0
  46. submodules/diff-gaussian-rasterization/rasterize_points.cu +217 -0
  47. submodules/diff-gaussian-rasterization/rasterize_points.h +67 -0
  48. submodules/diff-gaussian-rasterization/setup.py +34 -0
  49. submodules/diff-gaussian-rasterization/third_party/glm/.appveyor.yml +92 -0
  50. submodules/diff-gaussian-rasterization/third_party/glm/.gitignore +61 -0
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.idea/
2
+ /work_dirs*
3
+ .vscode/
4
+ /tmp
5
+ /data
6
+ /checkpoints
7
+ *.so
8
+ *.patch
9
+ __pycache__/
10
+ *.egg-info/
11
+ /viz*
12
+ /submit*
13
+ build/
14
+ *.pyd
15
+ /cache*
16
+ *.stl
17
+ *.pth
18
+ /venv/
19
+ .nk8s
20
+ *.mp4
21
+ .vs
22
+ /exp/
23
+ /dev/
README.md CHANGED
@@ -6,7 +6,7 @@ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.20.1
8
  python_version: 3.10.13
9
- app_file: app_wrapper.py
10
  pinned: false
11
  license: mit
12
  short_description: Sparse-view SFM-free Gaussian Splatting in Seconds
 
6
  sdk: gradio
7
  sdk_version: 4.20.1
8
  python_version: 3.10.13
9
+ app_file: app.py
10
  pinned: false
11
  license: mit
12
  short_description: Sparse-view SFM-free Gaussian Splatting in Seconds
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, subprocess, shlex, sys, gc
2
+ import time
3
+ import torch
4
+ import numpy as np
5
+ import shutil
6
+ import argparse
7
+ import gradio as gr
8
+ import uuid
9
+ import spaces
10
+
11
+ subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
12
+ subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
13
+ subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))
14
+
15
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
17
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
18
+ from dust3r.inference import inference
19
+ from dust3r.model import AsymmetricCroCo3DStereo
20
+ from dust3r.utils.device import to_numpy
21
+ from dust3r.image_pairs import make_pairs
22
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
23
+ from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
24
+
25
+ from argparse import ArgumentParser, Namespace
26
+ from arguments import ModelParams, PipelineParams, OptimizationParams
27
+ from train_joint import training
28
+ from render_by_interp import render_sets
29
+ GRADIO_CACHE_FOLDER = './gradio_cache_folder'
30
+ #############################################################################################################################################
31
+
32
+
33
+ def get_dust3r_args_parser():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
36
+ parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
37
+ parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
38
+ parser.add_argument("--batch_size", type=int, default=1)
39
+ parser.add_argument("--schedule", type=str, default='linear')
40
+ parser.add_argument("--lr", type=float, default=0.01)
41
+ parser.add_argument("--niter", type=int, default=300)
42
+ parser.add_argument("--focal_avg", type=bool, default=True)
43
+ parser.add_argument("--n_views", type=int, default=3)
44
+ parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
45
+ return parser
46
+
47
+
48
+ @spaces.GPU(duration=300)
49
+ def process(inputfiles, input_path=None):
50
+
51
+ if input_path is not None:
52
+ imgs_path = './assets/example/' + input_path
53
+ imgs_names = sorted(os.listdir(imgs_path))
54
+
55
+ inputfiles = []
56
+ for imgs_name in imgs_names:
57
+ file_path = os.path.join(imgs_path, imgs_name)
58
+ print(file_path)
59
+ inputfiles.append(file_path)
60
+ print(inputfiles)
61
+
62
+ # ------ (1) Coarse Geometric Initialization ------
63
+ # os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
64
+ parser = get_dust3r_args_parser()
65
+ opt = parser.parse_args()
66
+
67
+ tmp_user_folder = str(uuid.uuid4()).replace("-", "")
68
+ opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
69
+ img_folder_path = os.path.join(opt.img_base_path, "images")
70
+
71
+ img_folder_path = os.path.join(opt.img_base_path, "images")
72
+ model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
73
+ os.makedirs(img_folder_path, exist_ok=True)
74
+
75
+ opt.n_views = len(inputfiles)
76
+ if opt.n_views == 1:
77
+ raise gr.Error("The number of input images should be greater than 1.")
78
+ print("Multiple images: ", inputfiles)
79
+ for image_path in inputfiles:
80
+ if input_path is not None:
81
+ shutil.copy(image_path, img_folder_path)
82
+ else:
83
+ shutil.move(image_path, img_folder_path)
84
+ train_img_list = sorted(os.listdir(img_folder_path))
85
+ assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
86
+ images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
87
+ resolutions_are_equal = len(set(imgs_resolution)) == 1
88
+ if resolutions_are_equal == False:
89
+ raise gr.Error("The resolution of the input image should be the same.")
90
+ print("ori_size", ori_size)
91
+ start_time = time.time()
92
+ ######################################################
93
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
94
+ output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
95
+ output_colmap_path=img_folder_path.replace("images", "sparse/0")
96
+ os.makedirs(output_colmap_path, exist_ok=True)
97
+
98
+ scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
99
+ loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
100
+ scene = scene.clean_pointcloud()
101
+
102
+ imgs = to_numpy(scene.imgs)
103
+ focals = scene.get_focals()
104
+ poses = to_numpy(scene.get_im_poses())
105
+ pts3d = to_numpy(scene.get_pts3d())
106
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
107
+ confidence_masks = to_numpy(scene.get_masks())
108
+ intrinsics = to_numpy(scene.get_intrinsics())
109
+ ######################################################
110
+ end_time = time.time()
111
+ print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
112
+ save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
113
+ save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
114
+ pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
115
+ color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
116
+ color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
117
+ storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
118
+ pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
119
+ np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
120
+ np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
121
+
122
+ ### save VRAM
123
+ del scene
124
+ torch.cuda.empty_cache()
125
+ gc.collect()
126
+ ##################################################################################################################################################
127
+
128
+ # ------ (2) Fast 3D-Gaussian Optimization ------
129
+ parser = ArgumentParser(description="Training script parameters")
130
+ lp = ModelParams(parser)
131
+ op = OptimizationParams(parser)
132
+ pp = PipelineParams(parser)
133
+ parser.add_argument('--debug_from', type=int, default=-1)
134
+ parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
135
+ parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
136
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
137
+ parser.add_argument("--start_checkpoint", type=str, default = None)
138
+ parser.add_argument("--scene", type=str, default="demo")
139
+ parser.add_argument("--n_views", type=int, default=3)
140
+ parser.add_argument("--get_video", action="store_true")
141
+ parser.add_argument("--optim_pose", type=bool, default=True)
142
+ parser.add_argument("--skip_train", action="store_true")
143
+ parser.add_argument("--skip_test", action="store_true")
144
+ args = parser.parse_args(sys.argv[1:])
145
+ args.save_iterations.append(args.iterations)
146
+ args.model_path = opt.img_base_path + '/output/'
147
+ args.source_path = opt.img_base_path
148
+ # args.model_path = GRADIO_CACHE_FOLDER + '/output/'
149
+ # args.source_path = GRADIO_CACHE_FOLDER
150
+ args.iteration = 1000
151
+ os.makedirs(args.model_path, exist_ok=True)
152
+ training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
153
+ ##################################################################################################################################################
154
+
155
+ # ------ (3) Render video by interpolation ------
156
+ parser = ArgumentParser(description="Testing script parameters")
157
+ model = ModelParams(parser, sentinel=True)
158
+ pipeline = PipelineParams(parser)
159
+ args.eval = True
160
+ args.get_video = True
161
+ args.n_views = opt.n_views
162
+ render_sets(
163
+ model.extract(args),
164
+ args.iteration,
165
+ pipeline.extract(args),
166
+ args.skip_train,
167
+ args.skip_test,
168
+ args,
169
+ )
170
+ output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
171
+ output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
172
+ # output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
173
+ # output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4'
174
+
175
+ return output_video_path, output_ply_path, output_ply_path
176
+ ##################################################################################################################################################
177
+
178
+
179
+
180
+ _TITLE = '''InstantSplat'''
181
+ _DESCRIPTION = '''
182
+ <div style="display: flex; justify-content: center; align-items: center;">
183
+ <div style="width: 100%; text-align: center; font-size: 30px;">
184
+ <strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
185
+ </div>
186
+ </div>
187
+ <p></p>
188
+
189
+ <div align="center">
190
+ <a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>&nbsp;
191
+ <a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>&nbsp;
192
+ <a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
193
+ </div>
194
+ <p></p>
195
+
196
+ * Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
197
+ * Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
198
+ * Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
199
+ '''
200
+
201
+
202
+ # <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a>&nbsp;
203
+ # &nbsp;
204
+ # <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a>
205
+ # * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a>
206
+
207
+
208
+ # block = gr.Blocks(title=_TITLE).queue()
209
+ block = gr.Blocks().queue()
210
+ with block:
211
+ with gr.Row():
212
+ with gr.Column(scale=1):
213
+ # gr.Markdown('# ' + _TITLE)
214
+ gr.Markdown(_DESCRIPTION)
215
+
216
+ with gr.Row(variant='panel'):
217
+ with gr.Tab("Input"):
218
+ inputfiles = gr.File(file_count="multiple", label="images")
219
+ input_path = gr.Textbox(visible=False, label="example_path")
220
+ button_gen = gr.Button("RUN")
221
+
222
+ with gr.Row(variant='panel'):
223
+ with gr.Tab("Output"):
224
+ with gr.Column(scale=2):
225
+ output_model = gr.Model3D(
226
+ label="3D Model (Gaussian)",
227
+ # height=300,
228
+ interactive=False,
229
+ # clear_color=[1.0, 1.0, 1.0, 1.0]
230
+ )
231
+ output_file = gr.File(label="ply")
232
+ with gr.Column(scale=1):
233
+ output_video = gr.Video(label="video")
234
+
235
+ button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model])
236
+
237
+ # gr.Examples(
238
+ # examples=[
239
+ # "sora-santorini-3-views",
240
+ # # "TT-family-3-views",
241
+ # # "dl3dv-ba55-3-views",
242
+ # ],
243
+ # inputs=[input_path],
244
+ # outputs=[output_video, output_file, output_model],
245
+ # fn=lambda x: process(inputfiles=None, input_path=x),
246
+ # cache_examples=True,
247
+ # label='Sparse-view Examples'
248
+ # )
249
+ block.launch(server_name="0.0.0.0", share=False)
250
+
251
+
252
+ # block = gr.Blocks(title=_TITLE).queue()
253
+ # with block:
254
+ # with gr.Row():
255
+ # with gr.Column(scale=1):
256
+ # gr.Markdown('# ' + _TITLE)
257
+ # # gr.Markdown(_DESCRIPTION)
258
+
259
+ # with gr.Row(variant='panel'):
260
+ # with gr.Column(scale=1):
261
+ # with gr.Tab("Input"):
262
+ # inputfiles = gr.File(file_count="multiple", label="images")
263
+ # button_gen = gr.Button("RUN")
264
+
265
+ # with gr.Column(scale=2):
266
+ # with gr.Tab("Output"):
267
+ # output_video = gr.Video(label="video")
268
+ # output_model = gr.Model3D(
269
+ # label="3D Model (Gaussian)",
270
+ # height=300,
271
+ # interactive=False,
272
+ # )
273
+ # output_file = gr.File(label="ply")
274
+
275
+ # button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model])
276
+
277
+ # block.launch(server_name="0.0.0.0", share=False)
arguments/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from argparse import ArgumentParser, Namespace
13
+ import sys
14
+ import os
15
+
16
+ class GroupParams:
17
+ pass
18
+
19
+ class ParamGroup:
20
+ def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21
+ group = parser.add_argument_group(name)
22
+ for key, value in vars(self).items():
23
+ shorthand = False
24
+ if key.startswith("_"):
25
+ shorthand = True
26
+ key = key[1:]
27
+ t = type(value)
28
+ value = value if not fill_none else None
29
+ if shorthand:
30
+ if t == bool:
31
+ group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32
+ else:
33
+ group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34
+ else:
35
+ if t == bool:
36
+ group.add_argument("--" + key, default=value, action="store_true")
37
+ else:
38
+ group.add_argument("--" + key, default=value, type=t)
39
+
40
+ def extract(self, args):
41
+ group = GroupParams()
42
+ for arg in vars(args).items():
43
+ if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44
+ setattr(group, arg[0], arg[1])
45
+ return group
46
+
47
+ class ModelParams(ParamGroup):
48
+ def __init__(self, parser, sentinel=False):
49
+ self.sh_degree = 3
50
+ self._source_path = ""
51
+ self._model_path = ""
52
+ self._images = "images"
53
+ self._resolution = -1
54
+ self._white_background = False
55
+ self.data_device = "cuda"
56
+ self.eval = False
57
+ super().__init__(parser, "Loading Parameters", sentinel)
58
+
59
+ def extract(self, args):
60
+ g = super().extract(args)
61
+ g.source_path = os.path.abspath(g.source_path)
62
+ return g
63
+
64
+ class PipelineParams(ParamGroup):
65
+ def __init__(self, parser):
66
+ self.convert_SHs_python = False
67
+ self.compute_cov3D_python = False
68
+ self.debug = False
69
+ super().__init__(parser, "Pipeline Parameters")
70
+
71
+ class OptimizationParams(ParamGroup):
72
+ def __init__(self, parser):
73
+ self.iterations = 1000
74
+ # self.iterations = 30_000
75
+ self.position_lr_init = 0.00016
76
+ self.position_lr_final = 0.0000016
77
+ self.position_lr_delay_mult = 0.01
78
+ self.position_lr_max_steps = 30_000
79
+ self.feature_lr = 0.0025
80
+ self.opacity_lr = 0.05
81
+ self.scaling_lr = 0.005
82
+ self.rotation_lr = 0.001
83
+ self.percent_dense = 0.01
84
+ self.lambda_dssim = 0.2
85
+ self.densification_interval = 100
86
+ self.opacity_reset_interval = 3000
87
+ self.densify_from_iter = 500
88
+ self.densify_until_iter = 15_000
89
+ self.densify_grad_threshold = 0.0002
90
+ self.random_background = False
91
+ super().__init__(parser, "Optimization Parameters")
92
+
93
+ def get_combined_args(parser : ArgumentParser):
94
+ cmdlne_string = sys.argv[1:]
95
+ cfgfile_string = "Namespace()"
96
+ args_cmdline = parser.parse_args(cmdlne_string)
97
+
98
+ try:
99
+ cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
100
+ print("Looking for config file in", cfgfilepath)
101
+ with open(cfgfilepath) as cfg_file:
102
+ print("Config file found: {}".format(cfgfilepath))
103
+ cfgfile_string = cfg_file.read()
104
+ except TypeError:
105
+ print("Config file not found at")
106
+ pass
107
+ args_cfgfile = eval(cfgfile_string)
108
+
109
+ merged_dict = vars(args_cfgfile).copy()
110
+ for k,v in vars(args_cmdline).items():
111
+ if v != None:
112
+ merged_dict[k] = v
113
+ return Namespace(**merged_dict)
assets/example/TT-family-3-views/000301.jpg ADDED
assets/example/TT-family-3-views/000388.jpg ADDED
assets/example/TT-family-3-views/000491.jpg ADDED
assets/example/dl3dv-ba55-3-views/frame_60.jpg ADDED
assets/example/dl3dv-ba55-3-views/frame_61.jpg ADDED
assets/example/dl3dv-ba55-3-views/frame_62.jpg ADDED
assets/example/sora-santorini-3-views/frame_00.jpg ADDED
assets/example/sora-santorini-3-views/frame_06.jpg ADDED
assets/example/sora-santorini-3-views/frame_12.jpg ADDED
assets/load/.gitkeep ADDED
File without changes
coarse_init_infer.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+ import numpy as np
5
+ import argparse
6
+ import time
7
+
8
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
+ os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
10
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
11
+
12
+ from dust3r.inference import inference
13
+ from dust3r.model import AsymmetricCroCo3DStereo
14
+ from dust3r.utils.device import to_numpy
15
+ from dust3r.image_pairs import make_pairs
16
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
17
+ from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
18
+
19
+ def get_args_parser():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
22
+ # parser.add_argument("--model_path", type=str, default="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
23
+ parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
24
+ parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
25
+ parser.add_argument("--batch_size", type=int, default=1)
26
+ parser.add_argument("--schedule", type=str, default='linear')
27
+ parser.add_argument("--lr", type=float, default=0.01)
28
+ parser.add_argument("--niter", type=int, default=300)
29
+ parser.add_argument("--focal_avg", action="store_true")
30
+ # parser.add_argument("--focal_avg", type=bool, default=True)
31
+
32
+ parser.add_argument("--llffhold", type=int, default=2)
33
+ parser.add_argument("--n_views", type=int, default=12)
34
+ parser.add_argument("--img_base_path", type=str, default="/home/workspace/datasets/instantsplat/Tanks/Barn/24_views")
35
+
36
+ return parser
37
+
38
+ if __name__ == '__main__':
39
+
40
+ parser = get_args_parser()
41
+ args = parser.parse_args()
42
+
43
+ model_path = args.model_path
44
+ device = args.device
45
+ batch_size = args.batch_size
46
+ schedule = args.schedule
47
+ lr = args.lr
48
+ niter = args.niter
49
+ n_views = args.n_views
50
+ img_base_path = args.img_base_path
51
+ img_folder_path = os.path.join(img_base_path, "images")
52
+ os.makedirs(img_folder_path, exist_ok=True)
53
+ model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(device)
54
+ ##########################################################################################################################################################################################
55
+
56
+ train_img_list = sorted(os.listdir(img_folder_path))
57
+ assert len(train_img_list)==n_views, f"Number of images ({len(train_img_list)}) in the folder ({img_folder_path}) is not equal to {n_views}"
58
+
59
+ # if len(os.listdir(img_folder_path)) != len(train_img_list):
60
+ # for img_name in train_img_list:
61
+ # src_path = os.path.join(img_base_path, "images", img_name)
62
+ # tgt_path = os.path.join(img_folder_path, img_name)
63
+ # print(src_path, tgt_path)
64
+ # shutil.copy(src_path, tgt_path)
65
+ images, ori_size = load_images(img_folder_path, size=512)
66
+ print("ori_size", ori_size)
67
+
68
+ start_time = time.time()
69
+ ##########################################################################################################################################################################################
70
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
71
+ output = inference(pairs, model, args.device, batch_size=batch_size)
72
+ output_colmap_path=img_folder_path.replace("images", "sparse/0")
73
+ os.makedirs(output_colmap_path, exist_ok=True)
74
+
75
+ scene = global_aligner(output, device=args.device, mode=GlobalAlignerMode.PointCloudOptimizer)
76
+ loss = compute_global_alignment(scene=scene, init="mst", niter=niter, schedule=schedule, lr=lr, focal_avg=args.focal_avg)
77
+ scene = scene.clean_pointcloud()
78
+
79
+ imgs = to_numpy(scene.imgs)
80
+ focals = scene.get_focals()
81
+ poses = to_numpy(scene.get_im_poses())
82
+ pts3d = to_numpy(scene.get_pts3d())
83
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
84
+ confidence_masks = to_numpy(scene.get_masks())
85
+ intrinsics = to_numpy(scene.get_intrinsics())
86
+ ##########################################################################################################################################################################################
87
+ end_time = time.time()
88
+ print(f"Time taken for {n_views} views: {end_time-start_time} seconds")
89
+
90
+ # save
91
+ save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
92
+ save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
93
+
94
+ pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
95
+ color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
96
+ color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
97
+ storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
98
+ pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
99
+ np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
100
+ np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
gaussian_renderer/__init__.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from diff_gaussian_rasterization import (
15
+ GaussianRasterizationSettings,
16
+ GaussianRasterizer,
17
+ )
18
+ from scene.gaussian_model import GaussianModel
19
+ from utils.sh_utils import eval_sh
20
+ from utils.pose_utils import get_camera_from_tensor, quadmultiply
21
+
22
+
23
+ def render(
24
+ viewpoint_camera,
25
+ pc: GaussianModel,
26
+ pipe,
27
+ bg_color: torch.Tensor,
28
+ scaling_modifier=1.0,
29
+ override_color=None,
30
+ camera_pose=None,
31
+ ):
32
+ """
33
+ Render the scene.
34
+
35
+ Background tensor (bg_color) must be on GPU!
36
+ """
37
+
38
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
39
+ screenspace_points = (
40
+ torch.zeros_like(
41
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
42
+ )
43
+ + 0
44
+ )
45
+ try:
46
+ screenspace_points.retain_grad()
47
+ except:
48
+ pass
49
+
50
+ # Set up rasterization configuration
51
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
52
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
53
+
54
+ # Set camera pose as identity. Then, we will transform the Gaussians around camera_pose
55
+ w2c = torch.eye(4).cuda()
56
+ projmatrix = (
57
+ w2c.unsqueeze(0).bmm(viewpoint_camera.projection_matrix.unsqueeze(0))
58
+ ).squeeze(0)
59
+ camera_pos = w2c.inverse()[3, :3]
60
+ raster_settings = GaussianRasterizationSettings(
61
+ image_height=int(viewpoint_camera.image_height),
62
+ image_width=int(viewpoint_camera.image_width),
63
+ tanfovx=tanfovx,
64
+ tanfovy=tanfovy,
65
+ bg=bg_color,
66
+ scale_modifier=scaling_modifier,
67
+ # viewmatrix=viewpoint_camera.world_view_transform,
68
+ # projmatrix=viewpoint_camera.full_proj_transform,
69
+ viewmatrix=w2c,
70
+ projmatrix=projmatrix,
71
+ sh_degree=pc.active_sh_degree,
72
+ # campos=viewpoint_camera.camera_center,
73
+ campos=camera_pos,
74
+ prefiltered=False,
75
+ debug=pipe.debug,
76
+ )
77
+
78
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
79
+
80
+ # means3D = pc.get_xyz
81
+ rel_w2c = get_camera_from_tensor(camera_pose)
82
+ # Transform mean and rot of Gaussians to camera frame
83
+ gaussians_xyz = pc._xyz.clone()
84
+ gaussians_rot = pc._rotation.clone()
85
+
86
+ xyz_ones = torch.ones(gaussians_xyz.shape[0], 1).cuda().float()
87
+ xyz_homo = torch.cat((gaussians_xyz, xyz_ones), dim=1)
88
+ gaussians_xyz_trans = (rel_w2c @ xyz_homo.T).T[:, :3]
89
+ gaussians_rot_trans = quadmultiply(camera_pose[:4], gaussians_rot)
90
+ means3D = gaussians_xyz_trans
91
+ means2D = screenspace_points
92
+ opacity = pc.get_opacity
93
+
94
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
95
+ # scaling / rotation by the rasterizer.
96
+ scales = None
97
+ rotations = None
98
+ cov3D_precomp = None
99
+ if pipe.compute_cov3D_python:
100
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
101
+ else:
102
+ scales = pc.get_scaling
103
+ rotations = gaussians_rot_trans # pc.get_rotation
104
+
105
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
106
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
107
+ shs = None
108
+ colors_precomp = None
109
+ if override_color is None:
110
+ if pipe.convert_SHs_python:
111
+ shs_view = pc.get_features.transpose(1, 2).view(
112
+ -1, 3, (pc.max_sh_degree + 1) ** 2
113
+ )
114
+ dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
115
+ pc.get_features.shape[0], 1
116
+ )
117
+ dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
118
+ sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
119
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
120
+ else:
121
+ shs = pc.get_features
122
+ else:
123
+ colors_precomp = override_color
124
+
125
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
126
+ rendered_image, radii = rasterizer(
127
+ means3D=means3D,
128
+ means2D=means2D,
129
+ shs=shs,
130
+ colors_precomp=colors_precomp,
131
+ opacities=opacity,
132
+ scales=scales,
133
+ rotations=rotations,
134
+ cov3D_precomp=cov3D_precomp,
135
+ )
136
+
137
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
138
+ # They will be excluded from value updates used in the splitting criteria.
139
+ return {
140
+ "render": rendered_image,
141
+ "viewspace_points": screenspace_points,
142
+ "visibility_filter": radii > 0,
143
+ "radii": radii,
144
+ }
gaussian_renderer/__init__3dgs.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15
+ from scene.gaussian_model import GaussianModel
16
+ from utils.sh_utils import eval_sh
17
+
18
+ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
19
+ """
20
+ Render the scene.
21
+
22
+ Background tensor (bg_color) must be on GPU!
23
+ """
24
+
25
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
26
+ screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
27
+ try:
28
+ screenspace_points.retain_grad()
29
+ except:
30
+ pass
31
+
32
+ # Set up rasterization configuration
33
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
34
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
35
+
36
+ raster_settings = GaussianRasterizationSettings(
37
+ image_height=int(viewpoint_camera.image_height),
38
+ image_width=int(viewpoint_camera.image_width),
39
+ tanfovx=tanfovx,
40
+ tanfovy=tanfovy,
41
+ bg=bg_color,
42
+ scale_modifier=scaling_modifier,
43
+ viewmatrix=viewpoint_camera.world_view_transform,
44
+ projmatrix=viewpoint_camera.full_proj_transform,
45
+ sh_degree=pc.active_sh_degree,
46
+ campos=viewpoint_camera.camera_center,
47
+ prefiltered=False,
48
+ debug=pipe.debug
49
+ )
50
+
51
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
52
+
53
+ means3D = pc.get_xyz
54
+ means2D = screenspace_points
55
+ opacity = pc.get_opacity
56
+
57
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
58
+ # scaling / rotation by the rasterizer.
59
+ scales = None
60
+ rotations = None
61
+ cov3D_precomp = None
62
+ if pipe.compute_cov3D_python:
63
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
64
+ else:
65
+ scales = pc.get_scaling
66
+ rotations = pc.get_rotation
67
+
68
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
69
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
70
+ shs = None
71
+ colors_precomp = None
72
+ if override_color is None:
73
+ if pipe.convert_SHs_python:
74
+ shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
75
+ dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
76
+ dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
77
+ sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
78
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
79
+ else:
80
+ shs = pc.get_features
81
+ else:
82
+ colors_precomp = override_color
83
+
84
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
85
+ rendered_image, radii = rasterizer(
86
+ means3D = means3D,
87
+ means2D = means2D,
88
+ shs = shs,
89
+ colors_precomp = colors_precomp,
90
+ opacities = opacity,
91
+ scales = scales,
92
+ rotations = rotations,
93
+ cov3D_precomp = cov3D_precomp)
94
+
95
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
96
+ # They will be excluded from value updates used in the splitting criteria.
97
+ return {"render": rendered_image,
98
+ "viewspace_points": screenspace_points,
99
+ "visibility_filter" : radii > 0,
100
+ "radii": radii}
gaussian_renderer/network_gui.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import traceback
14
+ import socket
15
+ import json
16
+ from scene.cameras import MiniCam
17
+
18
+ host = "127.0.0.1"
19
+ port = 6009
20
+
21
+ conn = None
22
+ addr = None
23
+
24
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25
+
26
+ def init(wish_host, wish_port):
27
+ global host, port, listener
28
+ host = wish_host
29
+ port = wish_port
30
+ listener.bind((host, port))
31
+ listener.listen()
32
+ listener.settimeout(0)
33
+
34
+ def try_connect():
35
+ global conn, addr, listener
36
+ try:
37
+ conn, addr = listener.accept()
38
+ print(f"\nConnected by {addr}")
39
+ conn.settimeout(None)
40
+ except Exception as inst:
41
+ pass
42
+
43
+ def read():
44
+ global conn
45
+ messageLength = conn.recv(4)
46
+ messageLength = int.from_bytes(messageLength, 'little')
47
+ message = conn.recv(messageLength)
48
+ return json.loads(message.decode("utf-8"))
49
+
50
+ def send(message_bytes, verify):
51
+ global conn
52
+ if message_bytes != None:
53
+ conn.sendall(message_bytes)
54
+ conn.sendall(len(verify).to_bytes(4, 'little'))
55
+ conn.sendall(bytes(verify, 'ascii'))
56
+
57
+ def receive():
58
+ message = read()
59
+
60
+ width = message["resolution_x"]
61
+ height = message["resolution_y"]
62
+
63
+ if width != 0 and height != 0:
64
+ try:
65
+ do_training = bool(message["train"])
66
+ fovy = message["fov_y"]
67
+ fovx = message["fov_x"]
68
+ znear = message["z_near"]
69
+ zfar = message["z_far"]
70
+ do_shs_python = bool(message["shs_python"])
71
+ do_rot_scale_python = bool(message["rot_scale_python"])
72
+ keep_alive = bool(message["keep_alive"])
73
+ scaling_modifier = message["scaling_modifier"]
74
+ world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
75
+ world_view_transform[:,1] = -world_view_transform[:,1]
76
+ world_view_transform[:,2] = -world_view_transform[:,2]
77
+ full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
78
+ full_proj_transform[:,1] = -full_proj_transform[:,1]
79
+ custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
80
+ except Exception as e:
81
+ print("")
82
+ traceback.print_exc()
83
+ raise e
84
+ return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
85
+ else:
86
+ return None, None, None, None, None, None
lpipsPyTorch/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .modules.lpips import LPIPS
4
+
5
+
6
+ def lpips(x: torch.Tensor,
7
+ y: torch.Tensor,
8
+ net_type: str = 'alex',
9
+ version: str = '0.1'):
10
+ r"""Function that measures
11
+ Learned Perceptual Image Patch Similarity (LPIPS).
12
+
13
+ Arguments:
14
+ x, y (torch.Tensor): the input tensors to compare.
15
+ net_type (str): the network type to compare the features:
16
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17
+ version (str): the version of LPIPS. Default: 0.1.
18
+ """
19
+ device = x.device
20
+ criterion = LPIPS(net_type, version).to(device)
21
+ return criterion(x, y)
lpipsPyTorch/modules/lpips.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .networks import get_network, LinLayers
5
+ from .utils import get_state_dict
6
+
7
+
8
+ class LPIPS(nn.Module):
9
+ r"""Creates a criterion that measures
10
+ Learned Perceptual Image Patch Similarity (LPIPS).
11
+
12
+ Arguments:
13
+ net_type (str): the network type to compare the features:
14
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15
+ version (str): the version of LPIPS. Default: 0.1.
16
+ """
17
+ def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18
+
19
+ assert version in ['0.1'], 'v0.1 is only supported now'
20
+
21
+ super(LPIPS, self).__init__()
22
+
23
+ # pretrained network
24
+ self.net = get_network(net_type)
25
+
26
+ # linear layers
27
+ self.lin = LinLayers(self.net.n_channels_list)
28
+ self.lin.load_state_dict(get_state_dict(net_type, version))
29
+
30
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
31
+ feat_x, feat_y = self.net(x), self.net(y)
32
+
33
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35
+
36
+ return torch.sum(torch.cat(res, 0), 0, True)
lpipsPyTorch/modules/networks.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+
3
+ from itertools import chain
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from .utils import normalize_activation
10
+
11
+
12
+ def get_network(net_type: str):
13
+ if net_type == 'alex':
14
+ return AlexNet()
15
+ elif net_type == 'squeeze':
16
+ return SqueezeNet()
17
+ elif net_type == 'vgg':
18
+ return VGG16()
19
+ else:
20
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21
+
22
+
23
+ class LinLayers(nn.ModuleList):
24
+ def __init__(self, n_channels_list: Sequence[int]):
25
+ super(LinLayers, self).__init__([
26
+ nn.Sequential(
27
+ nn.Identity(),
28
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29
+ ) for nc in n_channels_list
30
+ ])
31
+
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+
36
+ class BaseNet(nn.Module):
37
+ def __init__(self):
38
+ super(BaseNet, self).__init__()
39
+
40
+ # register buffer
41
+ self.register_buffer(
42
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43
+ self.register_buffer(
44
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45
+
46
+ def set_requires_grad(self, state: bool):
47
+ for param in chain(self.parameters(), self.buffers()):
48
+ param.requires_grad = state
49
+
50
+ def z_score(self, x: torch.Tensor):
51
+ return (x - self.mean) / self.std
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = self.z_score(x)
55
+
56
+ output = []
57
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58
+ x = layer(x)
59
+ if i in self.target_layers:
60
+ output.append(normalize_activation(x))
61
+ if len(output) == len(self.target_layers):
62
+ break
63
+ return output
64
+
65
+
66
+ class SqueezeNet(BaseNet):
67
+ def __init__(self):
68
+ super(SqueezeNet, self).__init__()
69
+
70
+ self.layers = models.squeezenet1_1(True).features
71
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73
+
74
+ self.set_requires_grad(False)
75
+
76
+
77
+ class AlexNet(BaseNet):
78
+ def __init__(self):
79
+ super(AlexNet, self).__init__()
80
+
81
+ self.layers = models.alexnet(True).features
82
+ self.target_layers = [2, 5, 8, 10, 12]
83
+ self.n_channels_list = [64, 192, 384, 256, 256]
84
+
85
+ self.set_requires_grad(False)
86
+
87
+
88
+ class VGG16(BaseNet):
89
+ def __init__(self):
90
+ super(VGG16, self).__init__()
91
+
92
+ self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
93
+ self.target_layers = [4, 9, 16, 23, 30]
94
+ self.n_channels_list = [64, 128, 256, 512, 512]
95
+
96
+ self.set_requires_grad(False)
lpipsPyTorch/modules/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_activation(x, eps=1e-10):
7
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8
+ return x / (norm_factor + eps)
9
+
10
+
11
+ def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12
+ # build url
13
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
15
+
16
+ # download
17
+ old_state_dict = torch.hub.load_state_dict_from_url(
18
+ url, progress=True,
19
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
20
+ )
21
+
22
+ # rename keys
23
+ new_state_dict = OrderedDict()
24
+ for key, val in old_state_dict.items():
25
+ new_key = key
26
+ new_key = new_key.replace('lin', '')
27
+ new_key = new_key.replace('model.', '')
28
+ new_state_dict[new_key] = val
29
+
30
+ return new_state_dict
render_by_interp.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ from scene import Scene
14
+ import os
15
+ from tqdm import tqdm
16
+ from os import makedirs
17
+ from gaussian_renderer import render
18
+ import torchvision
19
+ from utils.general_utils import safe_state
20
+ from argparse import ArgumentParser
21
+ from arguments import ModelParams, PipelineParams, get_combined_args
22
+ from gaussian_renderer import GaussianModel
23
+ from utils.pose_utils import get_tensor_from_camera
24
+ from utils.camera_utils import generate_interpolated_path
25
+ from utils.camera_utils import visualizer
26
+ import cv2
27
+ import numpy as np
28
+ import imageio
29
+
30
+
31
+ def save_interpolate_pose(model_path, iter, n_views):
32
+
33
+ org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
34
+ # visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
35
+ # n_interp = int(10 * 30 / n_views) # 10second, fps=30
36
+ n_interp = int(5 * 30 / n_views) # 5second, fps=30
37
+ all_inter_pose = []
38
+ for i in range(n_views-1):
39
+ tmp_inter_pose = generate_interpolated_path(poses=org_pose[i:i+2], n_interp=n_interp)
40
+ all_inter_pose.append(tmp_inter_pose)
41
+ all_inter_pose = np.array(all_inter_pose).reshape(-1, 3, 4)
42
+
43
+ inter_pose_list = []
44
+ for p in all_inter_pose:
45
+ tmp_view = np.eye(4)
46
+ tmp_view[:3, :3] = p[:3, :3]
47
+ tmp_view[:3, 3] = p[:3, 3]
48
+ inter_pose_list.append(tmp_view)
49
+ inter_pose = np.stack(inter_pose_list, 0)
50
+ # visualizer(inter_pose, ["blue" for _ in inter_pose], model_path + "pose/poses_interpolated.png")
51
+ np.save(model_path + "pose/pose_interpolated.npy", inter_pose)
52
+
53
+
54
+ def images_to_video(image_folder, output_video_path, fps=30):
55
+ """
56
+ Convert images in a folder to a video.
57
+
58
+ Args:
59
+ - image_folder (str): The path to the folder containing the images.
60
+ - output_video_path (str): The path where the output video will be saved.
61
+ - fps (int): Frames per second for the output video.
62
+ """
63
+ images = []
64
+
65
+ for filename in sorted(os.listdir(image_folder)):
66
+ if filename.endswith(('.png', '.jpg', '.jpeg', '.JPG', '.PNG')):
67
+ image_path = os.path.join(image_folder, filename)
68
+ image = imageio.imread(image_path)
69
+ images.append(image)
70
+
71
+ imageio.mimwrite(output_video_path, images, fps=fps)
72
+
73
+
74
+ def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
75
+ render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
76
+ makedirs(render_path, exist_ok=True)
77
+
78
+ # for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
79
+ for idx, view in enumerate(views):
80
+ camera_pose = get_tensor_from_camera(view.world_view_transform.transpose(0, 1))
81
+ rendering = render(
82
+ view, gaussians, pipeline, background, camera_pose=camera_pose
83
+ )["render"]
84
+ gt = view.original_image[0:3, :, :]
85
+ torchvision.utils.save_image(
86
+ rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
87
+ )
88
+
89
+
90
+ def render_sets(
91
+ dataset: ModelParams,
92
+ iteration: int,
93
+ pipeline: PipelineParams,
94
+ skip_train: bool,
95
+ skip_test: bool,
96
+ args,
97
+ ):
98
+
99
+ # Applying interpolation
100
+ save_interpolate_pose(dataset.model_path, iteration, args.n_views)
101
+
102
+ with torch.no_grad():
103
+ gaussians = GaussianModel(dataset.sh_degree)
104
+ scene = Scene(dataset, gaussians, load_iteration=iteration, opt=args, shuffle=False)
105
+
106
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
107
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
108
+
109
+ # render interpolated views
110
+ render_set(
111
+ dataset.model_path,
112
+ "interp",
113
+ scene.loaded_iter,
114
+ scene.getTrainCameras(),
115
+ gaussians,
116
+ pipeline,
117
+ background,
118
+ )
119
+
120
+ if args.get_video:
121
+ image_folder = os.path.join(dataset.model_path, f'interp/ours_{args.iteration}/renders')
122
+ output_video_file = os.path.join(dataset.model_path, f'{args.scene}_{args.n_views}_view.mp4')
123
+ images_to_video(image_folder, output_video_file, fps=30)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ # Set up command line argument parser
128
+ parser = ArgumentParser(description="Testing script parameters")
129
+ model = ModelParams(parser, sentinel=True)
130
+ pipeline = PipelineParams(parser)
131
+ parser.add_argument("--iteration", default=-1, type=int)
132
+ parser.add_argument("--skip_train", action="store_true")
133
+ parser.add_argument("--skip_test", action="store_true")
134
+ parser.add_argument("--quiet", action="store_true")
135
+
136
+ parser.add_argument("--get_video", action="store_true")
137
+ parser.add_argument("--n_views", default=None, type=int)
138
+ parser.add_argument("--scene", default=None, type=str)
139
+ args = get_combined_args(parser)
140
+ print("Rendering " + args.model_path)
141
+
142
+ # Initialize system state (RNG)
143
+ # safe_state(args.quiet)
144
+
145
+ render_sets(
146
+ model.extract(args),
147
+ args.iteration,
148
+ pipeline.extract(args),
149
+ args.skip_train,
150
+ args.skip_test,
151
+ args,
152
+ )
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ torchvision
3
+ roma
4
+ evo
5
+ gradio==5.0.1
6
+ matplotlib
7
+ tqdm
8
+ opencv-python
9
+ scipy
10
+ einops
11
+ trimesh
12
+ tensorboard
13
+ pyglet<2
14
+ huggingface-hub[torch]>=0.22
15
+ plyfile
16
+ imageio[ffmpeg]
17
+ spaces
scene/__init__.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import os
13
+ import random
14
+ import json
15
+ from utils.system_utils import searchForMaxIteration
16
+ from scene.dataset_readers import sceneLoadTypeCallbacks
17
+ from scene.gaussian_model import GaussianModel
18
+ from arguments import ModelParams
19
+ from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
20
+
21
+ class Scene:
22
+
23
+ gaussians : GaussianModel
24
+
25
+ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, opt=None, shuffle=True, resolution_scales=[1.0]):
26
+ """b
27
+ :param path: Path to colmap scene main folder.
28
+ """
29
+ self.model_path = args.model_path
30
+ self.loaded_iter = None
31
+ self.gaussians = gaussians
32
+
33
+ if load_iteration:
34
+ if load_iteration == -1:
35
+ self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
36
+ else:
37
+ self.loaded_iter = load_iteration
38
+ print("Loading trained model at iteration {}".format(self.loaded_iter))
39
+
40
+ self.train_cameras = {}
41
+ self.test_cameras = {}
42
+
43
+ if os.path.exists(os.path.join(args.source_path, "sparse")):
44
+ scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args, opt)
45
+ elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
46
+ print("Found transforms_train.json file, assuming Blender data set!")
47
+ scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
48
+ else:
49
+ assert False, "Could not recognize scene type!"
50
+
51
+ if not self.loaded_iter:
52
+ with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
53
+ dest_file.write(src_file.read())
54
+ json_cams = []
55
+ camlist = []
56
+ if scene_info.test_cameras:
57
+ camlist.extend(scene_info.test_cameras)
58
+ if scene_info.train_cameras:
59
+ camlist.extend(scene_info.train_cameras)
60
+ for id, cam in enumerate(camlist):
61
+ json_cams.append(camera_to_JSON(id, cam))
62
+ with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
63
+ json.dump(json_cams, file)
64
+
65
+ if shuffle:
66
+ random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
67
+ random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
68
+
69
+ self.cameras_extent = scene_info.nerf_normalization["radius"]
70
+
71
+ for resolution_scale in resolution_scales:
72
+ print("Loading Training Cameras")
73
+ self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
74
+ print('train_camera_num: ', len(self.train_cameras[resolution_scale]))
75
+ print("Loading Test Cameras")
76
+ self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
77
+ print('test_camera_num: ', len(self.test_cameras[resolution_scale]))
78
+
79
+ if self.loaded_iter:
80
+ self.gaussians.load_ply(os.path.join(self.model_path,
81
+ "point_cloud",
82
+ "iteration_" + str(self.loaded_iter),
83
+ "point_cloud.ply"))
84
+ else:
85
+ self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
86
+ self.gaussians.init_RT_seq(self.train_cameras)
87
+
88
+ def save(self, iteration):
89
+ point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
90
+ self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
91
+
92
+ def getTrainCameras(self, scale=1.0):
93
+ return self.train_cameras[scale]
94
+
95
+ def getTestCameras(self, scale=1.0):
96
+ return self.test_cameras[scale]
scene/cameras.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ from torch import nn
14
+ import numpy as np
15
+ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
16
+
17
+ class Camera(nn.Module):
18
+ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
19
+ image_name, uid,
20
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
21
+ ):
22
+ super(Camera, self).__init__()
23
+
24
+ self.uid = uid
25
+ self.colmap_id = colmap_id
26
+ self.R = R
27
+ self.T = T
28
+ self.FoVx = FoVx
29
+ self.FoVy = FoVy
30
+ self.image_name = image_name
31
+
32
+ try:
33
+ self.data_device = torch.device(data_device)
34
+ except Exception as e:
35
+ print(e)
36
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
37
+ self.data_device = torch.device("cuda")
38
+
39
+ self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40
+ self.image_width = self.original_image.shape[2]
41
+ self.image_height = self.original_image.shape[1]
42
+
43
+ if gt_alpha_mask is not None:
44
+ self.original_image *= gt_alpha_mask.to(self.data_device)
45
+ else:
46
+ self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
47
+
48
+ self.zfar = 100.0
49
+ self.znear = 0.01
50
+
51
+ self.trans = trans
52
+ self.scale = scale
53
+
54
+ self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
55
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
56
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
57
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
58
+
59
+ class MiniCam:
60
+ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
61
+ self.image_width = width
62
+ self.image_height = height
63
+ self.FoVy = fovy
64
+ self.FoVx = fovx
65
+ self.znear = znear
66
+ self.zfar = zfar
67
+ self.world_view_transform = world_view_transform
68
+ self.full_proj_transform = full_proj_transform
69
+ view_inv = torch.inverse(self.world_view_transform)
70
+ self.camera_center = view_inv[3][:3]
71
+
scene/colmap_loader.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import numpy as np
13
+ import collections
14
+ import struct
15
+
16
+ CameraModel = collections.namedtuple(
17
+ "CameraModel", ["model_id", "model_name", "num_params"])
18
+ Camera = collections.namedtuple(
19
+ "Camera", ["id", "model", "width", "height", "params"])
20
+ BaseImage = collections.namedtuple(
21
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22
+ Point3D = collections.namedtuple(
23
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24
+ CAMERA_MODELS = {
25
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
33
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36
+ }
37
+ CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38
+ for camera_model in CAMERA_MODELS])
39
+ CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40
+ for camera_model in CAMERA_MODELS])
41
+
42
+
43
+ def qvec2rotmat(qvec):
44
+ return np.array([
45
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54
+
55
+ def rotmat2qvec(R):
56
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
57
+ K = np.array([
58
+ [Rxx - Ryy - Rzz, 0, 0, 0],
59
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
60
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
61
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
62
+ eigvals, eigvecs = np.linalg.eigh(K)
63
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
64
+ if qvec[0] < 0:
65
+ qvec *= -1
66
+ return qvec
67
+
68
+ class Image(BaseImage):
69
+ def qvec2rotmat(self):
70
+ return qvec2rotmat(self.qvec)
71
+
72
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
73
+ """Read and unpack the next bytes from a binary file.
74
+ :param fid:
75
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
76
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
77
+ :param endian_character: Any of {@, =, <, >, !}
78
+ :return: Tuple of read and unpacked values.
79
+ """
80
+ data = fid.read(num_bytes)
81
+ return struct.unpack(endian_character + format_char_sequence, data)
82
+
83
+ def read_points3D_text(path):
84
+ """
85
+ see: src/base/reconstruction.cc
86
+ void Reconstruction::ReadPoints3DText(const std::string& path)
87
+ void Reconstruction::WritePoints3DText(const std::string& path)
88
+ """
89
+ xyzs = None
90
+ rgbs = None
91
+ errors = None
92
+ num_points = 0
93
+ with open(path, "r") as fid:
94
+ while True:
95
+ line = fid.readline()
96
+ if not line:
97
+ break
98
+ line = line.strip()
99
+ if len(line) > 0 and line[0] != "#":
100
+ num_points += 1
101
+
102
+
103
+ xyzs = np.empty((num_points, 3))
104
+ rgbs = np.empty((num_points, 3))
105
+ errors = np.empty((num_points, 1))
106
+ count = 0
107
+ with open(path, "r") as fid:
108
+ while True:
109
+ line = fid.readline()
110
+ if not line:
111
+ break
112
+ line = line.strip()
113
+ if len(line) > 0 and line[0] != "#":
114
+ elems = line.split()
115
+ xyz = np.array(tuple(map(float, elems[1:4])))
116
+ rgb = np.array(tuple(map(int, elems[4:7])))
117
+ error = np.array(float(elems[7]))
118
+ xyzs[count] = xyz
119
+ rgbs[count] = rgb
120
+ errors[count] = error
121
+ count += 1
122
+
123
+ return xyzs, rgbs, errors
124
+
125
+ def read_points3D_binary(path_to_model_file):
126
+ """
127
+ see: src/base/reconstruction.cc
128
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
129
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
130
+ """
131
+
132
+
133
+ with open(path_to_model_file, "rb") as fid:
134
+ num_points = read_next_bytes(fid, 8, "Q")[0]
135
+
136
+ xyzs = np.empty((num_points, 3))
137
+ rgbs = np.empty((num_points, 3))
138
+ errors = np.empty((num_points, 1))
139
+
140
+ for p_id in range(num_points):
141
+ binary_point_line_properties = read_next_bytes(
142
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
143
+ xyz = np.array(binary_point_line_properties[1:4])
144
+ rgb = np.array(binary_point_line_properties[4:7])
145
+ error = np.array(binary_point_line_properties[7])
146
+ track_length = read_next_bytes(
147
+ fid, num_bytes=8, format_char_sequence="Q")[0]
148
+ track_elems = read_next_bytes(
149
+ fid, num_bytes=8*track_length,
150
+ format_char_sequence="ii"*track_length)
151
+ xyzs[p_id] = xyz
152
+ rgbs[p_id] = rgb
153
+ errors[p_id] = error
154
+ return xyzs, rgbs, errors
155
+
156
+ def read_intrinsics_text(path):
157
+ """
158
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
159
+ """
160
+ cameras = {}
161
+ with open(path, "r") as fid:
162
+ while True:
163
+ line = fid.readline()
164
+ if not line:
165
+ break
166
+ line = line.strip()
167
+ if len(line) > 0 and line[0] != "#":
168
+ elems = line.split()
169
+ camera_id = int(elems[0])
170
+ model = elems[1]
171
+ assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
172
+ width = int(elems[2])
173
+ height = int(elems[3])
174
+ params = np.array(tuple(map(float, elems[4:])))
175
+ cameras[camera_id] = Camera(id=camera_id, model=model,
176
+ width=width, height=height,
177
+ params=params)
178
+ return cameras
179
+
180
+ def read_extrinsics_binary(path_to_model_file):
181
+ """
182
+ see: src/base/reconstruction.cc
183
+ void Reconstruction::ReadImagesBinary(const std::string& path)
184
+ void Reconstruction::WriteImagesBinary(const std::string& path)
185
+ """
186
+ images = {}
187
+ with open(path_to_model_file, "rb") as fid:
188
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
189
+ for _ in range(num_reg_images):
190
+ binary_image_properties = read_next_bytes(
191
+ fid, num_bytes=64, format_char_sequence="idddddddi")
192
+ image_id = binary_image_properties[0]
193
+ qvec = np.array(binary_image_properties[1:5])
194
+ tvec = np.array(binary_image_properties[5:8])
195
+ camera_id = binary_image_properties[8]
196
+ image_name = ""
197
+ current_char = read_next_bytes(fid, 1, "c")[0]
198
+ while current_char != b"\x00": # look for the ASCII 0 entry
199
+ image_name += current_char.decode("utf-8")
200
+ current_char = read_next_bytes(fid, 1, "c")[0]
201
+ num_points2D = read_next_bytes(fid, num_bytes=8,
202
+ format_char_sequence="Q")[0]
203
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
204
+ format_char_sequence="ddq"*num_points2D)
205
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
206
+ tuple(map(float, x_y_id_s[1::3]))])
207
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
208
+ images[image_id] = Image(
209
+ id=image_id, qvec=qvec, tvec=tvec,
210
+ camera_id=camera_id, name=image_name,
211
+ xys=xys, point3D_ids=point3D_ids)
212
+ return images
213
+
214
+
215
+ def read_intrinsics_binary(path_to_model_file):
216
+ """
217
+ see: src/base/reconstruction.cc
218
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
219
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
220
+ """
221
+ cameras = {}
222
+ with open(path_to_model_file, "rb") as fid:
223
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
224
+ for _ in range(num_cameras):
225
+ camera_properties = read_next_bytes(
226
+ fid, num_bytes=24, format_char_sequence="iiQQ")
227
+ camera_id = camera_properties[0]
228
+ model_id = camera_properties[1]
229
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
230
+ width = camera_properties[2]
231
+ height = camera_properties[3]
232
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
233
+ params = read_next_bytes(fid, num_bytes=8*num_params,
234
+ format_char_sequence="d"*num_params)
235
+ cameras[camera_id] = Camera(id=camera_id,
236
+ model=model_name,
237
+ width=width,
238
+ height=height,
239
+ params=np.array(params))
240
+ assert len(cameras) == num_cameras
241
+ return cameras
242
+
243
+
244
+ def read_extrinsics_text(path):
245
+ """
246
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
247
+ """
248
+ images = {}
249
+ with open(path, "r") as fid:
250
+ while True:
251
+ line = fid.readline()
252
+ if not line:
253
+ break
254
+ line = line.strip()
255
+ if len(line) > 0 and line[0] != "#":
256
+ elems = line.split()
257
+ image_id = int(elems[0])
258
+ qvec = np.array(tuple(map(float, elems[1:5])))
259
+ tvec = np.array(tuple(map(float, elems[5:8])))
260
+ camera_id = int(elems[8])
261
+ image_name = elems[9]
262
+ elems = fid.readline().split()
263
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
264
+ tuple(map(float, elems[1::3]))])
265
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
266
+ images[image_id] = Image(
267
+ id=image_id, qvec=qvec, tvec=tvec,
268
+ camera_id=camera_id, name=image_name,
269
+ xys=xys, point3D_ids=point3D_ids)
270
+ return images
271
+
272
+
273
+ def read_colmap_bin_array(path):
274
+ """
275
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
276
+
277
+ :param path: path to the colmap binary file.
278
+ :return: nd array with the floating point values in the value
279
+ """
280
+ with open(path, "rb") as fid:
281
+ width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
282
+ usecols=(0, 1, 2), dtype=int)
283
+ fid.seek(0)
284
+ num_delimiter = 0
285
+ byte = fid.read(1)
286
+ while True:
287
+ if byte == b"&":
288
+ num_delimiter += 1
289
+ if num_delimiter >= 3:
290
+ break
291
+ byte = fid.read(1)
292
+ array = np.fromfile(fid, np.float32)
293
+ array = array.reshape((width, height, channels), order="F")
294
+ return np.transpose(array, (1, 0, 2)).squeeze()
scene/dataset_readers.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import os
13
+ import sys
14
+ from PIL import Image
15
+ from typing import NamedTuple
16
+ from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
17
+ read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
18
+ from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
19
+ import numpy as np
20
+ import json
21
+ from pathlib import Path
22
+ from plyfile import PlyData, PlyElement
23
+ from utils.sh_utils import SH2RGB
24
+ from scene.gaussian_model import BasicPointCloud
25
+
26
+ class CameraInfo(NamedTuple):
27
+ uid: int
28
+ R: np.array
29
+ T: np.array
30
+ FovY: np.array
31
+ FovX: np.array
32
+ image: np.array
33
+ image_path: str
34
+ image_name: str
35
+ width: int
36
+ height: int
37
+
38
+
39
+ class SceneInfo(NamedTuple):
40
+ point_cloud: BasicPointCloud
41
+ train_cameras: list
42
+ test_cameras: list
43
+ nerf_normalization: dict
44
+ ply_path: str
45
+ train_poses: list
46
+ test_poses: list
47
+
48
+ def getNerfppNorm(cam_info):
49
+ def get_center_and_diag(cam_centers):
50
+ cam_centers = np.hstack(cam_centers)
51
+ avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
52
+ center = avg_cam_center
53
+ dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
54
+ diagonal = np.max(dist)
55
+ return center.flatten(), diagonal
56
+
57
+ cam_centers = []
58
+
59
+ for cam in cam_info:
60
+ W2C = getWorld2View2(cam.R, cam.T)
61
+ C2W = np.linalg.inv(W2C)
62
+ cam_centers.append(C2W[:3, 3:4])
63
+
64
+ center, diagonal = get_center_and_diag(cam_centers)
65
+ radius = diagonal * 1.1
66
+
67
+ translate = -center
68
+
69
+ return {"translate": translate, "radius": radius}
70
+
71
+ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, eval):
72
+
73
+ cam_infos = []
74
+ poses=[]
75
+ for idx, key in enumerate(cam_extrinsics):
76
+ sys.stdout.write('\r')
77
+ # the exact output you're looking for:
78
+ sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
79
+ sys.stdout.flush()
80
+
81
+ if eval:
82
+ extr = cam_extrinsics[key]
83
+ intr = cam_intrinsics[1]
84
+ uid = idx+1
85
+
86
+ else:
87
+ extr = cam_extrinsics[key]
88
+ intr = cam_intrinsics[extr.camera_id]
89
+ uid = intr.id
90
+
91
+ height = intr.height
92
+ width = intr.width
93
+ R = np.transpose(qvec2rotmat(extr.qvec))
94
+ T = np.array(extr.tvec)
95
+ pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
96
+ poses.append(pose)
97
+ if intr.model=="SIMPLE_PINHOLE":
98
+ focal_length_x = intr.params[0]
99
+ FovY = focal2fov(focal_length_x, height)
100
+ FovX = focal2fov(focal_length_x, width)
101
+ elif intr.model=="PINHOLE":
102
+ focal_length_x = intr.params[0]
103
+ focal_length_y = intr.params[1]
104
+ FovY = focal2fov(focal_length_y, height)
105
+ FovX = focal2fov(focal_length_x, width)
106
+ else:
107
+ assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
108
+
109
+
110
+ if eval:
111
+ tmp = os.path.dirname(os.path.dirname(os.path.join(images_folder)))
112
+ all_images_folder = os.path.join(tmp, 'images')
113
+ image_path = os.path.join(all_images_folder, os.path.basename(extr.name))
114
+ else:
115
+ image_path = os.path.join(images_folder, os.path.basename(extr.name))
116
+ image_name = os.path.basename(image_path).split(".")[0]
117
+ image = Image.open(image_path)
118
+
119
+
120
+ cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
121
+ image_path=image_path, image_name=image_name, width=width, height=height)
122
+
123
+ cam_infos.append(cam_info)
124
+ sys.stdout.write('\n')
125
+ return cam_infos, poses
126
+
127
+ # For interpolated video, open when only render interpolated video
128
+ def readColmapCamerasInterp(cam_extrinsics, cam_intrinsics, images_folder, model_path):
129
+
130
+ pose_interpolated_path = model_path + 'pose/pose_interpolated.npy'
131
+ pose_interpolated = np.load(pose_interpolated_path)
132
+ intr = cam_intrinsics[1]
133
+
134
+ cam_infos = []
135
+ poses=[]
136
+ for idx, pose_npy in enumerate(pose_interpolated):
137
+ sys.stdout.write('\r')
138
+ sys.stdout.write("Reading camera {}/{}".format(idx+1, pose_interpolated.shape[0]))
139
+ sys.stdout.flush()
140
+
141
+ extr = pose_npy
142
+ intr = intr
143
+ height = intr.height
144
+ width = intr.width
145
+
146
+ uid = idx
147
+ R = extr[:3, :3].transpose()
148
+ T = extr[:3, 3]
149
+ pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
150
+ # print(uid)
151
+ # print(pose.shape)
152
+ # pose = np.linalg.inv(pose)
153
+ poses.append(pose)
154
+ if intr.model=="SIMPLE_PINHOLE":
155
+ focal_length_x = intr.params[0]
156
+ FovY = focal2fov(focal_length_x, height)
157
+ FovX = focal2fov(focal_length_x, width)
158
+ elif intr.model=="PINHOLE":
159
+ focal_length_x = intr.params[0]
160
+ focal_length_y = intr.params[1]
161
+ FovY = focal2fov(focal_length_y, height)
162
+ FovX = focal2fov(focal_length_x, width)
163
+ else:
164
+ assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
165
+
166
+ images_list = os.listdir(os.path.join(images_folder))
167
+ image_name_0 = images_list[0]
168
+ image_name = str(idx).zfill(4)
169
+ image = Image.open(images_folder + '/' + image_name_0)
170
+
171
+ cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
172
+ image_path=images_folder, image_name=image_name, width=width, height=height)
173
+ cam_infos.append(cam_info)
174
+
175
+ sys.stdout.write('\n')
176
+ return cam_infos, poses
177
+
178
+
179
+ def fetchPly(path):
180
+ plydata = PlyData.read(path)
181
+ vertices = plydata['vertex']
182
+ positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
183
+ colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
184
+ normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
185
+ return BasicPointCloud(points=positions, colors=colors, normals=normals)
186
+
187
+ def storePly(path, xyz, rgb):
188
+ # Define the dtype for the structured array
189
+ dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
190
+ ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
191
+ ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
192
+
193
+ normals = np.zeros_like(xyz)
194
+
195
+ elements = np.empty(xyz.shape[0], dtype=dtype)
196
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
197
+ elements[:] = list(map(tuple, attributes))
198
+
199
+ # Create the PlyData object and write to file
200
+ vertex_element = PlyElement.describe(elements, 'vertex')
201
+ ply_data = PlyData([vertex_element])
202
+ ply_data.write(path)
203
+
204
+ def readColmapSceneInfo(path, images, eval, args, opt, llffhold=2):
205
+ # try:
206
+ # cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
207
+ # cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
208
+ # cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
209
+ # cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
210
+ # except:
211
+
212
+ ##### For initializing test pose using PCD_Registration
213
+ if eval and opt.get_video==False:
214
+ print("Loading initial test pose for evaluation.")
215
+ cameras_extrinsic_file = os.path.join(path, "init_test_pose/sparse/0", "images.txt")
216
+ else:
217
+ cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
218
+
219
+ cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
220
+ cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
221
+ cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
222
+
223
+ reading_dir = "images" if images == None else images
224
+
225
+ if opt.get_video:
226
+ cam_infos_unsorted, poses = readColmapCamerasInterp(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), model_path=args.model_path)
227
+ else:
228
+ cam_infos_unsorted, poses = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), eval=eval)
229
+ sorting_indices = sorted(range(len(cam_infos_unsorted)), key=lambda x: cam_infos_unsorted[x].image_name)
230
+ cam_infos = [cam_infos_unsorted[i] for i in sorting_indices]
231
+ sorted_poses = [poses[i] for i in sorting_indices]
232
+ cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
233
+
234
+ if eval:
235
+ # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold != 0]
236
+ # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold == 0]
237
+ # train_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold != 0]
238
+ # test_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold == 0]
239
+
240
+ train_cam_infos = cam_infos
241
+ test_cam_infos = cam_infos
242
+ train_poses = sorted_poses
243
+ test_poses = sorted_poses
244
+
245
+ else:
246
+ train_cam_infos = cam_infos
247
+ test_cam_infos = []
248
+ train_poses = sorted_poses
249
+ test_poses = []
250
+
251
+ nerf_normalization = getNerfppNorm(train_cam_infos)
252
+
253
+ ply_path = os.path.join(path, "sparse/0/points3D.ply")
254
+ bin_path = os.path.join(path, "sparse/0/points3D.bin")
255
+ txt_path = os.path.join(path, "sparse/0/points3D.txt")
256
+ if not os.path.exists(ply_path):
257
+ print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
258
+ try:
259
+ xyz, rgb, _ = read_points3D_binary(bin_path)
260
+ except:
261
+ xyz, rgb, _ = read_points3D_text(txt_path)
262
+ storePly(ply_path, xyz, rgb)
263
+ try:
264
+ pcd = fetchPly(ply_path)
265
+ except:
266
+ pcd = None
267
+
268
+ # np.save("poses_family.npy", sorted_poses)
269
+ # breakpoint()
270
+ # np.save("3dpoints.npy", pcd.points)
271
+ # np.save("3dcolors.npy", pcd.colors)
272
+
273
+ scene_info = SceneInfo(point_cloud=pcd,
274
+ train_cameras=train_cam_infos,
275
+ test_cameras=test_cam_infos,
276
+ nerf_normalization=nerf_normalization,
277
+ ply_path=ply_path,
278
+ train_poses=train_poses,
279
+ test_poses=test_poses)
280
+ return scene_info
281
+
282
+ def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
283
+ cam_infos = []
284
+
285
+ with open(os.path.join(path, transformsfile)) as json_file:
286
+ contents = json.load(json_file)
287
+ fovx = contents["camera_angle_x"]
288
+
289
+ frames = contents["frames"]
290
+ for idx, frame in enumerate(frames):
291
+ cam_name = os.path.join(path, frame["file_path"] + extension)
292
+
293
+ # NeRF 'transform_matrix' is a camera-to-world transform
294
+ c2w = np.array(frame["transform_matrix"])
295
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
296
+ c2w[:3, 1:3] *= -1
297
+
298
+ # get the world-to-camera transform and set R, T
299
+ w2c = np.linalg.inv(c2w)
300
+ R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
301
+ T = w2c[:3, 3]
302
+
303
+ image_path = os.path.join(path, cam_name)
304
+ image_name = Path(cam_name).stem
305
+ image = Image.open(image_path)
306
+
307
+ im_data = np.array(image.convert("RGBA"))
308
+
309
+ bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
310
+
311
+ norm_data = im_data / 255.0
312
+ arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
313
+ image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
314
+
315
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
316
+ FovY = fovy
317
+ FovX = fovx
318
+
319
+ cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
320
+ image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
321
+
322
+ return cam_infos
323
+
324
+ def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
325
+ print("Reading Training Transforms")
326
+ train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
327
+ print("Reading Test Transforms")
328
+ test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
329
+
330
+ if not eval:
331
+ train_cam_infos.extend(test_cam_infos)
332
+ test_cam_infos = []
333
+
334
+ nerf_normalization = getNerfppNorm(train_cam_infos)
335
+
336
+ ply_path = os.path.join(path, "points3d.ply")
337
+ if not os.path.exists(ply_path):
338
+ # Since this data set has no colmap data, we start with random points
339
+ num_pts = 100_000
340
+ print(f"Generating random point cloud ({num_pts})...")
341
+
342
+ # We create random points inside the bounds of the synthetic Blender scenes
343
+ xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
344
+ shs = np.random.random((num_pts, 3)) / 255.0
345
+ pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
346
+
347
+ storePly(ply_path, xyz, SH2RGB(shs) * 255)
348
+ try:
349
+ pcd = fetchPly(ply_path)
350
+ except:
351
+ pcd = None
352
+
353
+ scene_info = SceneInfo(point_cloud=pcd,
354
+ train_cameras=train_cam_infos,
355
+ test_cameras=test_cam_infos,
356
+ nerf_normalization=nerf_normalization,
357
+ ply_path=ply_path)
358
+ return scene_info
359
+
360
+ sceneLoadTypeCallbacks = {
361
+ "Colmap": readColmapSceneInfo,
362
+ "Blender" : readNerfSyntheticInfo
363
+ }
scene/gaussian_model.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ # from lietorch import SO3, SE3, Sim3, LieGroupParameter
14
+ import numpy as np
15
+ from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
16
+ from torch import nn
17
+ import os
18
+ from utils.system_utils import mkdir_p
19
+ from plyfile import PlyData, PlyElement
20
+ from utils.sh_utils import RGB2SH
21
+ from simple_knn._C import distCUDA2
22
+ from utils.graphics_utils import BasicPointCloud
23
+ from utils.general_utils import strip_symmetric, build_scaling_rotation
24
+ from scipy.spatial.transform import Rotation as R
25
+ from utils.pose_utils import rotation2quad, get_tensor_from_camera
26
+ from utils.graphics_utils import getWorld2View2
27
+
28
+ def quaternion_to_rotation_matrix(quaternion):
29
+ """
30
+ Convert a quaternion to a rotation matrix.
31
+
32
+ Parameters:
33
+ - quaternion: A tensor of shape (..., 4) representing quaternions.
34
+
35
+ Returns:
36
+ - A tensor of shape (..., 3, 3) representing rotation matrices.
37
+ """
38
+ # Ensure quaternion is of float type for computation
39
+ quaternion = quaternion.float()
40
+
41
+ # Normalize the quaternion to unit length
42
+ quaternion = quaternion / quaternion.norm(p=2, dim=-1, keepdim=True)
43
+
44
+ # Extract components
45
+ w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3]
46
+
47
+ # Compute rotation matrix components
48
+ xx, yy, zz = x * x, y * y, z * z
49
+ xy, xz, yz = x * y, x * z, y * z
50
+ xw, yw, zw = x * w, y * w, z * w
51
+
52
+ # Assemble the rotation matrix
53
+ R = torch.stack([
54
+ torch.stack([1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw)], dim=-1),
55
+ torch.stack([ 2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw)], dim=-1),
56
+ torch.stack([ 2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy)], dim=-1)
57
+ ], dim=-2)
58
+
59
+ return R
60
+
61
+
62
+ class GaussianModel:
63
+
64
+ def setup_functions(self):
65
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
66
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
67
+ actual_covariance = L @ L.transpose(1, 2)
68
+ symm = strip_symmetric(actual_covariance)
69
+ return symm
70
+
71
+ self.scaling_activation = torch.exp
72
+ self.scaling_inverse_activation = torch.log
73
+
74
+ self.covariance_activation = build_covariance_from_scaling_rotation
75
+
76
+ self.opacity_activation = torch.sigmoid
77
+ self.inverse_opacity_activation = inverse_sigmoid
78
+
79
+ self.rotation_activation = torch.nn.functional.normalize
80
+
81
+
82
+ def __init__(self, sh_degree : int):
83
+ self.active_sh_degree = 0
84
+ self.max_sh_degree = sh_degree
85
+ self._xyz = torch.empty(0)
86
+ self._features_dc = torch.empty(0)
87
+ self._features_rest = torch.empty(0)
88
+ self._scaling = torch.empty(0)
89
+ self._rotation = torch.empty(0)
90
+ self._opacity = torch.empty(0)
91
+ self.max_radii2D = torch.empty(0)
92
+ self.xyz_gradient_accum = torch.empty(0)
93
+ self.denom = torch.empty(0)
94
+ self.optimizer = None
95
+ self.percent_dense = 0
96
+ self.spatial_lr_scale = 0
97
+ self.setup_functions()
98
+
99
+ def capture(self):
100
+ return (
101
+ self.active_sh_degree,
102
+ self._xyz,
103
+ self._features_dc,
104
+ self._features_rest,
105
+ self._scaling,
106
+ self._rotation,
107
+ self._opacity,
108
+ self.max_radii2D,
109
+ self.xyz_gradient_accum,
110
+ self.denom,
111
+ self.optimizer.state_dict(),
112
+ self.spatial_lr_scale,
113
+ self.P,
114
+ )
115
+
116
+ def restore(self, model_args, training_args):
117
+ (self.active_sh_degree,
118
+ self._xyz,
119
+ self._features_dc,
120
+ self._features_rest,
121
+ self._scaling,
122
+ self._rotation,
123
+ self._opacity,
124
+ self.max_radii2D,
125
+ xyz_gradient_accum,
126
+ denom,
127
+ opt_dict,
128
+ self.spatial_lr_scale,
129
+ self.P) = model_args
130
+ self.training_setup(training_args)
131
+ self.xyz_gradient_accum = xyz_gradient_accum
132
+ self.denom = denom
133
+ self.optimizer.load_state_dict(opt_dict)
134
+
135
+ @property
136
+ def get_scaling(self):
137
+ return self.scaling_activation(self._scaling)
138
+
139
+ @property
140
+ def get_rotation(self):
141
+ return self.rotation_activation(self._rotation)
142
+
143
+ @property
144
+ def get_xyz(self):
145
+ return self._xyz
146
+
147
+ def compute_relative_world_to_camera(self, R1, t1, R2, t2):
148
+ # Create a row of zeros with a one at the end, for homogeneous coordinates
149
+ zero_row = np.array([[0, 0, 0, 1]], dtype=np.float32)
150
+
151
+ # Compute the inverse of the first extrinsic matrix
152
+ E1_inv = np.hstack([R1.T, -R1.T @ t1.reshape(-1, 1)]) # Transpose and reshape for correct dimensions
153
+ E1_inv = np.vstack([E1_inv, zero_row]) # Append the zero_row to make it a 4x4 matrix
154
+
155
+ # Compute the second extrinsic matrix
156
+ E2 = np.hstack([R2, -R2 @ t2.reshape(-1, 1)]) # No need to transpose R2
157
+ E2 = np.vstack([E2, zero_row]) # Append the zero_row to make it a 4x4 matrix
158
+
159
+ # Compute the relative transformation
160
+ E_rel = E2 @ E1_inv
161
+
162
+ return E_rel
163
+
164
+ def init_RT_seq(self, cam_list):
165
+ poses =[]
166
+ for cam in cam_list[1.0]:
167
+ p = get_tensor_from_camera(cam.world_view_transform.transpose(0, 1)) # R T -> quat t
168
+ poses.append(p)
169
+ poses = torch.stack(poses)
170
+ self.P = poses.cuda().requires_grad_(True)
171
+
172
+
173
+ def get_RT(self, idx):
174
+ pose = self.P[idx]
175
+ return pose
176
+
177
+ def get_RT_test(self, idx):
178
+ pose = self.test_P[idx]
179
+ return pose
180
+
181
+ @property
182
+ def get_features(self):
183
+ features_dc = self._features_dc
184
+ features_rest = self._features_rest
185
+ return torch.cat((features_dc, features_rest), dim=1)
186
+
187
+ @property
188
+ def get_opacity(self):
189
+ return self.opacity_activation(self._opacity)
190
+
191
+ def get_covariance(self, scaling_modifier = 1):
192
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
193
+
194
+ def oneupSHdegree(self):
195
+ if self.active_sh_degree < self.max_sh_degree:
196
+ self.active_sh_degree += 1
197
+
198
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
199
+
200
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # gradio
201
+
202
+ self.spatial_lr_scale = spatial_lr_scale
203
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
204
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
205
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
206
+ features[:, :3, 0 ] = fused_color
207
+ features[:, 3:, 1:] = 0.0
208
+
209
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
210
+
211
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
212
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
213
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
214
+ rots[:, 0] = 1
215
+
216
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
217
+
218
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
219
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
220
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
221
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
222
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
223
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
224
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
225
+
226
+ def training_setup(self, training_args):
227
+ self.percent_dense = training_args.percent_dense
228
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
229
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
230
+
231
+ l = [
232
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
233
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
234
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
235
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
236
+ {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
237
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
238
+ ]
239
+
240
+ l_cam = [{'params': [self.P],'lr': training_args.rotation_lr*0.1, "name": "pose"},]
241
+
242
+ l += l_cam
243
+
244
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
245
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
246
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
247
+ lr_delay_mult=training_args.position_lr_delay_mult,
248
+ max_steps=training_args.position_lr_max_steps)
249
+ self.cam_scheduler_args = get_expon_lr_func(
250
+ # lr_init=0,
251
+ # lr_final=0,
252
+ lr_init=training_args.rotation_lr*0.1,
253
+ lr_final=training_args.rotation_lr*0.001,
254
+ # lr_init=training_args.position_lr_init*self.spatial_lr_scale*10,
255
+ # lr_final=training_args.position_lr_final*self.spatial_lr_scale*10,
256
+ lr_delay_mult=training_args.position_lr_delay_mult,
257
+ max_steps=1000)
258
+
259
+ def update_learning_rate(self, iteration):
260
+ ''' Learning rate scheduling per step '''
261
+ for param_group in self.optimizer.param_groups:
262
+ if param_group["name"] == "pose":
263
+ lr = self.cam_scheduler_args(iteration)
264
+ # print("pose learning rate", iteration, lr)
265
+ param_group['lr'] = lr
266
+ if param_group["name"] == "xyz":
267
+ lr = self.xyz_scheduler_args(iteration)
268
+ param_group['lr'] = lr
269
+ # return lr
270
+
271
+ def construct_list_of_attributes(self):
272
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
273
+ # All channels except the 3 DC
274
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
275
+ l.append('f_dc_{}'.format(i))
276
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
277
+ l.append('f_rest_{}'.format(i))
278
+ l.append('opacity')
279
+ for i in range(self._scaling.shape[1]):
280
+ l.append('scale_{}'.format(i))
281
+ for i in range(self._rotation.shape[1]):
282
+ l.append('rot_{}'.format(i))
283
+ return l
284
+
285
+ def save_ply(self, path):
286
+ mkdir_p(os.path.dirname(path))
287
+
288
+ xyz = self._xyz.detach().cpu().numpy()
289
+ normals = np.zeros_like(xyz)
290
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
291
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
292
+ opacities = self._opacity.detach().cpu().numpy()
293
+ scale = self._scaling.detach().cpu().numpy()
294
+ rotation = self._rotation.detach().cpu().numpy()
295
+
296
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
297
+
298
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
299
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
300
+ elements[:] = list(map(tuple, attributes))
301
+ el = PlyElement.describe(elements, 'vertex')
302
+ PlyData([el]).write(path)
303
+
304
+ def reset_opacity(self):
305
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
306
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
307
+ self._opacity = optimizable_tensors["opacity"]
308
+
309
+ def load_ply(self, path):
310
+ plydata = PlyData.read(path)
311
+
312
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
313
+ np.asarray(plydata.elements[0]["y"]),
314
+ np.asarray(plydata.elements[0]["z"])), axis=1)
315
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
316
+
317
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
318
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
319
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
320
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
321
+
322
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
323
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
324
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
325
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
326
+ for idx, attr_name in enumerate(extra_f_names):
327
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
328
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
329
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
330
+
331
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
332
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
333
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
334
+ for idx, attr_name in enumerate(scale_names):
335
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
336
+
337
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
338
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
339
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
340
+ for idx, attr_name in enumerate(rot_names):
341
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
342
+
343
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
344
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
345
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
346
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
347
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
348
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
349
+
350
+ self.active_sh_degree = self.max_sh_degree
351
+
352
+ def replace_tensor_to_optimizer(self, tensor, name):
353
+ optimizable_tensors = {}
354
+ for group in self.optimizer.param_groups:
355
+ if group["name"] == name:
356
+ # breakpoint()
357
+ stored_state = self.optimizer.state.get(group['params'][0], None)
358
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
359
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
360
+
361
+ del self.optimizer.state[group['params'][0]]
362
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
363
+ self.optimizer.state[group['params'][0]] = stored_state
364
+
365
+ optimizable_tensors[group["name"]] = group["params"][0]
366
+ return optimizable_tensors
367
+
368
+ def _prune_optimizer(self, mask):
369
+ optimizable_tensors = {}
370
+ for group in self.optimizer.param_groups:
371
+ stored_state = self.optimizer.state.get(group['params'][0], None)
372
+ if stored_state is not None:
373
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
374
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
375
+
376
+ del self.optimizer.state[group['params'][0]]
377
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
378
+ self.optimizer.state[group['params'][0]] = stored_state
379
+
380
+ optimizable_tensors[group["name"]] = group["params"][0]
381
+ else:
382
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
383
+ optimizable_tensors[group["name"]] = group["params"][0]
384
+ return optimizable_tensors
385
+
386
+ def prune_points(self, mask):
387
+ valid_points_mask = ~mask
388
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
389
+
390
+ self._xyz = optimizable_tensors["xyz"]
391
+ self._features_dc = optimizable_tensors["f_dc"]
392
+ self._features_rest = optimizable_tensors["f_rest"]
393
+ self._opacity = optimizable_tensors["opacity"]
394
+ self._scaling = optimizable_tensors["scaling"]
395
+ self._rotation = optimizable_tensors["rotation"]
396
+
397
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
398
+
399
+ self.denom = self.denom[valid_points_mask]
400
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
401
+
402
+ def cat_tensors_to_optimizer(self, tensors_dict):
403
+ optimizable_tensors = {}
404
+ for group in self.optimizer.param_groups:
405
+ assert len(group["params"]) == 1
406
+ extension_tensor = tensors_dict[group["name"]]
407
+ stored_state = self.optimizer.state.get(group['params'][0], None)
408
+ if stored_state is not None:
409
+
410
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
411
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
412
+
413
+ del self.optimizer.state[group['params'][0]]
414
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
415
+ self.optimizer.state[group['params'][0]] = stored_state
416
+
417
+ optimizable_tensors[group["name"]] = group["params"][0]
418
+ else:
419
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
420
+ optimizable_tensors[group["name"]] = group["params"][0]
421
+
422
+ return optimizable_tensors
423
+
424
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
425
+ d = {"xyz": new_xyz,
426
+ "f_dc": new_features_dc,
427
+ "f_rest": new_features_rest,
428
+ "opacity": new_opacities,
429
+ "scaling" : new_scaling,
430
+ "rotation" : new_rotation}
431
+
432
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
433
+ self._xyz = optimizable_tensors["xyz"]
434
+ self._features_dc = optimizable_tensors["f_dc"]
435
+ self._features_rest = optimizable_tensors["f_rest"]
436
+ self._opacity = optimizable_tensors["opacity"]
437
+ self._scaling = optimizable_tensors["scaling"]
438
+ self._rotation = optimizable_tensors["rotation"]
439
+
440
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
441
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
442
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
443
+
444
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
445
+ n_init_points = self.get_xyz.shape[0]
446
+ # Extract points that satisfy the gradient condition
447
+ padded_grad = torch.zeros((n_init_points), device="cuda")
448
+ padded_grad[:grads.shape[0]] = grads.squeeze()
449
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
450
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
451
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
452
+
453
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
454
+ means =torch.zeros((stds.size(0), 3),device="cuda")
455
+ samples = torch.normal(mean=means, std=stds)
456
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
457
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
458
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
459
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
460
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
461
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
462
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
463
+
464
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
465
+
466
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
467
+ self.prune_points(prune_filter)
468
+
469
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
470
+ # Extract points that satisfy the gradient condition
471
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
472
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
473
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
474
+
475
+ new_xyz = self._xyz[selected_pts_mask]
476
+ new_features_dc = self._features_dc[selected_pts_mask]
477
+ new_features_rest = self._features_rest[selected_pts_mask]
478
+ new_opacities = self._opacity[selected_pts_mask]
479
+ new_scaling = self._scaling[selected_pts_mask]
480
+ new_rotation = self._rotation[selected_pts_mask]
481
+
482
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
483
+
484
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
485
+ grads = self.xyz_gradient_accum / self.denom
486
+ grads[grads.isnan()] = 0.0
487
+
488
+ # self.densify_and_clone(grads, max_grad, extent)
489
+ # self.densify_and_split(grads, max_grad, extent)
490
+
491
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
492
+ if max_screen_size:
493
+ big_points_vs = self.max_radii2D > max_screen_size
494
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
495
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
496
+ self.prune_points(prune_mask)
497
+
498
+ torch.cuda.empty_cache()
499
+
500
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
501
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
502
+ self.denom[update_filter] += 1
submodules/diff-gaussian-rasterization/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build/
2
+ diff_gaussian_rasterization.egg-info/
3
+ dist/
submodules/diff-gaussian-rasterization/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/glm"]
2
+ path = third_party/glm
3
+ url = https://github.com/g-truc/glm.git
submodules/diff-gaussian-rasterization/CMakeLists.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ cmake_minimum_required(VERSION 3.20)
13
+
14
+ project(DiffRast LANGUAGES CUDA CXX)
15
+
16
+ set(CMAKE_CXX_STANDARD 17)
17
+ set(CMAKE_CXX_EXTENSIONS OFF)
18
+ set(CMAKE_CUDA_STANDARD 17)
19
+
20
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
21
+
22
+ add_library(CudaRasterizer
23
+ cuda_rasterizer/backward.h
24
+ cuda_rasterizer/backward.cu
25
+ cuda_rasterizer/forward.h
26
+ cuda_rasterizer/forward.cu
27
+ cuda_rasterizer/auxiliary.h
28
+ cuda_rasterizer/rasterizer_impl.cu
29
+ cuda_rasterizer/rasterizer_impl.h
30
+ cuda_rasterizer/rasterizer.h
31
+ )
32
+
33
+ set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
34
+
35
+ target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
36
+ target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
submodules/diff-gaussian-rasterization/LICENSE.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gaussian-Splatting License
2
+ ===========================
3
+
4
+ **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5
+ The *Software* is in the process of being registered with the Agence pour la Protection des
6
+ Programmes (APP).
7
+
8
+ The *Software* is still being developed by the *Licensor*.
9
+
10
+ *Licensor*'s goal is to allow the research community to use, test and evaluate
11
+ the *Software*.
12
+
13
+ ## 1. Definitions
14
+
15
+ *Licensee* means any person or entity that uses the *Software* and distributes
16
+ its *Work*.
17
+
18
+ *Licensor* means the owners of the *Software*, i.e Inria and MPII
19
+
20
+ *Software* means the original work of authorship made available under this
21
+ License ie gaussian-splatting.
22
+
23
+ *Work* means the *Software* and any additions to or derivative works of the
24
+ *Software* that are made available under this License.
25
+
26
+
27
+ ## 2. Purpose
28
+ This license is intended to define the rights granted to the *Licensee* by
29
+ Licensors under the *Software*.
30
+
31
+ ## 3. Rights granted
32
+
33
+ For the above reasons Licensors have decided to distribute the *Software*.
34
+ Licensors grant non-exclusive rights to use the *Software* for research purposes
35
+ to research users (both academic and industrial), free of charge, without right
36
+ to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37
+ and/or evaluation purposes only.
38
+
39
+ Subject to the terms and conditions of this License, you are granted a
40
+ non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41
+ publicly display, publicly perform and distribute its *Work* and any resulting
42
+ derivative works in any form.
43
+
44
+ ## 4. Limitations
45
+
46
+ **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47
+ so under this License, (b) you include a complete copy of this License with
48
+ your distribution, and (c) you retain without modification any copyright,
49
+ patent, trademark, or attribution notices that are present in the *Work*.
50
+
51
+ **4.2 Derivative Works.** You may specify that additional or different terms apply
52
+ to the use, reproduction, and distribution of your derivative works of the *Work*
53
+ ("Your Terms") only if (a) Your Terms provide that the use limitation in
54
+ Section 2 applies to your derivative works, and (b) you identify the specific
55
+ derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56
+ this License (including the redistribution requirements in Section 3.1) will
57
+ continue to apply to the *Work* itself.
58
+
59
+ **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60
+ users explicitly acknowledge having received from Licensors all information
61
+ allowing to appreciate the adequacy between of the *Software* and their needs and
62
+ to undertake all necessary precautions for its execution and use.
63
+
64
+ **4.4** The *Software* is provided both as a compiled library file and as source
65
+ code. In case of using the *Software* for a publication or other results obtained
66
+ through the use of the *Software*, users are strongly encouraged to cite the
67
+ corresponding publications as explained in the documentation of the *Software*.
68
+
69
+ ## 5. Disclaimer
70
+
71
+ THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72
+ WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73
+ UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74
+ CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75
+ OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76
+ USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77
+ ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78
+ AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80
+ GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81
+ HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83
+ IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
submodules/diff-gaussian-rasterization/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Differential Gaussian Rasterization
2
+
3
+ Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us.
4
+
5
+ <section class="section" id="BibTeX">
6
+ <div class="container is-max-desktop content">
7
+ <h2 class="title">BibTeX</h2>
8
+ <pre><code>@Article{kerbl3Dgaussians,
9
+ author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
10
+ title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
11
+ journal = {ACM Transactions on Graphics},
12
+ number = {4},
13
+ volume = {42},
14
+ month = {July},
15
+ year = {2023},
16
+ url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
17
+ }</code></pre>
18
+ </div>
19
+ </section>
submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
13
+ #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
14
+
15
+ #include "config.h"
16
+ #include "stdio.h"
17
+
18
+ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
19
+ #define NUM_WARPS (BLOCK_SIZE/32)
20
+
21
+ // Spherical harmonics coefficients
22
+ __device__ const float SH_C0 = 0.28209479177387814f;
23
+ __device__ const float SH_C1 = 0.4886025119029199f;
24
+ __device__ const float SH_C2[] = {
25
+ 1.0925484305920792f,
26
+ -1.0925484305920792f,
27
+ 0.31539156525252005f,
28
+ -1.0925484305920792f,
29
+ 0.5462742152960396f
30
+ };
31
+ __device__ const float SH_C3[] = {
32
+ -0.5900435899266435f,
33
+ 2.890611442640554f,
34
+ -0.4570457994644658f,
35
+ 0.3731763325901154f,
36
+ -0.4570457994644658f,
37
+ 1.445305721320277f,
38
+ -0.5900435899266435f
39
+ };
40
+
41
+ __forceinline__ __device__ float ndc2Pix(float v, int S)
42
+ {
43
+ return ((v + 1.0) * S - 1.0) * 0.5;
44
+ }
45
+
46
+ __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
47
+ {
48
+ rect_min = {
49
+ min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
50
+ min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
51
+ };
52
+ rect_max = {
53
+ min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
54
+ min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
55
+ };
56
+ }
57
+
58
+ __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix)
59
+ {
60
+ float3 transformed = {
61
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
62
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
63
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
64
+ };
65
+ return transformed;
66
+ }
67
+
68
+ __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
69
+ {
70
+ float4 transformed = {
71
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
72
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
73
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
74
+ matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
75
+ };
76
+ return transformed;
77
+ }
78
+
79
+ __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix)
80
+ {
81
+ float3 transformed = {
82
+ matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
83
+ matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
84
+ matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
85
+ };
86
+ return transformed;
87
+ }
88
+
89
+ __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix)
90
+ {
91
+ float3 transformed = {
92
+ matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
93
+ matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
94
+ matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
95
+ };
96
+ return transformed;
97
+ }
98
+
99
+ __forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
100
+ {
101
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
102
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
103
+ float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
104
+ return dnormvdz;
105
+ }
106
+
107
+ __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
108
+ {
109
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
110
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
111
+
112
+ float3 dnormvdv;
113
+ dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
114
+ dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32;
115
+ dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
116
+ return dnormvdv;
117
+ }
118
+
119
+ __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
120
+ {
121
+ float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
122
+ float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
123
+
124
+ float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
125
+ float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
126
+ float4 dnormvdv;
127
+ dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
128
+ dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
129
+ dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
130
+ dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
131
+ return dnormvdv;
132
+ }
133
+
134
+ __forceinline__ __device__ float sigmoid(float x)
135
+ {
136
+ return 1.0f / (1.0f + expf(-x));
137
+ }
138
+
139
+ __forceinline__ __device__ bool in_frustum(int idx,
140
+ const float* orig_points,
141
+ const float* viewmatrix,
142
+ const float* projmatrix,
143
+ bool prefiltered,
144
+ float3& p_view)
145
+ {
146
+ float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
147
+
148
+ // Bring points to screen space
149
+ float4 p_hom = transformPoint4x4(p_orig, projmatrix);
150
+ float p_w = 1.0f / (p_hom.w + 0.0000001f);
151
+ float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
152
+ p_view = transformPoint4x3(p_orig, viewmatrix);
153
+
154
+ if (p_view.z <= 0.01f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
155
+ {
156
+ if (prefiltered)
157
+ {
158
+ printf("Point is filtered although prefiltered is set. This shouldn't happen!");
159
+ __trap();
160
+ }
161
+ return false;
162
+ }
163
+ return true;
164
+ }
165
+
166
+ #define CHECK_CUDA(A, debug) \
167
+ A; if(debug) { \
168
+ auto ret = cudaDeviceSynchronize(); \
169
+ if (ret != cudaSuccess) { \
170
+ std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
171
+ throw std::runtime_error(cudaGetErrorString(ret)); \
172
+ } \
173
+ }
174
+
175
+ #endif
submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "backward.h"
13
+ #include "auxiliary.h"
14
+ #include <cooperative_groups.h>
15
+ #include <cooperative_groups/reduce.h>
16
+ namespace cg = cooperative_groups;
17
+
18
+ // Backward pass for conversion of spherical harmonics to RGB for
19
+ // each Gaussian.
20
+ __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
21
+ {
22
+ // Compute intermediate values, as it is done during forward
23
+ glm::vec3 pos = means[idx];
24
+ glm::vec3 dir_orig = pos - campos;
25
+ glm::vec3 dir = dir_orig / glm::length(dir_orig);
26
+
27
+ glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
28
+
29
+ // Use PyTorch rule for clamping: if clamping was applied,
30
+ // gradient becomes 0.
31
+ glm::vec3 dL_dRGB = dL_dcolor[idx];
32
+ dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
33
+ dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
34
+ dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
35
+
36
+ glm::vec3 dRGBdx(0, 0, 0);
37
+ glm::vec3 dRGBdy(0, 0, 0);
38
+ glm::vec3 dRGBdz(0, 0, 0);
39
+ float x = dir.x;
40
+ float y = dir.y;
41
+ float z = dir.z;
42
+
43
+ // Target location for this Gaussian to write SH gradients to
44
+ glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
45
+
46
+ // No tricks here, just high school-level calculus.
47
+ float dRGBdsh0 = SH_C0;
48
+ dL_dsh[0] = dRGBdsh0 * dL_dRGB;
49
+ if (deg > 0)
50
+ {
51
+ float dRGBdsh1 = -SH_C1 * y;
52
+ float dRGBdsh2 = SH_C1 * z;
53
+ float dRGBdsh3 = -SH_C1 * x;
54
+ dL_dsh[1] = dRGBdsh1 * dL_dRGB;
55
+ dL_dsh[2] = dRGBdsh2 * dL_dRGB;
56
+ dL_dsh[3] = dRGBdsh3 * dL_dRGB;
57
+
58
+ dRGBdx = -SH_C1 * sh[3];
59
+ dRGBdy = -SH_C1 * sh[1];
60
+ dRGBdz = SH_C1 * sh[2];
61
+
62
+ if (deg > 1)
63
+ {
64
+ float xx = x * x, yy = y * y, zz = z * z;
65
+ float xy = x * y, yz = y * z, xz = x * z;
66
+
67
+ float dRGBdsh4 = SH_C2[0] * xy;
68
+ float dRGBdsh5 = SH_C2[1] * yz;
69
+ float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
70
+ float dRGBdsh7 = SH_C2[3] * xz;
71
+ float dRGBdsh8 = SH_C2[4] * (xx - yy);
72
+ dL_dsh[4] = dRGBdsh4 * dL_dRGB;
73
+ dL_dsh[5] = dRGBdsh5 * dL_dRGB;
74
+ dL_dsh[6] = dRGBdsh6 * dL_dRGB;
75
+ dL_dsh[7] = dRGBdsh7 * dL_dRGB;
76
+ dL_dsh[8] = dRGBdsh8 * dL_dRGB;
77
+
78
+ dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
79
+ dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
80
+ dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];
81
+
82
+ if (deg > 2)
83
+ {
84
+ float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
85
+ float dRGBdsh10 = SH_C3[1] * xy * z;
86
+ float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
87
+ float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
88
+ float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
89
+ float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
90
+ float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
91
+ dL_dsh[9] = dRGBdsh9 * dL_dRGB;
92
+ dL_dsh[10] = dRGBdsh10 * dL_dRGB;
93
+ dL_dsh[11] = dRGBdsh11 * dL_dRGB;
94
+ dL_dsh[12] = dRGBdsh12 * dL_dRGB;
95
+ dL_dsh[13] = dRGBdsh13 * dL_dRGB;
96
+ dL_dsh[14] = dRGBdsh14 * dL_dRGB;
97
+ dL_dsh[15] = dRGBdsh15 * dL_dRGB;
98
+
99
+ dRGBdx += (
100
+ SH_C3[0] * sh[9] * 3.f * 2.f * xy +
101
+ SH_C3[1] * sh[10] * yz +
102
+ SH_C3[2] * sh[11] * -2.f * xy +
103
+ SH_C3[3] * sh[12] * -3.f * 2.f * xz +
104
+ SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
105
+ SH_C3[5] * sh[14] * 2.f * xz +
106
+ SH_C3[6] * sh[15] * 3.f * (xx - yy));
107
+
108
+ dRGBdy += (
109
+ SH_C3[0] * sh[9] * 3.f * (xx - yy) +
110
+ SH_C3[1] * sh[10] * xz +
111
+ SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
112
+ SH_C3[3] * sh[12] * -3.f * 2.f * yz +
113
+ SH_C3[4] * sh[13] * -2.f * xy +
114
+ SH_C3[5] * sh[14] * -2.f * yz +
115
+ SH_C3[6] * sh[15] * -3.f * 2.f * xy);
116
+
117
+ dRGBdz += (
118
+ SH_C3[1] * sh[10] * xy +
119
+ SH_C3[2] * sh[11] * 4.f * 2.f * yz +
120
+ SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
121
+ SH_C3[4] * sh[13] * 4.f * 2.f * xz +
122
+ SH_C3[5] * sh[14] * (xx - yy));
123
+ }
124
+ }
125
+ }
126
+
127
+ // The view direction is an input to the computation. View direction
128
+ // is influenced by the Gaussian's mean, so SHs gradients
129
+ // must propagate back into 3D position.
130
+ glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));
131
+
132
+ // Account for normalization of direction
133
+ float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });
134
+
135
+ // Gradients of loss w.r.t. Gaussian means, but only the portion
136
+ // that is caused because the mean affects the view-dependent color.
137
+ // Additional mean gradient is accumulated in below methods.
138
+ dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
139
+ }
140
+
141
+ // Backward version of INVERSE 2D covariance matrix computation
142
+ // (due to length launched as separate kernel before other
143
+ // backward steps contained in preprocess)
144
+ __global__ void computeCov2DCUDA(int P,
145
+ const float3* means,
146
+ const int* radii,
147
+ const float* cov3Ds,
148
+ const float h_x, float h_y,
149
+ const float tan_fovx, float tan_fovy,
150
+ const float* view_matrix,
151
+ const float* dL_dconics,
152
+ float3* dL_dmeans,
153
+ float* dL_dcov)
154
+ {
155
+ auto idx = cg::this_grid().thread_rank();
156
+ if (idx >= P || !(radii[idx] > 0))
157
+ return;
158
+
159
+ // Reading location of 3D covariance for this Gaussian
160
+ const float* cov3D = cov3Ds + 6 * idx;
161
+
162
+ // Fetch gradients, recompute 2D covariance and relevant
163
+ // intermediate forward results needed in the backward.
164
+ float3 mean = means[idx];
165
+ float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
166
+ float3 t = transformPoint4x3(mean, view_matrix);
167
+
168
+ const float limx = 1.3f * tan_fovx;
169
+ const float limy = 1.3f * tan_fovy;
170
+ const float txtz = t.x / t.z;
171
+ const float tytz = t.y / t.z;
172
+ t.x = min(limx, max(-limx, txtz)) * t.z;
173
+ t.y = min(limy, max(-limy, tytz)) * t.z;
174
+
175
+ const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
176
+ const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
177
+
178
+ glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
179
+ 0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
180
+ 0, 0, 0);
181
+
182
+ glm::mat3 W = glm::mat3(
183
+ view_matrix[0], view_matrix[4], view_matrix[8],
184
+ view_matrix[1], view_matrix[5], view_matrix[9],
185
+ view_matrix[2], view_matrix[6], view_matrix[10]);
186
+
187
+ glm::mat3 Vrk = glm::mat3(
188
+ cov3D[0], cov3D[1], cov3D[2],
189
+ cov3D[1], cov3D[3], cov3D[4],
190
+ cov3D[2], cov3D[4], cov3D[5]);
191
+
192
+ glm::mat3 T = W * J;
193
+
194
+ glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
195
+
196
+ // Use helper variables for 2D covariance entries. More compact.
197
+ float a = cov2D[0][0] += 0.3f;
198
+ float b = cov2D[0][1];
199
+ float c = cov2D[1][1] += 0.3f;
200
+
201
+ float denom = a * c - b * b;
202
+ float dL_da = 0, dL_db = 0, dL_dc = 0;
203
+ float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
204
+
205
+ if (denom2inv != 0)
206
+ {
207
+ // Gradients of loss w.r.t. entries of 2D covariance matrix,
208
+ // given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
209
+ // e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
210
+ dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z);
211
+ dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x);
212
+ dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z);
213
+
214
+ // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
215
+ // given gradients w.r.t. 2D covariance matrix (diagonal).
216
+ // cov2D = transpose(T) * transpose(Vrk) * T;
217
+ dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc);
218
+ dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc);
219
+ dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc);
220
+
221
+ // Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
222
+ // given gradients w.r.t. 2D covariance matrix (off-diagonal).
223
+ // Off-diagonal elements appear twice --> double the gradient.
224
+ // cov2D = transpose(T) * transpose(Vrk) * T;
225
+ dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc;
226
+ dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc;
227
+ dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc;
228
+ }
229
+ else
230
+ {
231
+ for (int i = 0; i < 6; i++)
232
+ dL_dcov[6 * idx + i] = 0;
233
+ }
234
+
235
+ // Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
236
+ // cov2D = transpose(T) * transpose(Vrk) * T;
237
+ float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da +
238
+ (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
239
+ float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da +
240
+ (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
241
+ float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da +
242
+ (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
243
+ float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc +
244
+ (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
245
+ float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc +
246
+ (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
247
+ float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc +
248
+ (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
249
+
250
+ // Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
251
+ // T = W * J
252
+ float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
253
+ float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
254
+ float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
255
+ float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
256
+
257
+ float tz = 1.f / t.z;
258
+ float tz2 = tz * tz;
259
+ float tz3 = tz2 * tz;
260
+
261
+ // Gradients of loss w.r.t. transformed Gaussian mean t
262
+ float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
263
+ float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
264
+ float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;
265
+
266
+ // Account for transformation of mean to t
267
+ // t = transformPoint4x3(mean, view_matrix);
268
+ float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix);
269
+
270
+ // Gradients of loss w.r.t. Gaussian means, but only the portion
271
+ // that is caused because the mean affects the covariance matrix.
272
+ // Additional mean gradient is accumulated in BACKWARD::preprocess.
273
+ dL_dmeans[idx] = dL_dmean;
274
+ }
275
+
276
+ // Backward pass for the conversion of scale and rotation to a
277
+ // 3D covariance matrix for each Gaussian.
278
+ __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
279
+ {
280
+ // Recompute (intermediate) results for the 3D covariance computation.
281
+ glm::vec4 q = rot;// / glm::length(rot);
282
+ float r = q.x;
283
+ float x = q.y;
284
+ float y = q.z;
285
+ float z = q.w;
286
+
287
+ glm::mat3 R = glm::mat3(
288
+ 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
289
+ 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
290
+ 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
291
+ );
292
+
293
+ glm::mat3 S = glm::mat3(1.0f);
294
+
295
+ glm::vec3 s = mod * scale;
296
+ S[0][0] = s.x;
297
+ S[1][1] = s.y;
298
+ S[2][2] = s.z;
299
+
300
+ glm::mat3 M = S * R;
301
+
302
+ const float* dL_dcov3D = dL_dcov3Ds + 6 * idx;
303
+
304
+ glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
305
+ glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
306
+
307
+ // Convert per-element covariance loss gradients to matrix form
308
+ glm::mat3 dL_dSigma = glm::mat3(
309
+ dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
310
+ 0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
311
+ 0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]
312
+ );
313
+
314
+ // Compute loss gradient w.r.t. matrix M
315
+ // dSigma_dM = 2 * M
316
+ glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
317
+
318
+ glm::mat3 Rt = glm::transpose(R);
319
+ glm::mat3 dL_dMt = glm::transpose(dL_dM);
320
+
321
+ // Gradients of loss w.r.t. scale
322
+ glm::vec3* dL_dscale = dL_dscales + idx;
323
+ dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
324
+ dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
325
+ dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
326
+
327
+ dL_dMt[0] *= s.x;
328
+ dL_dMt[1] *= s.y;
329
+ dL_dMt[2] *= s.z;
330
+
331
+ // Gradients of loss w.r.t. normalized quaternion
332
+ glm::vec4 dL_dq;
333
+ dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
334
+ dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
335
+ dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
336
+ dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
337
+
338
+ // Gradients of loss w.r.t. unnormalized quaternion
339
+ float4* dL_drot = (float4*)(dL_drots + idx);
340
+ *dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
341
+ }
342
+
343
+ // Backward pass of the preprocessing steps, except
344
+ // for the covariance computation and inversion
345
+ // (those are handled by a previous kernel call)
346
+ template<int C>
347
+ __global__ void preprocessCUDA(
348
+ int P, int D, int M,
349
+ const float3* means,
350
+ const int* radii,
351
+ const float* shs,
352
+ const bool* clamped,
353
+ const glm::vec3* scales,
354
+ const glm::vec4* rotations,
355
+ const float scale_modifier,
356
+ const float* proj,
357
+ const glm::vec3* campos,
358
+ const float3* dL_dmean2D,
359
+ glm::vec3* dL_dmeans,
360
+ float* dL_dcolor,
361
+ float* dL_dcov3D,
362
+ float* dL_dsh,
363
+ glm::vec3* dL_dscale,
364
+ glm::vec4* dL_drot)
365
+ {
366
+ auto idx = cg::this_grid().thread_rank();
367
+ if (idx >= P || !(radii[idx] > 0))
368
+ return;
369
+
370
+ float3 m = means[idx];
371
+
372
+ // Taking care of gradients from the screenspace points
373
+ float4 m_hom = transformPoint4x4(m, proj);
374
+ float m_w = 1.0f / (m_hom.w + 0.0000001f);
375
+
376
+ // Compute loss gradient w.r.t. 3D means due to gradients of 2D means
377
+ // from rendering procedure
378
+ glm::vec3 dL_dmean;
379
+ float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
380
+ float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
381
+ dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
382
+ dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
383
+ dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
384
+
385
+ // That's the second part of the mean gradient. Previous computation
386
+ // of cov2D and following SH conversion also affects it.
387
+ dL_dmeans[idx] += dL_dmean;
388
+
389
+ // Compute gradient updates due to computing colors from SHs
390
+ if (shs)
391
+ computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
392
+
393
+ // Compute gradient updates due to computing covariance from scale/rotation
394
+ if (scales)
395
+ computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot);
396
+ }
397
+
398
+ // Backward version of the rendering procedure.
399
+ template <uint32_t C>
400
+ __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
401
+ renderCUDA(
402
+ const uint2* __restrict__ ranges,
403
+ const uint32_t* __restrict__ point_list,
404
+ int W, int H,
405
+ const float* __restrict__ bg_color,
406
+ const float2* __restrict__ points_xy_image,
407
+ const float4* __restrict__ conic_opacity,
408
+ const float* __restrict__ colors,
409
+ const float* __restrict__ final_Ts,
410
+ const uint32_t* __restrict__ n_contrib,
411
+ const float* __restrict__ dL_dpixels,
412
+ float3* __restrict__ dL_dmean2D,
413
+ float4* __restrict__ dL_dconic2D,
414
+ float* __restrict__ dL_dopacity,
415
+ float* __restrict__ dL_dcolors)
416
+ {
417
+ // We rasterize again. Compute necessary block info.
418
+ auto block = cg::this_thread_block();
419
+ const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
420
+ const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
421
+ const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
422
+ const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
423
+ const uint32_t pix_id = W * pix.y + pix.x;
424
+ const float2 pixf = { (float)pix.x, (float)pix.y };
425
+
426
+ const bool inside = pix.x < W&& pix.y < H;
427
+ const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
428
+
429
+ const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
430
+
431
+ bool done = !inside;
432
+ int toDo = range.y - range.x;
433
+
434
+ __shared__ int collected_id[BLOCK_SIZE];
435
+ __shared__ float2 collected_xy[BLOCK_SIZE];
436
+ __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
437
+ __shared__ float collected_colors[C * BLOCK_SIZE];
438
+
439
+ // In the forward, we stored the final value for T, the
440
+ // product of all (1 - alpha) factors.
441
+ const float T_final = inside ? final_Ts[pix_id] : 0;
442
+ float T = T_final;
443
+
444
+ // We start from the back. The ID of the last contributing
445
+ // Gaussian is known from each pixel from the forward.
446
+ uint32_t contributor = toDo;
447
+ const int last_contributor = inside ? n_contrib[pix_id] : 0;
448
+
449
+ float accum_rec[C] = { 0 };
450
+ float dL_dpixel[C];
451
+ if (inside)
452
+ for (int i = 0; i < C; i++)
453
+ dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
454
+
455
+ float last_alpha = 0;
456
+ float last_color[C] = { 0 };
457
+
458
+ // Gradient of pixel coordinate w.r.t. normalized
459
+ // screen-space viewport corrdinates (-1 to 1)
460
+ const float ddelx_dx = 0.5 * W;
461
+ const float ddely_dy = 0.5 * H;
462
+
463
+ // Traverse all Gaussians
464
+ for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
465
+ {
466
+ // Load auxiliary data into shared memory, start in the BACK
467
+ // and load them in revers order.
468
+ block.sync();
469
+ const int progress = i * BLOCK_SIZE + block.thread_rank();
470
+ if (range.x + progress < range.y)
471
+ {
472
+ const int coll_id = point_list[range.y - progress - 1];
473
+ collected_id[block.thread_rank()] = coll_id;
474
+ collected_xy[block.thread_rank()] = points_xy_image[coll_id];
475
+ collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
476
+ for (int i = 0; i < C; i++)
477
+ collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
478
+ }
479
+ block.sync();
480
+
481
+ // Iterate over Gaussians
482
+ for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
483
+ {
484
+ // Keep track of current Gaussian ID. Skip, if this one
485
+ // is behind the last contributor for this pixel.
486
+ contributor--;
487
+ if (contributor >= last_contributor)
488
+ continue;
489
+
490
+ // Compute blending values, as before.
491
+ const float2 xy = collected_xy[j];
492
+ const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
493
+ const float4 con_o = collected_conic_opacity[j];
494
+ const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
495
+ if (power > 0.0f)
496
+ continue;
497
+
498
+ const float G = exp(power);
499
+ const float alpha = min(0.99f, con_o.w * G);
500
+ if (alpha < 1.0f / 255.0f)
501
+ continue;
502
+
503
+ T = T / (1.f - alpha);
504
+ const float dchannel_dcolor = alpha * T;
505
+
506
+ // Propagate gradients to per-Gaussian colors and keep
507
+ // gradients w.r.t. alpha (blending factor for a Gaussian/pixel
508
+ // pair).
509
+ float dL_dalpha = 0.0f;
510
+ const int global_id = collected_id[j];
511
+ for (int ch = 0; ch < C; ch++)
512
+ {
513
+ const float c = collected_colors[ch * BLOCK_SIZE + j];
514
+ // Update last color (to be used in the next iteration)
515
+ accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
516
+ last_color[ch] = c;
517
+
518
+ const float dL_dchannel = dL_dpixel[ch];
519
+ dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
520
+ // Update the gradients w.r.t. color of the Gaussian.
521
+ // Atomic, since this pixel is just one of potentially
522
+ // many that were affected by this Gaussian.
523
+ atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
524
+ }
525
+ dL_dalpha *= T;
526
+ // Update last alpha (to be used in the next iteration)
527
+ last_alpha = alpha;
528
+
529
+ // Account for fact that alpha also influences how much of
530
+ // the background color is added if nothing left to blend
531
+ float bg_dot_dpixel = 0;
532
+ for (int i = 0; i < C; i++)
533
+ bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
534
+ dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
535
+
536
+
537
+ // Helpful reusable temporary variables
538
+ const float dL_dG = con_o.w * dL_dalpha;
539
+ const float gdx = G * d.x;
540
+ const float gdy = G * d.y;
541
+ const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
542
+ const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
543
+
544
+ // Update gradients w.r.t. 2D mean position of the Gaussian
545
+ atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
546
+ atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
547
+
548
+ // Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
549
+ atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
550
+ atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
551
+ atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
552
+
553
+ // Update gradients w.r.t. opacity of the Gaussian
554
+ atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
555
+ }
556
+ }
557
+ }
558
+
559
+ void BACKWARD::preprocess(
560
+ int P, int D, int M,
561
+ const float3* means3D,
562
+ const int* radii,
563
+ const float* shs,
564
+ const bool* clamped,
565
+ const glm::vec3* scales,
566
+ const glm::vec4* rotations,
567
+ const float scale_modifier,
568
+ const float* cov3Ds,
569
+ const float* viewmatrix,
570
+ const float* projmatrix,
571
+ const float focal_x, float focal_y,
572
+ const float tan_fovx, float tan_fovy,
573
+ const glm::vec3* campos,
574
+ const float3* dL_dmean2D,
575
+ const float* dL_dconic,
576
+ glm::vec3* dL_dmean3D,
577
+ float* dL_dcolor,
578
+ float* dL_dcov3D,
579
+ float* dL_dsh,
580
+ glm::vec3* dL_dscale,
581
+ glm::vec4* dL_drot)
582
+ {
583
+ // Propagate gradients for the path of 2D conic matrix computation.
584
+ // Somewhat long, thus it is its own kernel rather than being part of
585
+ // "preprocess". When done, loss gradient w.r.t. 3D means has been
586
+ // modified and gradient w.r.t. 3D covariance matrix has been computed.
587
+ computeCov2DCUDA << <(P + 255) / 256, 256 >> > (
588
+ P,
589
+ means3D,
590
+ radii,
591
+ cov3Ds,
592
+ focal_x,
593
+ focal_y,
594
+ tan_fovx,
595
+ tan_fovy,
596
+ viewmatrix,
597
+ dL_dconic,
598
+ (float3*)dL_dmean3D,
599
+ dL_dcov3D);
600
+
601
+ // Propagate gradients for remaining steps: finish 3D mean gradients,
602
+ // propagate color gradients to SH (if desireD), propagate 3D covariance
603
+ // matrix gradients to scale and rotation.
604
+ preprocessCUDA<NUM_CHANNELS> << < (P + 255) / 256, 256 >> > (
605
+ P, D, M,
606
+ (float3*)means3D,
607
+ radii,
608
+ shs,
609
+ clamped,
610
+ (glm::vec3*)scales,
611
+ (glm::vec4*)rotations,
612
+ scale_modifier,
613
+ projmatrix,
614
+ campos,
615
+ (float3*)dL_dmean2D,
616
+ (glm::vec3*)dL_dmean3D,
617
+ dL_dcolor,
618
+ dL_dcov3D,
619
+ dL_dsh,
620
+ dL_dscale,
621
+ dL_drot);
622
+ }
623
+
624
+ void BACKWARD::render(
625
+ const dim3 grid, const dim3 block,
626
+ const uint2* ranges,
627
+ const uint32_t* point_list,
628
+ int W, int H,
629
+ const float* bg_color,
630
+ const float2* means2D,
631
+ const float4* conic_opacity,
632
+ const float* colors,
633
+ const float* final_Ts,
634
+ const uint32_t* n_contrib,
635
+ const float* dL_dpixels,
636
+ float3* dL_dmean2D,
637
+ float4* dL_dconic2D,
638
+ float* dL_dopacity,
639
+ float* dL_dcolors)
640
+ {
641
+ renderCUDA<NUM_CHANNELS> << <grid, block >> >(
642
+ ranges,
643
+ point_list,
644
+ W, H,
645
+ bg_color,
646
+ means2D,
647
+ conic_opacity,
648
+ colors,
649
+ final_Ts,
650
+ n_contrib,
651
+ dL_dpixels,
652
+ dL_dmean2D,
653
+ dL_dconic2D,
654
+ dL_dopacity,
655
+ dL_dcolors
656
+ );
657
+ }
submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
13
+ #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
14
+
15
+ #include <cuda.h>
16
+ #include "cuda_runtime.h"
17
+ #include "device_launch_parameters.h"
18
+ #define GLM_FORCE_CUDA
19
+ #include <glm/glm.hpp>
20
+
21
+ namespace BACKWARD
22
+ {
23
+ void render(
24
+ const dim3 grid, dim3 block,
25
+ const uint2* ranges,
26
+ const uint32_t* point_list,
27
+ int W, int H,
28
+ const float* bg_color,
29
+ const float2* means2D,
30
+ const float4* conic_opacity,
31
+ const float* colors,
32
+ const float* final_Ts,
33
+ const uint32_t* n_contrib,
34
+ const float* dL_dpixels,
35
+ float3* dL_dmean2D,
36
+ float4* dL_dconic2D,
37
+ float* dL_dopacity,
38
+ float* dL_dcolors);
39
+
40
+ void preprocess(
41
+ int P, int D, int M,
42
+ const float3* means,
43
+ const int* radii,
44
+ const float* shs,
45
+ const bool* clamped,
46
+ const glm::vec3* scales,
47
+ const glm::vec4* rotations,
48
+ const float scale_modifier,
49
+ const float* cov3Ds,
50
+ const float* view,
51
+ const float* proj,
52
+ const float focal_x, float focal_y,
53
+ const float tan_fovx, float tan_fovy,
54
+ const glm::vec3* campos,
55
+ const float3* dL_dmean2D,
56
+ const float* dL_dconics,
57
+ glm::vec3* dL_dmeans,
58
+ float* dL_dcolor,
59
+ float* dL_dcov3D,
60
+ float* dL_dsh,
61
+ glm::vec3* dL_dscale,
62
+ glm::vec4* dL_drot);
63
+ }
64
+
65
+ #endif
submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
13
+ #define CUDA_RASTERIZER_CONFIG_H_INCLUDED
14
+
15
+ #define NUM_CHANNELS 3 // Default 3, RGB
16
+ #define BLOCK_X 16
17
+ #define BLOCK_Y 16
18
+
19
+ #endif
submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "forward.h"
13
+ #include "auxiliary.h"
14
+ #include <cooperative_groups.h>
15
+ #include <cooperative_groups/reduce.h>
16
+ namespace cg = cooperative_groups;
17
+
18
+ // Forward method for converting the input spherical harmonics
19
+ // coefficients of each Gaussian to a simple RGB color.
20
+ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
21
+ {
22
+ // The implementation is loosely based on code for
23
+ // "Differentiable Point-Based Radiance Fields for
24
+ // Efficient View Synthesis" by Zhang et al. (2022)
25
+ glm::vec3 pos = means[idx];
26
+ glm::vec3 dir = pos - campos;
27
+ dir = dir / glm::length(dir);
28
+
29
+ glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
30
+ glm::vec3 result = SH_C0 * sh[0];
31
+
32
+ if (deg > 0)
33
+ {
34
+ float x = dir.x;
35
+ float y = dir.y;
36
+ float z = dir.z;
37
+ result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
38
+
39
+ if (deg > 1)
40
+ {
41
+ float xx = x * x, yy = y * y, zz = z * z;
42
+ float xy = x * y, yz = y * z, xz = x * z;
43
+ result = result +
44
+ SH_C2[0] * xy * sh[4] +
45
+ SH_C2[1] * yz * sh[5] +
46
+ SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
47
+ SH_C2[3] * xz * sh[7] +
48
+ SH_C2[4] * (xx - yy) * sh[8];
49
+
50
+ if (deg > 2)
51
+ {
52
+ result = result +
53
+ SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
54
+ SH_C3[1] * xy * z * sh[10] +
55
+ SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
56
+ SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
57
+ SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
58
+ SH_C3[5] * z * (xx - yy) * sh[14] +
59
+ SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
60
+ }
61
+ }
62
+ }
63
+ result += 0.5f;
64
+
65
+ // RGB colors are clamped to positive values. If values are
66
+ // clamped, we need to keep track of this for the backward pass.
67
+ clamped[3 * idx + 0] = (result.x < 0);
68
+ clamped[3 * idx + 1] = (result.y < 0);
69
+ clamped[3 * idx + 2] = (result.z < 0);
70
+ return glm::max(result, 0.0f);
71
+ }
72
+
73
+ // Forward version of 2D covariance matrix computation
74
+ __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
75
+ {
76
+ // The following models the steps outlined by equations 29
77
+ // and 31 in "EWA Splatting" (Zwicker et al., 2002).
78
+ // Additionally considers aspect / scaling of viewport.
79
+ // Transposes used to account for row-/column-major conventions.
80
+ float3 t = transformPoint4x3(mean, viewmatrix);
81
+
82
+ const float limx = 1.3f * tan_fovx;
83
+ const float limy = 1.3f * tan_fovy;
84
+ const float txtz = t.x / t.z;
85
+ const float tytz = t.y / t.z;
86
+ t.x = min(limx, max(-limx, txtz)) * t.z;
87
+ t.y = min(limy, max(-limy, tytz)) * t.z;
88
+
89
+ glm::mat3 J = glm::mat3(
90
+ focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
91
+ 0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
92
+ 0, 0, 0);
93
+
94
+ glm::mat3 W = glm::mat3(
95
+ viewmatrix[0], viewmatrix[4], viewmatrix[8],
96
+ viewmatrix[1], viewmatrix[5], viewmatrix[9],
97
+ viewmatrix[2], viewmatrix[6], viewmatrix[10]);
98
+
99
+ glm::mat3 T = W * J;
100
+
101
+ glm::mat3 Vrk = glm::mat3(
102
+ cov3D[0], cov3D[1], cov3D[2],
103
+ cov3D[1], cov3D[3], cov3D[4],
104
+ cov3D[2], cov3D[4], cov3D[5]);
105
+
106
+ glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
107
+
108
+ // Apply low-pass filter: every Gaussian should be at least
109
+ // one pixel wide/high. Discard 3rd row and column.
110
+ cov[0][0] += 0.3f;
111
+ cov[1][1] += 0.3f;
112
+ return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) };
113
+ }
114
+
115
+ // Forward method for converting scale and rotation properties of each
116
+ // Gaussian to a 3D covariance matrix in world space. Also takes care
117
+ // of quaternion normalization.
118
+ __device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
119
+ {
120
+ // Create scaling matrix
121
+ glm::mat3 S = glm::mat3(1.0f);
122
+ S[0][0] = mod * scale.x;
123
+ S[1][1] = mod * scale.y;
124
+ S[2][2] = mod * scale.z;
125
+
126
+ // Normalize quaternion to get valid rotation
127
+ glm::vec4 q = rot;// / glm::length(rot);
128
+ float r = q.x;
129
+ float x = q.y;
130
+ float y = q.z;
131
+ float z = q.w;
132
+
133
+ // Compute rotation matrix from quaternion
134
+ glm::mat3 R = glm::mat3(
135
+ 1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
136
+ 2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
137
+ 2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
138
+ );
139
+
140
+ glm::mat3 M = S * R;
141
+
142
+ // Compute 3D world covariance matrix Sigma
143
+ glm::mat3 Sigma = glm::transpose(M) * M;
144
+
145
+ // Covariance is symmetric, only store upper right
146
+ cov3D[0] = Sigma[0][0];
147
+ cov3D[1] = Sigma[0][1];
148
+ cov3D[2] = Sigma[0][2];
149
+ cov3D[3] = Sigma[1][1];
150
+ cov3D[4] = Sigma[1][2];
151
+ cov3D[5] = Sigma[2][2];
152
+ }
153
+
154
+ // Perform initial steps for each Gaussian prior to rasterization.
155
+ template<int C>
156
+ __global__ void preprocessCUDA(int P, int D, int M,
157
+ const float* orig_points,
158
+ const glm::vec3* scales,
159
+ const float scale_modifier,
160
+ const glm::vec4* rotations,
161
+ const float* opacities,
162
+ const float* shs,
163
+ bool* clamped,
164
+ const float* cov3D_precomp,
165
+ const float* colors_precomp,
166
+ const float* viewmatrix,
167
+ const float* projmatrix,
168
+ const glm::vec3* cam_pos,
169
+ const int W, int H,
170
+ const float tan_fovx, float tan_fovy,
171
+ const float focal_x, float focal_y,
172
+ int* radii,
173
+ float2* points_xy_image,
174
+ float* depths,
175
+ float* cov3Ds,
176
+ float* rgb,
177
+ float4* conic_opacity,
178
+ const dim3 grid,
179
+ uint32_t* tiles_touched,
180
+ bool prefiltered)
181
+ {
182
+ auto idx = cg::this_grid().thread_rank();
183
+ if (idx >= P)
184
+ return;
185
+
186
+ // Initialize radius and touched tiles to 0. If this isn't changed,
187
+ // this Gaussian will not be processed further.
188
+ radii[idx] = 0;
189
+ tiles_touched[idx] = 0;
190
+
191
+ // Perform near culling, quit if outside.
192
+ float3 p_view;
193
+ if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
194
+ return;
195
+
196
+ // Transform point by projecting
197
+ float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
198
+ float4 p_hom = transformPoint4x4(p_orig, projmatrix);
199
+ float p_w = 1.0f / (p_hom.w + 0.0000001f);
200
+ float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
201
+
202
+ // If 3D covariance matrix is precomputed, use it, otherwise compute
203
+ // from scaling and rotation parameters.
204
+ const float* cov3D;
205
+ if (cov3D_precomp != nullptr)
206
+ {
207
+ cov3D = cov3D_precomp + idx * 6;
208
+ }
209
+ else
210
+ {
211
+ computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
212
+ cov3D = cov3Ds + idx * 6;
213
+ }
214
+
215
+ // Compute 2D screen-space covariance matrix
216
+ float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
217
+
218
+ // Invert covariance (EWA algorithm)
219
+ float det = (cov.x * cov.z - cov.y * cov.y);
220
+ if (det == 0.0f)
221
+ return;
222
+ float det_inv = 1.f / det;
223
+ float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
224
+
225
+ // Compute extent in screen space (by finding eigenvalues of
226
+ // 2D covariance matrix). Use extent to compute a bounding rectangle
227
+ // of screen-space tiles that this Gaussian overlaps with. Quit if
228
+ // rectangle covers 0 tiles.
229
+ float mid = 0.5f * (cov.x + cov.z);
230
+ float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
231
+ float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
232
+ float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
233
+ float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
234
+ uint2 rect_min, rect_max;
235
+ getRect(point_image, my_radius, rect_min, rect_max, grid);
236
+ if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
237
+ return;
238
+
239
+ // If colors have been precomputed, use them, otherwise convert
240
+ // spherical harmonics coefficients to RGB color.
241
+ if (colors_precomp == nullptr)
242
+ {
243
+ glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
244
+ rgb[idx * C + 0] = result.x;
245
+ rgb[idx * C + 1] = result.y;
246
+ rgb[idx * C + 2] = result.z;
247
+ }
248
+
249
+ // Store some useful helper data for the next steps.
250
+ depths[idx] = p_view.z;
251
+ radii[idx] = my_radius;
252
+ points_xy_image[idx] = point_image;
253
+ // Inverse 2D covariance and opacity neatly pack into one float4
254
+ conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] };
255
+ tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
256
+ }
257
+
258
+ // Main rasterization method. Collaboratively works on one tile per
259
+ // block, each thread treats one pixel. Alternates between fetching
260
+ // and rasterizing data.
261
+ template <uint32_t CHANNELS>
262
+ __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
263
+ renderCUDA(
264
+ const uint2* __restrict__ ranges,
265
+ const uint32_t* __restrict__ point_list,
266
+ int W, int H,
267
+ const float2* __restrict__ points_xy_image,
268
+ const float* __restrict__ features,
269
+ const float4* __restrict__ conic_opacity,
270
+ float* __restrict__ final_T,
271
+ uint32_t* __restrict__ n_contrib,
272
+ const float* __restrict__ bg_color,
273
+ float* __restrict__ out_color)
274
+ {
275
+ // Identify current tile and associated min/max pixel range.
276
+ auto block = cg::this_thread_block();
277
+ uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
278
+ uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
279
+ uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
280
+ uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
281
+ uint32_t pix_id = W * pix.y + pix.x;
282
+ float2 pixf = { (float)pix.x, (float)pix.y };
283
+
284
+ // Check if this thread is associated with a valid pixel or outside.
285
+ bool inside = pix.x < W&& pix.y < H;
286
+ // Done threads can help with fetching, but don't rasterize
287
+ bool done = !inside;
288
+
289
+ // Load start/end range of IDs to process in bit sorted list.
290
+ uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
291
+ const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
292
+ int toDo = range.y - range.x;
293
+
294
+ // Allocate storage for batches of collectively fetched data.
295
+ __shared__ int collected_id[BLOCK_SIZE];
296
+ __shared__ float2 collected_xy[BLOCK_SIZE];
297
+ __shared__ float4 collected_conic_opacity[BLOCK_SIZE];
298
+
299
+ // Initialize helper variables
300
+ float T = 1.0f;
301
+ uint32_t contributor = 0;
302
+ uint32_t last_contributor = 0;
303
+ float C[CHANNELS] = { 0 };
304
+
305
+ // Iterate over batches until all done or range is complete
306
+ for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
307
+ {
308
+ // End if entire block votes that it is done rasterizing
309
+ int num_done = __syncthreads_count(done);
310
+ if (num_done == BLOCK_SIZE)
311
+ break;
312
+
313
+ // Collectively fetch per-Gaussian data from global to shared
314
+ int progress = i * BLOCK_SIZE + block.thread_rank();
315
+ if (range.x + progress < range.y)
316
+ {
317
+ int coll_id = point_list[range.x + progress];
318
+ collected_id[block.thread_rank()] = coll_id;
319
+ collected_xy[block.thread_rank()] = points_xy_image[coll_id];
320
+ collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
321
+ }
322
+ block.sync();
323
+
324
+ // Iterate over current batch
325
+ for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
326
+ {
327
+ // Keep track of current position in range
328
+ contributor++;
329
+
330
+ // Resample using conic matrix (cf. "Surface
331
+ // Splatting" by Zwicker et al., 2001)
332
+ float2 xy = collected_xy[j];
333
+ float2 d = { xy.x - pixf.x, xy.y - pixf.y };
334
+ float4 con_o = collected_conic_opacity[j];
335
+ float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
336
+ if (power > 0.0f)
337
+ continue;
338
+
339
+ // Eq. (2) from 3D Gaussian splatting paper.
340
+ // Obtain alpha by multiplying with Gaussian opacity
341
+ // and its exponential falloff from mean.
342
+ // Avoid numerical instabilities (see paper appendix).
343
+ float alpha = min(0.99f, con_o.w * exp(power));
344
+ if (alpha < 1.0f / 255.0f)
345
+ continue;
346
+ float test_T = T * (1 - alpha);
347
+ if (test_T < 0.0001f)
348
+ {
349
+ done = true;
350
+ continue;
351
+ }
352
+
353
+ // Eq. (3) from 3D Gaussian splatting paper.
354
+ for (int ch = 0; ch < CHANNELS; ch++)
355
+ C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
356
+
357
+ T = test_T;
358
+
359
+ // Keep track of last range entry to update this
360
+ // pixel.
361
+ last_contributor = contributor;
362
+ }
363
+ }
364
+
365
+ // All threads that treat valid pixel write out their final
366
+ // rendering data to the frame and auxiliary buffers.
367
+ if (inside)
368
+ {
369
+ final_T[pix_id] = T;
370
+ n_contrib[pix_id] = last_contributor;
371
+ for (int ch = 0; ch < CHANNELS; ch++)
372
+ out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
373
+ }
374
+ }
375
+
376
+ void FORWARD::render(
377
+ const dim3 grid, dim3 block,
378
+ const uint2* ranges,
379
+ const uint32_t* point_list,
380
+ int W, int H,
381
+ const float2* means2D,
382
+ const float* colors,
383
+ const float4* conic_opacity,
384
+ float* final_T,
385
+ uint32_t* n_contrib,
386
+ const float* bg_color,
387
+ float* out_color)
388
+ {
389
+ renderCUDA<NUM_CHANNELS> << <grid, block >> > (
390
+ ranges,
391
+ point_list,
392
+ W, H,
393
+ means2D,
394
+ colors,
395
+ conic_opacity,
396
+ final_T,
397
+ n_contrib,
398
+ bg_color,
399
+ out_color);
400
+ }
401
+
402
+ void FORWARD::preprocess(int P, int D, int M,
403
+ const float* means3D,
404
+ const glm::vec3* scales,
405
+ const float scale_modifier,
406
+ const glm::vec4* rotations,
407
+ const float* opacities,
408
+ const float* shs,
409
+ bool* clamped,
410
+ const float* cov3D_precomp,
411
+ const float* colors_precomp,
412
+ const float* viewmatrix,
413
+ const float* projmatrix,
414
+ const glm::vec3* cam_pos,
415
+ const int W, int H,
416
+ const float focal_x, float focal_y,
417
+ const float tan_fovx, float tan_fovy,
418
+ int* radii,
419
+ float2* means2D,
420
+ float* depths,
421
+ float* cov3Ds,
422
+ float* rgb,
423
+ float4* conic_opacity,
424
+ const dim3 grid,
425
+ uint32_t* tiles_touched,
426
+ bool prefiltered)
427
+ {
428
+ preprocessCUDA<NUM_CHANNELS> << <(P + 255) / 256, 256 >> > (
429
+ P, D, M,
430
+ means3D,
431
+ scales,
432
+ scale_modifier,
433
+ rotations,
434
+ opacities,
435
+ shs,
436
+ clamped,
437
+ cov3D_precomp,
438
+ colors_precomp,
439
+ viewmatrix,
440
+ projmatrix,
441
+ cam_pos,
442
+ W, H,
443
+ tan_fovx, tan_fovy,
444
+ focal_x, focal_y,
445
+ radii,
446
+ means2D,
447
+ depths,
448
+ cov3Ds,
449
+ rgb,
450
+ conic_opacity,
451
+ grid,
452
+ tiles_touched,
453
+ prefiltered
454
+ );
455
+ }
submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
13
+ #define CUDA_RASTERIZER_FORWARD_H_INCLUDED
14
+
15
+ #include <cuda.h>
16
+ #include "cuda_runtime.h"
17
+ #include "device_launch_parameters.h"
18
+ #define GLM_FORCE_CUDA
19
+ #include <glm/glm.hpp>
20
+
21
+ namespace FORWARD
22
+ {
23
+ // Perform initial steps for each Gaussian prior to rasterization.
24
+ void preprocess(int P, int D, int M,
25
+ const float* orig_points,
26
+ const glm::vec3* scales,
27
+ const float scale_modifier,
28
+ const glm::vec4* rotations,
29
+ const float* opacities,
30
+ const float* shs,
31
+ bool* clamped,
32
+ const float* cov3D_precomp,
33
+ const float* colors_precomp,
34
+ const float* viewmatrix,
35
+ const float* projmatrix,
36
+ const glm::vec3* cam_pos,
37
+ const int W, int H,
38
+ const float focal_x, float focal_y,
39
+ const float tan_fovx, float tan_fovy,
40
+ int* radii,
41
+ float2* points_xy_image,
42
+ float* depths,
43
+ float* cov3Ds,
44
+ float* colors,
45
+ float4* conic_opacity,
46
+ const dim3 grid,
47
+ uint32_t* tiles_touched,
48
+ bool prefiltered);
49
+
50
+ // Main rasterization method.
51
+ void render(
52
+ const dim3 grid, dim3 block,
53
+ const uint2* ranges,
54
+ const uint32_t* point_list,
55
+ int W, int H,
56
+ const float2* points_xy_image,
57
+ const float* features,
58
+ const float4* conic_opacity,
59
+ float* final_T,
60
+ uint32_t* n_contrib,
61
+ const float* bg_color,
62
+ float* out_color);
63
+ }
64
+
65
+
66
+ #endif
submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #ifndef CUDA_RASTERIZER_H_INCLUDED
13
+ #define CUDA_RASTERIZER_H_INCLUDED
14
+
15
+ #include <vector>
16
+ #include <functional>
17
+
18
+ namespace CudaRasterizer
19
+ {
20
+ class Rasterizer
21
+ {
22
+ public:
23
+
24
+ static void markVisible(
25
+ int P,
26
+ float* means3D,
27
+ float* viewmatrix,
28
+ float* projmatrix,
29
+ bool* present);
30
+
31
+ static int forward(
32
+ std::function<char* (size_t)> geometryBuffer,
33
+ std::function<char* (size_t)> binningBuffer,
34
+ std::function<char* (size_t)> imageBuffer,
35
+ const int P, int D, int M,
36
+ const float* background,
37
+ const int width, int height,
38
+ const float* means3D,
39
+ const float* shs,
40
+ const float* colors_precomp,
41
+ const float* opacities,
42
+ const float* scales,
43
+ const float scale_modifier,
44
+ const float* rotations,
45
+ const float* cov3D_precomp,
46
+ const float* viewmatrix,
47
+ const float* projmatrix,
48
+ const float* cam_pos,
49
+ const float tan_fovx, float tan_fovy,
50
+ const bool prefiltered,
51
+ float* out_color,
52
+ int* radii = nullptr,
53
+ bool debug = false);
54
+
55
+ static void backward(
56
+ const int P, int D, int M, int R,
57
+ const float* background,
58
+ const int width, int height,
59
+ const float* means3D,
60
+ const float* shs,
61
+ const float* colors_precomp,
62
+ const float* scales,
63
+ const float scale_modifier,
64
+ const float* rotations,
65
+ const float* cov3D_precomp,
66
+ const float* viewmatrix,
67
+ const float* projmatrix,
68
+ const float* campos,
69
+ const float tan_fovx, float tan_fovy,
70
+ const int* radii,
71
+ char* geom_buffer,
72
+ char* binning_buffer,
73
+ char* image_buffer,
74
+ const float* dL_dpix,
75
+ float* dL_dmean2D,
76
+ float* dL_dconic,
77
+ float* dL_dopacity,
78
+ float* dL_dcolor,
79
+ float* dL_dmean3D,
80
+ float* dL_dcov3D,
81
+ float* dL_dsh,
82
+ float* dL_dscale,
83
+ float* dL_drot,
84
+ bool debug);
85
+ };
86
+ };
87
+
88
+ #endif
submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include "rasterizer_impl.h"
13
+ #include <iostream>
14
+ #include <fstream>
15
+ #include <algorithm>
16
+ #include <numeric>
17
+ #include <cuda.h>
18
+ #include "cuda_runtime.h"
19
+ #include "device_launch_parameters.h"
20
+ #include <cub/cub.cuh>
21
+ #include <cub/device/device_radix_sort.cuh>
22
+ #define GLM_FORCE_CUDA
23
+ #include <glm/glm.hpp>
24
+
25
+ #include <cooperative_groups.h>
26
+ #include <cooperative_groups/reduce.h>
27
+ namespace cg = cooperative_groups;
28
+
29
+ #include "auxiliary.h"
30
+ #include "forward.h"
31
+ #include "backward.h"
32
+
33
+ // Helper function to find the next-highest bit of the MSB
34
+ // on the CPU.
35
+ uint32_t getHigherMsb(uint32_t n)
36
+ {
37
+ uint32_t msb = sizeof(n) * 4;
38
+ uint32_t step = msb;
39
+ while (step > 1)
40
+ {
41
+ step /= 2;
42
+ if (n >> msb)
43
+ msb += step;
44
+ else
45
+ msb -= step;
46
+ }
47
+ if (n >> msb)
48
+ msb++;
49
+ return msb;
50
+ }
51
+
52
+ // Wrapper method to call auxiliary coarse frustum containment test.
53
+ // Mark all Gaussians that pass it.
54
+ __global__ void checkFrustum(int P,
55
+ const float* orig_points,
56
+ const float* viewmatrix,
57
+ const float* projmatrix,
58
+ bool* present)
59
+ {
60
+ auto idx = cg::this_grid().thread_rank();
61
+ if (idx >= P)
62
+ return;
63
+
64
+ float3 p_view;
65
+ present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
66
+ }
67
+
68
+ // Generates one key/value pair for all Gaussian / tile overlaps.
69
+ // Run once per Gaussian (1:N mapping).
70
+ __global__ void duplicateWithKeys(
71
+ int P,
72
+ const float2* points_xy,
73
+ const float* depths,
74
+ const uint32_t* offsets,
75
+ uint64_t* gaussian_keys_unsorted,
76
+ uint32_t* gaussian_values_unsorted,
77
+ int* radii,
78
+ dim3 grid)
79
+ {
80
+ auto idx = cg::this_grid().thread_rank();
81
+ if (idx >= P)
82
+ return;
83
+
84
+ // Generate no key/value pair for invisible Gaussians
85
+ if (radii[idx] > 0)
86
+ {
87
+ // Find this Gaussian's offset in buffer for writing keys/values.
88
+ uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
89
+ uint2 rect_min, rect_max;
90
+
91
+ getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
92
+
93
+ // For each tile that the bounding rect overlaps, emit a
94
+ // key/value pair. The key is | tile ID | depth |,
95
+ // and the value is the ID of the Gaussian. Sorting the values
96
+ // with this key yields Gaussian IDs in a list, such that they
97
+ // are first sorted by tile and then by depth.
98
+ for (int y = rect_min.y; y < rect_max.y; y++)
99
+ {
100
+ for (int x = rect_min.x; x < rect_max.x; x++)
101
+ {
102
+ uint64_t key = y * grid.x + x;
103
+ key <<= 32;
104
+ key |= *((uint32_t*)&depths[idx]);
105
+ gaussian_keys_unsorted[off] = key;
106
+ gaussian_values_unsorted[off] = idx;
107
+ off++;
108
+ }
109
+ }
110
+ }
111
+ }
112
+
113
+ // Check keys to see if it is at the start/end of one tile's range in
114
+ // the full sorted list. If yes, write start/end of this tile.
115
+ // Run once per instanced (duplicated) Gaussian ID.
116
+ __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges)
117
+ {
118
+ auto idx = cg::this_grid().thread_rank();
119
+ if (idx >= L)
120
+ return;
121
+
122
+ // Read tile ID from key. Update start/end of tile range if at limit.
123
+ uint64_t key = point_list_keys[idx];
124
+ uint32_t currtile = key >> 32;
125
+ if (idx == 0)
126
+ ranges[currtile].x = 0;
127
+ else
128
+ {
129
+ uint32_t prevtile = point_list_keys[idx - 1] >> 32;
130
+ if (currtile != prevtile)
131
+ {
132
+ ranges[prevtile].y = idx;
133
+ ranges[currtile].x = idx;
134
+ }
135
+ }
136
+ if (idx == L - 1)
137
+ ranges[currtile].y = L;
138
+ }
139
+
140
+ // Mark Gaussians as visible/invisible, based on view frustum testing
141
+ void CudaRasterizer::Rasterizer::markVisible(
142
+ int P,
143
+ float* means3D,
144
+ float* viewmatrix,
145
+ float* projmatrix,
146
+ bool* present)
147
+ {
148
+ checkFrustum << <(P + 255) / 256, 256 >> > (
149
+ P,
150
+ means3D,
151
+ viewmatrix, projmatrix,
152
+ present);
153
+ }
154
+
155
+ CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P)
156
+ {
157
+ GeometryState geom;
158
+ obtain(chunk, geom.depths, P, 128);
159
+ obtain(chunk, geom.clamped, P * 3, 128);
160
+ obtain(chunk, geom.internal_radii, P, 128);
161
+ obtain(chunk, geom.means2D, P, 128);
162
+ obtain(chunk, geom.cov3D, P * 6, 128);
163
+ obtain(chunk, geom.conic_opacity, P, 128);
164
+ obtain(chunk, geom.rgb, P * 3, 128);
165
+ obtain(chunk, geom.tiles_touched, P, 128);
166
+ cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
167
+ obtain(chunk, geom.scanning_space, geom.scan_size, 128);
168
+ obtain(chunk, geom.point_offsets, P, 128);
169
+ return geom;
170
+ }
171
+
172
+ CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N)
173
+ {
174
+ ImageState img;
175
+ obtain(chunk, img.accum_alpha, N, 128);
176
+ obtain(chunk, img.n_contrib, N, 128);
177
+ obtain(chunk, img.ranges, N, 128);
178
+ return img;
179
+ }
180
+
181
+ CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P)
182
+ {
183
+ BinningState binning;
184
+ obtain(chunk, binning.point_list, P, 128);
185
+ obtain(chunk, binning.point_list_unsorted, P, 128);
186
+ obtain(chunk, binning.point_list_keys, P, 128);
187
+ obtain(chunk, binning.point_list_keys_unsorted, P, 128);
188
+ cub::DeviceRadixSort::SortPairs(
189
+ nullptr, binning.sorting_size,
190
+ binning.point_list_keys_unsorted, binning.point_list_keys,
191
+ binning.point_list_unsorted, binning.point_list, P);
192
+ obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
193
+ return binning;
194
+ }
195
+
196
+ // Forward rendering procedure for differentiable rasterization
197
+ // of Gaussians.
198
+ int CudaRasterizer::Rasterizer::forward(
199
+ std::function<char* (size_t)> geometryBuffer,
200
+ std::function<char* (size_t)> binningBuffer,
201
+ std::function<char* (size_t)> imageBuffer,
202
+ const int P, int D, int M,
203
+ const float* background,
204
+ const int width, int height,
205
+ const float* means3D,
206
+ const float* shs,
207
+ const float* colors_precomp,
208
+ const float* opacities,
209
+ const float* scales,
210
+ const float scale_modifier,
211
+ const float* rotations,
212
+ const float* cov3D_precomp,
213
+ const float* viewmatrix,
214
+ const float* projmatrix,
215
+ const float* cam_pos,
216
+ const float tan_fovx, float tan_fovy,
217
+ const bool prefiltered,
218
+ float* out_color,
219
+ int* radii,
220
+ bool debug)
221
+ {
222
+ const float focal_y = height / (2.0f * tan_fovy);
223
+ const float focal_x = width / (2.0f * tan_fovx);
224
+
225
+ size_t chunk_size = required<GeometryState>(P);
226
+ char* chunkptr = geometryBuffer(chunk_size);
227
+ GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
228
+
229
+ if (radii == nullptr)
230
+ {
231
+ radii = geomState.internal_radii;
232
+ }
233
+
234
+ dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
235
+ dim3 block(BLOCK_X, BLOCK_Y, 1);
236
+
237
+ // Dynamically resize image-based auxiliary buffers during training
238
+ size_t img_chunk_size = required<ImageState>(width * height);
239
+ char* img_chunkptr = imageBuffer(img_chunk_size);
240
+ ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
241
+
242
+ if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
243
+ {
244
+ throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!");
245
+ }
246
+
247
+ // Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
248
+ CHECK_CUDA(FORWARD::preprocess(
249
+ P, D, M,
250
+ means3D,
251
+ (glm::vec3*)scales,
252
+ scale_modifier,
253
+ (glm::vec4*)rotations,
254
+ opacities,
255
+ shs,
256
+ geomState.clamped,
257
+ cov3D_precomp,
258
+ colors_precomp,
259
+ viewmatrix, projmatrix,
260
+ (glm::vec3*)cam_pos,
261
+ width, height,
262
+ focal_x, focal_y,
263
+ tan_fovx, tan_fovy,
264
+ radii,
265
+ geomState.means2D,
266
+ geomState.depths,
267
+ geomState.cov3D,
268
+ geomState.rgb,
269
+ geomState.conic_opacity,
270
+ tile_grid,
271
+ geomState.tiles_touched,
272
+ prefiltered
273
+ ), debug)
274
+
275
+ // Compute prefix sum over full list of touched tile counts by Gaussians
276
+ // E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
277
+ CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug)
278
+
279
+ // Retrieve total number of Gaussian instances to launch and resize aux buffers
280
+ int num_rendered;
281
+ CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug);
282
+
283
+ size_t binning_chunk_size = required<BinningState>(num_rendered);
284
+ char* binning_chunkptr = binningBuffer(binning_chunk_size);
285
+ BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
286
+
287
+ // For each instance to be rendered, produce adequate [ tile | depth ] key
288
+ // and corresponding dublicated Gaussian indices to be sorted
289
+ duplicateWithKeys << <(P + 255) / 256, 256 >> > (
290
+ P,
291
+ geomState.means2D,
292
+ geomState.depths,
293
+ geomState.point_offsets,
294
+ binningState.point_list_keys_unsorted,
295
+ binningState.point_list_unsorted,
296
+ radii,
297
+ tile_grid)
298
+ CHECK_CUDA(, debug)
299
+
300
+ int bit = getHigherMsb(tile_grid.x * tile_grid.y);
301
+
302
+ // Sort complete list of (duplicated) Gaussian indices by keys
303
+ CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
304
+ binningState.list_sorting_space,
305
+ binningState.sorting_size,
306
+ binningState.point_list_keys_unsorted, binningState.point_list_keys,
307
+ binningState.point_list_unsorted, binningState.point_list,
308
+ num_rendered, 0, 32 + bit), debug)
309
+
310
+ CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug);
311
+
312
+ // Identify start and end of per-tile workloads in sorted list
313
+ if (num_rendered > 0)
314
+ identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
315
+ num_rendered,
316
+ binningState.point_list_keys,
317
+ imgState.ranges);
318
+ CHECK_CUDA(, debug)
319
+
320
+ // Let each tile blend its range of Gaussians independently in parallel
321
+ const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
322
+ CHECK_CUDA(FORWARD::render(
323
+ tile_grid, block,
324
+ imgState.ranges,
325
+ binningState.point_list,
326
+ width, height,
327
+ geomState.means2D,
328
+ feature_ptr,
329
+ geomState.conic_opacity,
330
+ imgState.accum_alpha,
331
+ imgState.n_contrib,
332
+ background,
333
+ out_color), debug)
334
+
335
+ return num_rendered;
336
+ }
337
+
338
+ // Produce necessary gradients for optimization, corresponding
339
+ // to forward render pass
340
+ void CudaRasterizer::Rasterizer::backward(
341
+ const int P, int D, int M, int R,
342
+ const float* background,
343
+ const int width, int height,
344
+ const float* means3D,
345
+ const float* shs,
346
+ const float* colors_precomp,
347
+ const float* scales,
348
+ const float scale_modifier,
349
+ const float* rotations,
350
+ const float* cov3D_precomp,
351
+ const float* viewmatrix,
352
+ const float* projmatrix,
353
+ const float* campos,
354
+ const float tan_fovx, float tan_fovy,
355
+ const int* radii,
356
+ char* geom_buffer,
357
+ char* binning_buffer,
358
+ char* img_buffer,
359
+ const float* dL_dpix,
360
+ float* dL_dmean2D,
361
+ float* dL_dconic,
362
+ float* dL_dopacity,
363
+ float* dL_dcolor,
364
+ float* dL_dmean3D,
365
+ float* dL_dcov3D,
366
+ float* dL_dsh,
367
+ float* dL_dscale,
368
+ float* dL_drot,
369
+ bool debug)
370
+ {
371
+ GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
372
+ BinningState binningState = BinningState::fromChunk(binning_buffer, R);
373
+ ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
374
+
375
+ if (radii == nullptr)
376
+ {
377
+ radii = geomState.internal_radii;
378
+ }
379
+
380
+ const float focal_y = height / (2.0f * tan_fovy);
381
+ const float focal_x = width / (2.0f * tan_fovx);
382
+
383
+ const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
384
+ const dim3 block(BLOCK_X, BLOCK_Y, 1);
385
+
386
+ // Compute loss gradients w.r.t. 2D mean position, conic matrix,
387
+ // opacity and RGB of Gaussians from per-pixel loss gradients.
388
+ // If we were given precomputed colors and not SHs, use them.
389
+ const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
390
+ CHECK_CUDA(BACKWARD::render(
391
+ tile_grid,
392
+ block,
393
+ imgState.ranges,
394
+ binningState.point_list,
395
+ width, height,
396
+ background,
397
+ geomState.means2D,
398
+ geomState.conic_opacity,
399
+ color_ptr,
400
+ imgState.accum_alpha,
401
+ imgState.n_contrib,
402
+ dL_dpix,
403
+ (float3*)dL_dmean2D,
404
+ (float4*)dL_dconic,
405
+ dL_dopacity,
406
+ dL_dcolor), debug)
407
+
408
+ // Take care of the rest of preprocessing. Was the precomputed covariance
409
+ // given to us or a scales/rot pair? If precomputed, pass that. If not,
410
+ // use the one we computed ourselves.
411
+ const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
412
+ CHECK_CUDA(BACKWARD::preprocess(P, D, M,
413
+ (float3*)means3D,
414
+ radii,
415
+ shs,
416
+ geomState.clamped,
417
+ (glm::vec3*)scales,
418
+ (glm::vec4*)rotations,
419
+ scale_modifier,
420
+ cov3D_ptr,
421
+ viewmatrix,
422
+ projmatrix,
423
+ focal_x, focal_y,
424
+ tan_fovx, tan_fovy,
425
+ (glm::vec3*)campos,
426
+ (float3*)dL_dmean2D,
427
+ dL_dconic,
428
+ (glm::vec3*)dL_dmean3D,
429
+ dL_dcolor,
430
+ dL_dcov3D,
431
+ dL_dsh,
432
+ (glm::vec3*)dL_dscale,
433
+ (glm::vec4*)dL_drot), debug)
434
+ }
submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include <iostream>
15
+ #include <vector>
16
+ #include "rasterizer.h"
17
+ #include <cuda_runtime_api.h>
18
+
19
+ namespace CudaRasterizer
20
+ {
21
+ template <typename T>
22
+ static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
23
+ {
24
+ std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
25
+ ptr = reinterpret_cast<T*>(offset);
26
+ chunk = reinterpret_cast<char*>(ptr + count);
27
+ }
28
+
29
+ struct GeometryState
30
+ {
31
+ size_t scan_size;
32
+ float* depths;
33
+ char* scanning_space;
34
+ bool* clamped;
35
+ int* internal_radii;
36
+ float2* means2D;
37
+ float* cov3D;
38
+ float4* conic_opacity;
39
+ float* rgb;
40
+ uint32_t* point_offsets;
41
+ uint32_t* tiles_touched;
42
+
43
+ static GeometryState fromChunk(char*& chunk, size_t P);
44
+ };
45
+
46
+ struct ImageState
47
+ {
48
+ uint2* ranges;
49
+ uint32_t* n_contrib;
50
+ float* accum_alpha;
51
+
52
+ static ImageState fromChunk(char*& chunk, size_t N);
53
+ };
54
+
55
+ struct BinningState
56
+ {
57
+ size_t sorting_size;
58
+ uint64_t* point_list_keys_unsorted;
59
+ uint64_t* point_list_keys;
60
+ uint32_t* point_list_unsorted;
61
+ uint32_t* point_list;
62
+ char* list_sorting_space;
63
+
64
+ static BinningState fromChunk(char*& chunk, size_t P);
65
+ };
66
+
67
+ template<typename T>
68
+ size_t required(size_t P)
69
+ {
70
+ char* size = nullptr;
71
+ T::fromChunk(size, P);
72
+ return ((size_t)size) + 128;
73
+ }
74
+ };
submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from typing import NamedTuple
13
+ import torch.nn as nn
14
+ import torch
15
+ from . import _C
16
+
17
+ def cpu_deep_copy_tuple(input_tuple):
18
+ copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
19
+ return tuple(copied_tensors)
20
+
21
+ def rasterize_gaussians(
22
+ means3D,
23
+ means2D,
24
+ sh,
25
+ colors_precomp,
26
+ opacities,
27
+ scales,
28
+ rotations,
29
+ cov3Ds_precomp,
30
+ raster_settings,
31
+ ):
32
+ return _RasterizeGaussians.apply(
33
+ means3D,
34
+ means2D,
35
+ sh,
36
+ colors_precomp,
37
+ opacities,
38
+ scales,
39
+ rotations,
40
+ cov3Ds_precomp,
41
+ raster_settings,
42
+ )
43
+
44
+ class _RasterizeGaussians(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(
47
+ ctx,
48
+ means3D,
49
+ means2D,
50
+ sh,
51
+ colors_precomp,
52
+ opacities,
53
+ scales,
54
+ rotations,
55
+ cov3Ds_precomp,
56
+ raster_settings,
57
+ ):
58
+
59
+ # Restructure arguments the way that the C++ lib expects them
60
+ args = (
61
+ raster_settings.bg,
62
+ means3D,
63
+ colors_precomp,
64
+ opacities,
65
+ scales,
66
+ rotations,
67
+ raster_settings.scale_modifier,
68
+ cov3Ds_precomp,
69
+ raster_settings.viewmatrix,
70
+ raster_settings.projmatrix,
71
+ raster_settings.tanfovx,
72
+ raster_settings.tanfovy,
73
+ raster_settings.image_height,
74
+ raster_settings.image_width,
75
+ sh,
76
+ raster_settings.sh_degree,
77
+ raster_settings.campos,
78
+ raster_settings.prefiltered,
79
+ raster_settings.debug
80
+ )
81
+
82
+ # Invoke C++/CUDA rasterizer
83
+ if raster_settings.debug:
84
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
85
+ try:
86
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
87
+ except Exception as ex:
88
+ torch.save(cpu_args, "snapshot_fw.dump")
89
+ print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
90
+ raise ex
91
+ else:
92
+ num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
93
+
94
+ # Keep relevant tensors for backward
95
+ ctx.raster_settings = raster_settings
96
+ ctx.num_rendered = num_rendered
97
+ ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
98
+ return color, radii
99
+
100
+ @staticmethod
101
+ def backward(ctx, grad_out_color, _):
102
+
103
+ # Restore necessary values from context
104
+ num_rendered = ctx.num_rendered
105
+ raster_settings = ctx.raster_settings
106
+ colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
107
+
108
+ # Restructure args as C++ method expects them
109
+ args = (raster_settings.bg,
110
+ means3D,
111
+ radii,
112
+ colors_precomp,
113
+ scales,
114
+ rotations,
115
+ raster_settings.scale_modifier,
116
+ cov3Ds_precomp,
117
+ raster_settings.viewmatrix,
118
+ raster_settings.projmatrix,
119
+ raster_settings.tanfovx,
120
+ raster_settings.tanfovy,
121
+ grad_out_color,
122
+ sh,
123
+ raster_settings.sh_degree,
124
+ raster_settings.campos,
125
+ geomBuffer,
126
+ num_rendered,
127
+ binningBuffer,
128
+ imgBuffer,
129
+ raster_settings.debug)
130
+
131
+ # Compute gradients for relevant tensors by invoking backward method
132
+ if raster_settings.debug:
133
+ cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
134
+ try:
135
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
136
+ except Exception as ex:
137
+ torch.save(cpu_args, "snapshot_bw.dump")
138
+ print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
139
+ raise ex
140
+ else:
141
+ grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
142
+
143
+ grads = (
144
+ grad_means3D,
145
+ grad_means2D,
146
+ grad_sh,
147
+ grad_colors_precomp,
148
+ grad_opacities,
149
+ grad_scales,
150
+ grad_rotations,
151
+ grad_cov3Ds_precomp,
152
+ None,
153
+ )
154
+
155
+ return grads
156
+
157
+ class GaussianRasterizationSettings(NamedTuple):
158
+ image_height: int
159
+ image_width: int
160
+ tanfovx : float
161
+ tanfovy : float
162
+ bg : torch.Tensor
163
+ scale_modifier : float
164
+ viewmatrix : torch.Tensor
165
+ projmatrix : torch.Tensor
166
+ sh_degree : int
167
+ campos : torch.Tensor
168
+ prefiltered : bool
169
+ debug : bool
170
+
171
+ class GaussianRasterizer(nn.Module):
172
+ def __init__(self, raster_settings):
173
+ super().__init__()
174
+ self.raster_settings = raster_settings
175
+
176
+ def markVisible(self, positions):
177
+ # Mark visible points (based on frustum culling for camera) with a boolean
178
+ with torch.no_grad():
179
+ raster_settings = self.raster_settings
180
+ visible = _C.mark_visible(
181
+ positions,
182
+ raster_settings.viewmatrix,
183
+ raster_settings.projmatrix)
184
+
185
+ return visible
186
+
187
+ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
188
+
189
+ raster_settings = self.raster_settings
190
+
191
+ if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
192
+ raise Exception('Please provide excatly one of either SHs or precomputed colors!')
193
+
194
+ if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
195
+ raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
196
+
197
+ if shs is None:
198
+ shs = torch.Tensor([])
199
+ if colors_precomp is None:
200
+ colors_precomp = torch.Tensor([])
201
+
202
+ if scales is None:
203
+ scales = torch.Tensor([])
204
+ if rotations is None:
205
+ rotations = torch.Tensor([])
206
+ if cov3D_precomp is None:
207
+ cov3D_precomp = torch.Tensor([])
208
+
209
+ # Invoke C++/CUDA rasterization routine
210
+ return rasterize_gaussians(
211
+ means3D,
212
+ means2D,
213
+ shs,
214
+ colors_precomp,
215
+ opacities,
216
+ scales,
217
+ rotations,
218
+ cov3D_precomp,
219
+ raster_settings,
220
+ )
221
+
submodules/diff-gaussian-rasterization/ext.cpp ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include <torch/extension.h>
13
+ #include "rasterize_points.h"
14
+
15
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16
+ m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
17
+ m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
18
+ m.def("mark_visible", &markVisible);
19
+ }
submodules/diff-gaussian-rasterization/rasterize_points.cu ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #include <math.h>
13
+ #include <torch/extension.h>
14
+ #include <cstdio>
15
+ #include <sstream>
16
+ #include <iostream>
17
+ #include <tuple>
18
+ #include <stdio.h>
19
+ #include <cuda_runtime_api.h>
20
+ #include <memory>
21
+ #include "cuda_rasterizer/config.h"
22
+ #include "cuda_rasterizer/rasterizer.h"
23
+ #include <fstream>
24
+ #include <string>
25
+ #include <functional>
26
+
27
+ std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
28
+ auto lambda = [&t](size_t N) {
29
+ t.resize_({(long long)N});
30
+ return reinterpret_cast<char*>(t.contiguous().data_ptr());
31
+ };
32
+ return lambda;
33
+ }
34
+
35
+ std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
36
+ RasterizeGaussiansCUDA(
37
+ const torch::Tensor& background,
38
+ const torch::Tensor& means3D,
39
+ const torch::Tensor& colors,
40
+ const torch::Tensor& opacity,
41
+ const torch::Tensor& scales,
42
+ const torch::Tensor& rotations,
43
+ const float scale_modifier,
44
+ const torch::Tensor& cov3D_precomp,
45
+ const torch::Tensor& viewmatrix,
46
+ const torch::Tensor& projmatrix,
47
+ const float tan_fovx,
48
+ const float tan_fovy,
49
+ const int image_height,
50
+ const int image_width,
51
+ const torch::Tensor& sh,
52
+ const int degree,
53
+ const torch::Tensor& campos,
54
+ const bool prefiltered,
55
+ const bool debug)
56
+ {
57
+ if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
58
+ AT_ERROR("means3D must have dimensions (num_points, 3)");
59
+ }
60
+
61
+ const int P = means3D.size(0);
62
+ const int H = image_height;
63
+ const int W = image_width;
64
+
65
+ auto int_opts = means3D.options().dtype(torch::kInt32);
66
+ auto float_opts = means3D.options().dtype(torch::kFloat32);
67
+
68
+ torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
69
+ torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
70
+
71
+ torch::Device device(torch::kCUDA);
72
+ torch::TensorOptions options(torch::kByte);
73
+ torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
74
+ torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
75
+ torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
76
+ std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
77
+ std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
78
+ std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
79
+
80
+ int rendered = 0;
81
+ if(P != 0)
82
+ {
83
+ int M = 0;
84
+ if(sh.size(0) != 0)
85
+ {
86
+ M = sh.size(1);
87
+ }
88
+
89
+ rendered = CudaRasterizer::Rasterizer::forward(
90
+ geomFunc,
91
+ binningFunc,
92
+ imgFunc,
93
+ P, degree, M,
94
+ background.contiguous().data<float>(),
95
+ W, H,
96
+ means3D.contiguous().data<float>(),
97
+ sh.contiguous().data_ptr<float>(),
98
+ colors.contiguous().data<float>(),
99
+ opacity.contiguous().data<float>(),
100
+ scales.contiguous().data_ptr<float>(),
101
+ scale_modifier,
102
+ rotations.contiguous().data_ptr<float>(),
103
+ cov3D_precomp.contiguous().data<float>(),
104
+ viewmatrix.contiguous().data<float>(),
105
+ projmatrix.contiguous().data<float>(),
106
+ campos.contiguous().data<float>(),
107
+ tan_fovx,
108
+ tan_fovy,
109
+ prefiltered,
110
+ out_color.contiguous().data<float>(),
111
+ radii.contiguous().data<int>(),
112
+ debug);
113
+ }
114
+ return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
115
+ }
116
+
117
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
118
+ RasterizeGaussiansBackwardCUDA(
119
+ const torch::Tensor& background,
120
+ const torch::Tensor& means3D,
121
+ const torch::Tensor& radii,
122
+ const torch::Tensor& colors,
123
+ const torch::Tensor& scales,
124
+ const torch::Tensor& rotations,
125
+ const float scale_modifier,
126
+ const torch::Tensor& cov3D_precomp,
127
+ const torch::Tensor& viewmatrix,
128
+ const torch::Tensor& projmatrix,
129
+ const float tan_fovx,
130
+ const float tan_fovy,
131
+ const torch::Tensor& dL_dout_color,
132
+ const torch::Tensor& sh,
133
+ const int degree,
134
+ const torch::Tensor& campos,
135
+ const torch::Tensor& geomBuffer,
136
+ const int R,
137
+ const torch::Tensor& binningBuffer,
138
+ const torch::Tensor& imageBuffer,
139
+ const bool debug)
140
+ {
141
+ const int P = means3D.size(0);
142
+ const int H = dL_dout_color.size(1);
143
+ const int W = dL_dout_color.size(2);
144
+
145
+ int M = 0;
146
+ if(sh.size(0) != 0)
147
+ {
148
+ M = sh.size(1);
149
+ }
150
+
151
+ torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
152
+ torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
153
+ torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
154
+ torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
155
+ torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
156
+ torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
157
+ torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
158
+ torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
159
+ torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
160
+
161
+ if(P != 0)
162
+ {
163
+ CudaRasterizer::Rasterizer::backward(P, degree, M, R,
164
+ background.contiguous().data<float>(),
165
+ W, H,
166
+ means3D.contiguous().data<float>(),
167
+ sh.contiguous().data<float>(),
168
+ colors.contiguous().data<float>(),
169
+ scales.data_ptr<float>(),
170
+ scale_modifier,
171
+ rotations.data_ptr<float>(),
172
+ cov3D_precomp.contiguous().data<float>(),
173
+ viewmatrix.contiguous().data<float>(),
174
+ projmatrix.contiguous().data<float>(),
175
+ campos.contiguous().data<float>(),
176
+ tan_fovx,
177
+ tan_fovy,
178
+ radii.contiguous().data<int>(),
179
+ reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
180
+ reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
181
+ reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
182
+ dL_dout_color.contiguous().data<float>(),
183
+ dL_dmeans2D.contiguous().data<float>(),
184
+ dL_dconic.contiguous().data<float>(),
185
+ dL_dopacity.contiguous().data<float>(),
186
+ dL_dcolors.contiguous().data<float>(),
187
+ dL_dmeans3D.contiguous().data<float>(),
188
+ dL_dcov3D.contiguous().data<float>(),
189
+ dL_dsh.contiguous().data<float>(),
190
+ dL_dscales.contiguous().data<float>(),
191
+ dL_drotations.contiguous().data<float>(),
192
+ debug);
193
+ }
194
+
195
+ return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
196
+ }
197
+
198
+ torch::Tensor markVisible(
199
+ torch::Tensor& means3D,
200
+ torch::Tensor& viewmatrix,
201
+ torch::Tensor& projmatrix)
202
+ {
203
+ const int P = means3D.size(0);
204
+
205
+ torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
206
+
207
+ if(P != 0)
208
+ {
209
+ CudaRasterizer::Rasterizer::markVisible(P,
210
+ means3D.contiguous().data<float>(),
211
+ viewmatrix.contiguous().data<float>(),
212
+ projmatrix.contiguous().data<float>(),
213
+ present.contiguous().data<bool>());
214
+ }
215
+
216
+ return present;
217
+ }
submodules/diff-gaussian-rasterization/rasterize_points.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (C) 2023, Inria
3
+ * GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ * All rights reserved.
5
+ *
6
+ * This software is free for non-commercial, research and evaluation use
7
+ * under the terms of the LICENSE.md file.
8
+ *
9
+ * For inquiries contact george.drettakis@inria.fr
10
+ */
11
+
12
+ #pragma once
13
+ #include <torch/extension.h>
14
+ #include <cstdio>
15
+ #include <tuple>
16
+ #include <string>
17
+
18
+ std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
19
+ RasterizeGaussiansCUDA(
20
+ const torch::Tensor& background,
21
+ const torch::Tensor& means3D,
22
+ const torch::Tensor& colors,
23
+ const torch::Tensor& opacity,
24
+ const torch::Tensor& scales,
25
+ const torch::Tensor& rotations,
26
+ const float scale_modifier,
27
+ const torch::Tensor& cov3D_precomp,
28
+ const torch::Tensor& viewmatrix,
29
+ const torch::Tensor& projmatrix,
30
+ const float tan_fovx,
31
+ const float tan_fovy,
32
+ const int image_height,
33
+ const int image_width,
34
+ const torch::Tensor& sh,
35
+ const int degree,
36
+ const torch::Tensor& campos,
37
+ const bool prefiltered,
38
+ const bool debug);
39
+
40
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
41
+ RasterizeGaussiansBackwardCUDA(
42
+ const torch::Tensor& background,
43
+ const torch::Tensor& means3D,
44
+ const torch::Tensor& radii,
45
+ const torch::Tensor& colors,
46
+ const torch::Tensor& scales,
47
+ const torch::Tensor& rotations,
48
+ const float scale_modifier,
49
+ const torch::Tensor& cov3D_precomp,
50
+ const torch::Tensor& viewmatrix,
51
+ const torch::Tensor& projmatrix,
52
+ const float tan_fovx,
53
+ const float tan_fovy,
54
+ const torch::Tensor& dL_dout_color,
55
+ const torch::Tensor& sh,
56
+ const int degree,
57
+ const torch::Tensor& campos,
58
+ const torch::Tensor& geomBuffer,
59
+ const int R,
60
+ const torch::Tensor& binningBuffer,
61
+ const torch::Tensor& imageBuffer,
62
+ const bool debug);
63
+
64
+ torch::Tensor markVisible(
65
+ torch::Tensor& means3D,
66
+ torch::Tensor& viewmatrix,
67
+ torch::Tensor& projmatrix);
submodules/diff-gaussian-rasterization/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from setuptools import setup
13
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14
+ import os
15
+ os.path.dirname(os.path.abspath(__file__))
16
+
17
+ setup(
18
+ name="diff_gaussian_rasterization",
19
+ packages=['diff_gaussian_rasterization'],
20
+ ext_modules=[
21
+ CUDAExtension(
22
+ name="diff_gaussian_rasterization._C",
23
+ sources=[
24
+ "cuda_rasterizer/rasterizer_impl.cu",
25
+ "cuda_rasterizer/forward.cu",
26
+ "cuda_rasterizer/backward.cu",
27
+ "rasterize_points.cu",
28
+ "ext.cpp"],
29
+ extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]})
30
+ ],
31
+ cmdclass={
32
+ 'build_ext': BuildExtension
33
+ }
34
+ )
submodules/diff-gaussian-rasterization/third_party/glm/.appveyor.yml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shallow_clone: true
2
+
3
+ platform:
4
+ - x86
5
+ - x64
6
+
7
+ configuration:
8
+ - Debug
9
+ - Release
10
+
11
+ image:
12
+ - Visual Studio 2013
13
+ - Visual Studio 2015
14
+ - Visual Studio 2017
15
+ - Visual Studio 2019
16
+
17
+ environment:
18
+ matrix:
19
+ - GLM_ARGUMENTS: -DGLM_TEST_FORCE_PURE=ON
20
+ - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_SSE2=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
21
+ - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
22
+ - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
23
+ - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
24
+
25
+ matrix:
26
+ exclude:
27
+ - image: Visual Studio 2013
28
+ GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
29
+ - image: Visual Studio 2013
30
+ GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
31
+ - image: Visual Studio 2013
32
+ GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
33
+ - image: Visual Studio 2013
34
+ configuration: Debug
35
+ - image: Visual Studio 2015
36
+ GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_SSE2=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
37
+ - image: Visual Studio 2015
38
+ GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
39
+ - image: Visual Studio 2015
40
+ GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
41
+ - image: Visual Studio 2015
42
+ platform: x86
43
+ - image: Visual Studio 2015
44
+ configuration: Debug
45
+ - image: Visual Studio 2017
46
+ platform: x86
47
+ - image: Visual Studio 2017
48
+ configuration: Debug
49
+ - image: Visual Studio 2019
50
+ platform: x64
51
+
52
+ branches:
53
+ only:
54
+ - master
55
+
56
+ before_build:
57
+ - ps: |
58
+ mkdir build
59
+ cd build
60
+
61
+ if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2013") {
62
+ $env:generator="Visual Studio 12 2013"
63
+ }
64
+ if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2015") {
65
+ $env:generator="Visual Studio 14 2015"
66
+ }
67
+ if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2017") {
68
+ $env:generator="Visual Studio 15 2017"
69
+ }
70
+ if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2019") {
71
+ $env:generator="Visual Studio 16 2019"
72
+ }
73
+ if ($env:PLATFORM -eq "x64") {
74
+ $env:generator="$env:generator Win64"
75
+ }
76
+ echo generator="$env:generator"
77
+ cmake .. -G "$env:generator" -DCMAKE_INSTALL_PREFIX="$env:APPVEYOR_BUILD_FOLDER/install" -DGLM_QUIET=ON -DGLM_TEST_ENABLE=ON "$env:GLM_ARGUMENTS"
78
+
79
+ build_script:
80
+ - cmake --build . --parallel --config %CONFIGURATION% -- /m /v:minimal
81
+ - cmake --build . --target install --parallel --config %CONFIGURATION% -- /m /v:minimal
82
+
83
+ test_script:
84
+ - ctest --parallel 4 --verbose -C %CONFIGURATION%
85
+ - cd ..
86
+ - ps: |
87
+ mkdir build_test_cmake
88
+ cd build_test_cmake
89
+ cmake ..\test\cmake\ -G "$env:generator" -DCMAKE_PREFIX_PATH="$env:APPVEYOR_BUILD_FOLDER/install"
90
+ - cmake --build . --parallel --config %CONFIGURATION% -- /m /v:minimal
91
+
92
+ deploy: off
submodules/diff-gaussian-rasterization/third_party/glm/.gitignore ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compiled Object files
2
+ *.slo
3
+ *.lo
4
+ *.o
5
+ *.obj
6
+
7
+ # Precompiled Headers
8
+ *.gch
9
+ *.pch
10
+
11
+ # Compiled Dynamic libraries
12
+ *.so
13
+ *.dylib
14
+ *.dll
15
+
16
+ # Fortran module files
17
+ *.mod
18
+
19
+ # Compiled Static libraries
20
+ *.lai
21
+ *.la
22
+ *.a
23
+ *.lib
24
+
25
+ # Executables
26
+ *.exe
27
+ *.out
28
+ *.app
29
+
30
+ # CMake
31
+ CMakeCache.txt
32
+ CMakeFiles
33
+ cmake_install.cmake
34
+ install_manifest.txt
35
+ *.cmake
36
+ !glmConfig.cmake
37
+ !glmConfig-version.cmake
38
+ # ^ May need to add future .cmake files as exceptions
39
+
40
+ # Test logs
41
+ Testing/*
42
+
43
+ # Test input
44
+ test/gtc/*.dds
45
+
46
+ # Project Files
47
+ Makefile
48
+ *.cbp
49
+ *.user
50
+
51
+ # Misc.
52
+ *.log
53
+
54
+ # local build(s)
55
+ build*
56
+
57
+ /.vs
58
+ /.vscode
59
+ /CMakeSettings.json
60
+ .DS_Store
61
+ *.swp