Spaces:
Running
on
Zero
Running
on
Zero
add code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +23 -0
- README.md +1 -1
- app.py +277 -0
- arguments/__init__.py +113 -0
- assets/example/TT-family-3-views/000301.jpg +0 -0
- assets/example/TT-family-3-views/000388.jpg +0 -0
- assets/example/TT-family-3-views/000491.jpg +0 -0
- assets/example/dl3dv-ba55-3-views/frame_60.jpg +0 -0
- assets/example/dl3dv-ba55-3-views/frame_61.jpg +0 -0
- assets/example/dl3dv-ba55-3-views/frame_62.jpg +0 -0
- assets/example/sora-santorini-3-views/frame_00.jpg +0 -0
- assets/example/sora-santorini-3-views/frame_06.jpg +0 -0
- assets/example/sora-santorini-3-views/frame_12.jpg +0 -0
- assets/load/.gitkeep +0 -0
- coarse_init_infer.py +100 -0
- gaussian_renderer/__init__.py +144 -0
- gaussian_renderer/__init__3dgs.py +100 -0
- gaussian_renderer/network_gui.py +86 -0
- lpipsPyTorch/__init__.py +21 -0
- lpipsPyTorch/modules/lpips.py +36 -0
- lpipsPyTorch/modules/networks.py +96 -0
- lpipsPyTorch/modules/utils.py +30 -0
- render_by_interp.py +152 -0
- requirements.txt +17 -0
- scene/__init__.py +96 -0
- scene/cameras.py +71 -0
- scene/colmap_loader.py +294 -0
- scene/dataset_readers.py +363 -0
- scene/gaussian_model.py +502 -0
- submodules/diff-gaussian-rasterization/.gitignore +3 -0
- submodules/diff-gaussian-rasterization/.gitmodules +3 -0
- submodules/diff-gaussian-rasterization/CMakeLists.txt +36 -0
- submodules/diff-gaussian-rasterization/LICENSE.md +83 -0
- submodules/diff-gaussian-rasterization/README.md +19 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h +175 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu +657 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h +65 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h +19 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu +455 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h +66 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h +88 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu +434 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h +74 -0
- submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py +221 -0
- submodules/diff-gaussian-rasterization/ext.cpp +19 -0
- submodules/diff-gaussian-rasterization/rasterize_points.cu +217 -0
- submodules/diff-gaussian-rasterization/rasterize_points.h +67 -0
- submodules/diff-gaussian-rasterization/setup.py +34 -0
- submodules/diff-gaussian-rasterization/third_party/glm/.appveyor.yml +92 -0
- submodules/diff-gaussian-rasterization/third_party/glm/.gitignore +61 -0
.gitignore
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/.idea/
|
2 |
+
/work_dirs*
|
3 |
+
.vscode/
|
4 |
+
/tmp
|
5 |
+
/data
|
6 |
+
/checkpoints
|
7 |
+
*.so
|
8 |
+
*.patch
|
9 |
+
__pycache__/
|
10 |
+
*.egg-info/
|
11 |
+
/viz*
|
12 |
+
/submit*
|
13 |
+
build/
|
14 |
+
*.pyd
|
15 |
+
/cache*
|
16 |
+
*.stl
|
17 |
+
*.pth
|
18 |
+
/venv/
|
19 |
+
.nk8s
|
20 |
+
*.mp4
|
21 |
+
.vs
|
22 |
+
/exp/
|
23 |
+
/dev/
|
README.md
CHANGED
@@ -6,7 +6,7 @@ colorTo: green
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.20.1
|
8 |
python_version: 3.10.13
|
9 |
-
app_file:
|
10 |
pinned: false
|
11 |
license: mit
|
12 |
short_description: Sparse-view SFM-free Gaussian Splatting in Seconds
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.20.1
|
8 |
python_version: 3.10.13
|
9 |
+
app_file: app.py
|
10 |
pinned: false
|
11 |
license: mit
|
12 |
short_description: Sparse-view SFM-free Gaussian Splatting in Seconds
|
app.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, subprocess, shlex, sys, gc
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import shutil
|
6 |
+
import argparse
|
7 |
+
import gradio as gr
|
8 |
+
import uuid
|
9 |
+
import spaces
|
10 |
+
|
11 |
+
subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
|
12 |
+
subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
|
13 |
+
subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))
|
14 |
+
|
15 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
16 |
+
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
|
17 |
+
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
18 |
+
from dust3r.inference import inference
|
19 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
20 |
+
from dust3r.utils.device import to_numpy
|
21 |
+
from dust3r.image_pairs import make_pairs
|
22 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
23 |
+
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
|
24 |
+
|
25 |
+
from argparse import ArgumentParser, Namespace
|
26 |
+
from arguments import ModelParams, PipelineParams, OptimizationParams
|
27 |
+
from train_joint import training
|
28 |
+
from render_by_interp import render_sets
|
29 |
+
GRADIO_CACHE_FOLDER = './gradio_cache_folder'
|
30 |
+
#############################################################################################################################################
|
31 |
+
|
32 |
+
|
33 |
+
def get_dust3r_args_parser():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
|
36 |
+
parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
|
37 |
+
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
|
38 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
39 |
+
parser.add_argument("--schedule", type=str, default='linear')
|
40 |
+
parser.add_argument("--lr", type=float, default=0.01)
|
41 |
+
parser.add_argument("--niter", type=int, default=300)
|
42 |
+
parser.add_argument("--focal_avg", type=bool, default=True)
|
43 |
+
parser.add_argument("--n_views", type=int, default=3)
|
44 |
+
parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER)
|
45 |
+
return parser
|
46 |
+
|
47 |
+
|
48 |
+
@spaces.GPU(duration=300)
|
49 |
+
def process(inputfiles, input_path=None):
|
50 |
+
|
51 |
+
if input_path is not None:
|
52 |
+
imgs_path = './assets/example/' + input_path
|
53 |
+
imgs_names = sorted(os.listdir(imgs_path))
|
54 |
+
|
55 |
+
inputfiles = []
|
56 |
+
for imgs_name in imgs_names:
|
57 |
+
file_path = os.path.join(imgs_path, imgs_name)
|
58 |
+
print(file_path)
|
59 |
+
inputfiles.append(file_path)
|
60 |
+
print(inputfiles)
|
61 |
+
|
62 |
+
# ------ (1) Coarse Geometric Initialization ------
|
63 |
+
# os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
|
64 |
+
parser = get_dust3r_args_parser()
|
65 |
+
opt = parser.parse_args()
|
66 |
+
|
67 |
+
tmp_user_folder = str(uuid.uuid4()).replace("-", "")
|
68 |
+
opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
|
69 |
+
img_folder_path = os.path.join(opt.img_base_path, "images")
|
70 |
+
|
71 |
+
img_folder_path = os.path.join(opt.img_base_path, "images")
|
72 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
|
73 |
+
os.makedirs(img_folder_path, exist_ok=True)
|
74 |
+
|
75 |
+
opt.n_views = len(inputfiles)
|
76 |
+
if opt.n_views == 1:
|
77 |
+
raise gr.Error("The number of input images should be greater than 1.")
|
78 |
+
print("Multiple images: ", inputfiles)
|
79 |
+
for image_path in inputfiles:
|
80 |
+
if input_path is not None:
|
81 |
+
shutil.copy(image_path, img_folder_path)
|
82 |
+
else:
|
83 |
+
shutil.move(image_path, img_folder_path)
|
84 |
+
train_img_list = sorted(os.listdir(img_folder_path))
|
85 |
+
assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
|
86 |
+
images, ori_size, imgs_resolution = load_images(img_folder_path, size=512)
|
87 |
+
resolutions_are_equal = len(set(imgs_resolution)) == 1
|
88 |
+
if resolutions_are_equal == False:
|
89 |
+
raise gr.Error("The resolution of the input image should be the same.")
|
90 |
+
print("ori_size", ori_size)
|
91 |
+
start_time = time.time()
|
92 |
+
######################################################
|
93 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
94 |
+
output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
|
95 |
+
output_colmap_path=img_folder_path.replace("images", "sparse/0")
|
96 |
+
os.makedirs(output_colmap_path, exist_ok=True)
|
97 |
+
|
98 |
+
scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
99 |
+
loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
|
100 |
+
scene = scene.clean_pointcloud()
|
101 |
+
|
102 |
+
imgs = to_numpy(scene.imgs)
|
103 |
+
focals = scene.get_focals()
|
104 |
+
poses = to_numpy(scene.get_im_poses())
|
105 |
+
pts3d = to_numpy(scene.get_pts3d())
|
106 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
|
107 |
+
confidence_masks = to_numpy(scene.get_masks())
|
108 |
+
intrinsics = to_numpy(scene.get_intrinsics())
|
109 |
+
######################################################
|
110 |
+
end_time = time.time()
|
111 |
+
print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
|
112 |
+
save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
|
113 |
+
save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
|
114 |
+
pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
|
115 |
+
color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
|
116 |
+
color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
|
117 |
+
storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
|
118 |
+
pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
|
119 |
+
np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
|
120 |
+
np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
|
121 |
+
|
122 |
+
### save VRAM
|
123 |
+
del scene
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
gc.collect()
|
126 |
+
##################################################################################################################################################
|
127 |
+
|
128 |
+
# ------ (2) Fast 3D-Gaussian Optimization ------
|
129 |
+
parser = ArgumentParser(description="Training script parameters")
|
130 |
+
lp = ModelParams(parser)
|
131 |
+
op = OptimizationParams(parser)
|
132 |
+
pp = PipelineParams(parser)
|
133 |
+
parser.add_argument('--debug_from', type=int, default=-1)
|
134 |
+
parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
|
135 |
+
parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
|
136 |
+
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
137 |
+
parser.add_argument("--start_checkpoint", type=str, default = None)
|
138 |
+
parser.add_argument("--scene", type=str, default="demo")
|
139 |
+
parser.add_argument("--n_views", type=int, default=3)
|
140 |
+
parser.add_argument("--get_video", action="store_true")
|
141 |
+
parser.add_argument("--optim_pose", type=bool, default=True)
|
142 |
+
parser.add_argument("--skip_train", action="store_true")
|
143 |
+
parser.add_argument("--skip_test", action="store_true")
|
144 |
+
args = parser.parse_args(sys.argv[1:])
|
145 |
+
args.save_iterations.append(args.iterations)
|
146 |
+
args.model_path = opt.img_base_path + '/output/'
|
147 |
+
args.source_path = opt.img_base_path
|
148 |
+
# args.model_path = GRADIO_CACHE_FOLDER + '/output/'
|
149 |
+
# args.source_path = GRADIO_CACHE_FOLDER
|
150 |
+
args.iteration = 1000
|
151 |
+
os.makedirs(args.model_path, exist_ok=True)
|
152 |
+
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
|
153 |
+
##################################################################################################################################################
|
154 |
+
|
155 |
+
# ------ (3) Render video by interpolation ------
|
156 |
+
parser = ArgumentParser(description="Testing script parameters")
|
157 |
+
model = ModelParams(parser, sentinel=True)
|
158 |
+
pipeline = PipelineParams(parser)
|
159 |
+
args.eval = True
|
160 |
+
args.get_video = True
|
161 |
+
args.n_views = opt.n_views
|
162 |
+
render_sets(
|
163 |
+
model.extract(args),
|
164 |
+
args.iteration,
|
165 |
+
pipeline.extract(args),
|
166 |
+
args.skip_train,
|
167 |
+
args.skip_test,
|
168 |
+
args,
|
169 |
+
)
|
170 |
+
output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
|
171 |
+
output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
|
172 |
+
# output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
|
173 |
+
# output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4'
|
174 |
+
|
175 |
+
return output_video_path, output_ply_path, output_ply_path
|
176 |
+
##################################################################################################################################################
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
_TITLE = '''InstantSplat'''
|
181 |
+
_DESCRIPTION = '''
|
182 |
+
<div style="display: flex; justify-content: center; align-items: center;">
|
183 |
+
<div style="width: 100%; text-align: center; font-size: 30px;">
|
184 |
+
<strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
|
185 |
+
</div>
|
186 |
+
</div>
|
187 |
+
<p></p>
|
188 |
+
|
189 |
+
<div align="center">
|
190 |
+
<a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a>
|
191 |
+
<a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a>
|
192 |
+
<a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
|
193 |
+
</div>
|
194 |
+
<p></p>
|
195 |
+
|
196 |
+
* Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
|
197 |
+
* Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
|
198 |
+
* Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
|
199 |
+
'''
|
200 |
+
|
201 |
+
|
202 |
+
# <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a>
|
203 |
+
#
|
204 |
+
# <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a>
|
205 |
+
# * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a>
|
206 |
+
|
207 |
+
|
208 |
+
# block = gr.Blocks(title=_TITLE).queue()
|
209 |
+
block = gr.Blocks().queue()
|
210 |
+
with block:
|
211 |
+
with gr.Row():
|
212 |
+
with gr.Column(scale=1):
|
213 |
+
# gr.Markdown('# ' + _TITLE)
|
214 |
+
gr.Markdown(_DESCRIPTION)
|
215 |
+
|
216 |
+
with gr.Row(variant='panel'):
|
217 |
+
with gr.Tab("Input"):
|
218 |
+
inputfiles = gr.File(file_count="multiple", label="images")
|
219 |
+
input_path = gr.Textbox(visible=False, label="example_path")
|
220 |
+
button_gen = gr.Button("RUN")
|
221 |
+
|
222 |
+
with gr.Row(variant='panel'):
|
223 |
+
with gr.Tab("Output"):
|
224 |
+
with gr.Column(scale=2):
|
225 |
+
output_model = gr.Model3D(
|
226 |
+
label="3D Model (Gaussian)",
|
227 |
+
# height=300,
|
228 |
+
interactive=False,
|
229 |
+
# clear_color=[1.0, 1.0, 1.0, 1.0]
|
230 |
+
)
|
231 |
+
output_file = gr.File(label="ply")
|
232 |
+
with gr.Column(scale=1):
|
233 |
+
output_video = gr.Video(label="video")
|
234 |
+
|
235 |
+
button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model])
|
236 |
+
|
237 |
+
# gr.Examples(
|
238 |
+
# examples=[
|
239 |
+
# "sora-santorini-3-views",
|
240 |
+
# # "TT-family-3-views",
|
241 |
+
# # "dl3dv-ba55-3-views",
|
242 |
+
# ],
|
243 |
+
# inputs=[input_path],
|
244 |
+
# outputs=[output_video, output_file, output_model],
|
245 |
+
# fn=lambda x: process(inputfiles=None, input_path=x),
|
246 |
+
# cache_examples=True,
|
247 |
+
# label='Sparse-view Examples'
|
248 |
+
# )
|
249 |
+
block.launch(server_name="0.0.0.0", share=False)
|
250 |
+
|
251 |
+
|
252 |
+
# block = gr.Blocks(title=_TITLE).queue()
|
253 |
+
# with block:
|
254 |
+
# with gr.Row():
|
255 |
+
# with gr.Column(scale=1):
|
256 |
+
# gr.Markdown('# ' + _TITLE)
|
257 |
+
# # gr.Markdown(_DESCRIPTION)
|
258 |
+
|
259 |
+
# with gr.Row(variant='panel'):
|
260 |
+
# with gr.Column(scale=1):
|
261 |
+
# with gr.Tab("Input"):
|
262 |
+
# inputfiles = gr.File(file_count="multiple", label="images")
|
263 |
+
# button_gen = gr.Button("RUN")
|
264 |
+
|
265 |
+
# with gr.Column(scale=2):
|
266 |
+
# with gr.Tab("Output"):
|
267 |
+
# output_video = gr.Video(label="video")
|
268 |
+
# output_model = gr.Model3D(
|
269 |
+
# label="3D Model (Gaussian)",
|
270 |
+
# height=300,
|
271 |
+
# interactive=False,
|
272 |
+
# )
|
273 |
+
# output_file = gr.File(label="ply")
|
274 |
+
|
275 |
+
# button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model])
|
276 |
+
|
277 |
+
# block.launch(server_name="0.0.0.0", share=False)
|
arguments/__init__.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
from argparse import ArgumentParser, Namespace
|
13 |
+
import sys
|
14 |
+
import os
|
15 |
+
|
16 |
+
class GroupParams:
|
17 |
+
pass
|
18 |
+
|
19 |
+
class ParamGroup:
|
20 |
+
def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
|
21 |
+
group = parser.add_argument_group(name)
|
22 |
+
for key, value in vars(self).items():
|
23 |
+
shorthand = False
|
24 |
+
if key.startswith("_"):
|
25 |
+
shorthand = True
|
26 |
+
key = key[1:]
|
27 |
+
t = type(value)
|
28 |
+
value = value if not fill_none else None
|
29 |
+
if shorthand:
|
30 |
+
if t == bool:
|
31 |
+
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
|
32 |
+
else:
|
33 |
+
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
|
34 |
+
else:
|
35 |
+
if t == bool:
|
36 |
+
group.add_argument("--" + key, default=value, action="store_true")
|
37 |
+
else:
|
38 |
+
group.add_argument("--" + key, default=value, type=t)
|
39 |
+
|
40 |
+
def extract(self, args):
|
41 |
+
group = GroupParams()
|
42 |
+
for arg in vars(args).items():
|
43 |
+
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
|
44 |
+
setattr(group, arg[0], arg[1])
|
45 |
+
return group
|
46 |
+
|
47 |
+
class ModelParams(ParamGroup):
|
48 |
+
def __init__(self, parser, sentinel=False):
|
49 |
+
self.sh_degree = 3
|
50 |
+
self._source_path = ""
|
51 |
+
self._model_path = ""
|
52 |
+
self._images = "images"
|
53 |
+
self._resolution = -1
|
54 |
+
self._white_background = False
|
55 |
+
self.data_device = "cuda"
|
56 |
+
self.eval = False
|
57 |
+
super().__init__(parser, "Loading Parameters", sentinel)
|
58 |
+
|
59 |
+
def extract(self, args):
|
60 |
+
g = super().extract(args)
|
61 |
+
g.source_path = os.path.abspath(g.source_path)
|
62 |
+
return g
|
63 |
+
|
64 |
+
class PipelineParams(ParamGroup):
|
65 |
+
def __init__(self, parser):
|
66 |
+
self.convert_SHs_python = False
|
67 |
+
self.compute_cov3D_python = False
|
68 |
+
self.debug = False
|
69 |
+
super().__init__(parser, "Pipeline Parameters")
|
70 |
+
|
71 |
+
class OptimizationParams(ParamGroup):
|
72 |
+
def __init__(self, parser):
|
73 |
+
self.iterations = 1000
|
74 |
+
# self.iterations = 30_000
|
75 |
+
self.position_lr_init = 0.00016
|
76 |
+
self.position_lr_final = 0.0000016
|
77 |
+
self.position_lr_delay_mult = 0.01
|
78 |
+
self.position_lr_max_steps = 30_000
|
79 |
+
self.feature_lr = 0.0025
|
80 |
+
self.opacity_lr = 0.05
|
81 |
+
self.scaling_lr = 0.005
|
82 |
+
self.rotation_lr = 0.001
|
83 |
+
self.percent_dense = 0.01
|
84 |
+
self.lambda_dssim = 0.2
|
85 |
+
self.densification_interval = 100
|
86 |
+
self.opacity_reset_interval = 3000
|
87 |
+
self.densify_from_iter = 500
|
88 |
+
self.densify_until_iter = 15_000
|
89 |
+
self.densify_grad_threshold = 0.0002
|
90 |
+
self.random_background = False
|
91 |
+
super().__init__(parser, "Optimization Parameters")
|
92 |
+
|
93 |
+
def get_combined_args(parser : ArgumentParser):
|
94 |
+
cmdlne_string = sys.argv[1:]
|
95 |
+
cfgfile_string = "Namespace()"
|
96 |
+
args_cmdline = parser.parse_args(cmdlne_string)
|
97 |
+
|
98 |
+
try:
|
99 |
+
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
|
100 |
+
print("Looking for config file in", cfgfilepath)
|
101 |
+
with open(cfgfilepath) as cfg_file:
|
102 |
+
print("Config file found: {}".format(cfgfilepath))
|
103 |
+
cfgfile_string = cfg_file.read()
|
104 |
+
except TypeError:
|
105 |
+
print("Config file not found at")
|
106 |
+
pass
|
107 |
+
args_cfgfile = eval(cfgfile_string)
|
108 |
+
|
109 |
+
merged_dict = vars(args_cfgfile).copy()
|
110 |
+
for k,v in vars(args_cmdline).items():
|
111 |
+
if v != None:
|
112 |
+
merged_dict[k] = v
|
113 |
+
return Namespace(**merged_dict)
|
assets/example/TT-family-3-views/000301.jpg
ADDED
assets/example/TT-family-3-views/000388.jpg
ADDED
assets/example/TT-family-3-views/000491.jpg
ADDED
assets/example/dl3dv-ba55-3-views/frame_60.jpg
ADDED
assets/example/dl3dv-ba55-3-views/frame_61.jpg
ADDED
assets/example/dl3dv-ba55-3-views/frame_62.jpg
ADDED
assets/example/sora-santorini-3-views/frame_00.jpg
ADDED
assets/example/sora-santorini-3-views/frame_06.jpg
ADDED
assets/example/sora-santorini-3-views/frame_12.jpg
ADDED
assets/load/.gitkeep
ADDED
File without changes
|
coarse_init_infer.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
import time
|
7 |
+
|
8 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
|
10 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
11 |
+
|
12 |
+
from dust3r.inference import inference
|
13 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
14 |
+
from dust3r.utils.device import to_numpy
|
15 |
+
from dust3r.image_pairs import make_pairs
|
16 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
17 |
+
from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
|
18 |
+
|
19 |
+
def get_args_parser():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
|
22 |
+
# parser.add_argument("--model_path", type=str, default="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
|
23 |
+
parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
|
24 |
+
parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
|
25 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
26 |
+
parser.add_argument("--schedule", type=str, default='linear')
|
27 |
+
parser.add_argument("--lr", type=float, default=0.01)
|
28 |
+
parser.add_argument("--niter", type=int, default=300)
|
29 |
+
parser.add_argument("--focal_avg", action="store_true")
|
30 |
+
# parser.add_argument("--focal_avg", type=bool, default=True)
|
31 |
+
|
32 |
+
parser.add_argument("--llffhold", type=int, default=2)
|
33 |
+
parser.add_argument("--n_views", type=int, default=12)
|
34 |
+
parser.add_argument("--img_base_path", type=str, default="/home/workspace/datasets/instantsplat/Tanks/Barn/24_views")
|
35 |
+
|
36 |
+
return parser
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
|
40 |
+
parser = get_args_parser()
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
model_path = args.model_path
|
44 |
+
device = args.device
|
45 |
+
batch_size = args.batch_size
|
46 |
+
schedule = args.schedule
|
47 |
+
lr = args.lr
|
48 |
+
niter = args.niter
|
49 |
+
n_views = args.n_views
|
50 |
+
img_base_path = args.img_base_path
|
51 |
+
img_folder_path = os.path.join(img_base_path, "images")
|
52 |
+
os.makedirs(img_folder_path, exist_ok=True)
|
53 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(device)
|
54 |
+
##########################################################################################################################################################################################
|
55 |
+
|
56 |
+
train_img_list = sorted(os.listdir(img_folder_path))
|
57 |
+
assert len(train_img_list)==n_views, f"Number of images ({len(train_img_list)}) in the folder ({img_folder_path}) is not equal to {n_views}"
|
58 |
+
|
59 |
+
# if len(os.listdir(img_folder_path)) != len(train_img_list):
|
60 |
+
# for img_name in train_img_list:
|
61 |
+
# src_path = os.path.join(img_base_path, "images", img_name)
|
62 |
+
# tgt_path = os.path.join(img_folder_path, img_name)
|
63 |
+
# print(src_path, tgt_path)
|
64 |
+
# shutil.copy(src_path, tgt_path)
|
65 |
+
images, ori_size = load_images(img_folder_path, size=512)
|
66 |
+
print("ori_size", ori_size)
|
67 |
+
|
68 |
+
start_time = time.time()
|
69 |
+
##########################################################################################################################################################################################
|
70 |
+
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
71 |
+
output = inference(pairs, model, args.device, batch_size=batch_size)
|
72 |
+
output_colmap_path=img_folder_path.replace("images", "sparse/0")
|
73 |
+
os.makedirs(output_colmap_path, exist_ok=True)
|
74 |
+
|
75 |
+
scene = global_aligner(output, device=args.device, mode=GlobalAlignerMode.PointCloudOptimizer)
|
76 |
+
loss = compute_global_alignment(scene=scene, init="mst", niter=niter, schedule=schedule, lr=lr, focal_avg=args.focal_avg)
|
77 |
+
scene = scene.clean_pointcloud()
|
78 |
+
|
79 |
+
imgs = to_numpy(scene.imgs)
|
80 |
+
focals = scene.get_focals()
|
81 |
+
poses = to_numpy(scene.get_im_poses())
|
82 |
+
pts3d = to_numpy(scene.get_pts3d())
|
83 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
|
84 |
+
confidence_masks = to_numpy(scene.get_masks())
|
85 |
+
intrinsics = to_numpy(scene.get_intrinsics())
|
86 |
+
##########################################################################################################################################################################################
|
87 |
+
end_time = time.time()
|
88 |
+
print(f"Time taken for {n_views} views: {end_time-start_time} seconds")
|
89 |
+
|
90 |
+
# save
|
91 |
+
save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
|
92 |
+
save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
|
93 |
+
|
94 |
+
pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
|
95 |
+
color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
|
96 |
+
color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
|
97 |
+
storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
|
98 |
+
pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
|
99 |
+
np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
|
100 |
+
np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
|
gaussian_renderer/__init__.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
from diff_gaussian_rasterization import (
|
15 |
+
GaussianRasterizationSettings,
|
16 |
+
GaussianRasterizer,
|
17 |
+
)
|
18 |
+
from scene.gaussian_model import GaussianModel
|
19 |
+
from utils.sh_utils import eval_sh
|
20 |
+
from utils.pose_utils import get_camera_from_tensor, quadmultiply
|
21 |
+
|
22 |
+
|
23 |
+
def render(
|
24 |
+
viewpoint_camera,
|
25 |
+
pc: GaussianModel,
|
26 |
+
pipe,
|
27 |
+
bg_color: torch.Tensor,
|
28 |
+
scaling_modifier=1.0,
|
29 |
+
override_color=None,
|
30 |
+
camera_pose=None,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Render the scene.
|
34 |
+
|
35 |
+
Background tensor (bg_color) must be on GPU!
|
36 |
+
"""
|
37 |
+
|
38 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
39 |
+
screenspace_points = (
|
40 |
+
torch.zeros_like(
|
41 |
+
pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
|
42 |
+
)
|
43 |
+
+ 0
|
44 |
+
)
|
45 |
+
try:
|
46 |
+
screenspace_points.retain_grad()
|
47 |
+
except:
|
48 |
+
pass
|
49 |
+
|
50 |
+
# Set up rasterization configuration
|
51 |
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
52 |
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
53 |
+
|
54 |
+
# Set camera pose as identity. Then, we will transform the Gaussians around camera_pose
|
55 |
+
w2c = torch.eye(4).cuda()
|
56 |
+
projmatrix = (
|
57 |
+
w2c.unsqueeze(0).bmm(viewpoint_camera.projection_matrix.unsqueeze(0))
|
58 |
+
).squeeze(0)
|
59 |
+
camera_pos = w2c.inverse()[3, :3]
|
60 |
+
raster_settings = GaussianRasterizationSettings(
|
61 |
+
image_height=int(viewpoint_camera.image_height),
|
62 |
+
image_width=int(viewpoint_camera.image_width),
|
63 |
+
tanfovx=tanfovx,
|
64 |
+
tanfovy=tanfovy,
|
65 |
+
bg=bg_color,
|
66 |
+
scale_modifier=scaling_modifier,
|
67 |
+
# viewmatrix=viewpoint_camera.world_view_transform,
|
68 |
+
# projmatrix=viewpoint_camera.full_proj_transform,
|
69 |
+
viewmatrix=w2c,
|
70 |
+
projmatrix=projmatrix,
|
71 |
+
sh_degree=pc.active_sh_degree,
|
72 |
+
# campos=viewpoint_camera.camera_center,
|
73 |
+
campos=camera_pos,
|
74 |
+
prefiltered=False,
|
75 |
+
debug=pipe.debug,
|
76 |
+
)
|
77 |
+
|
78 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
79 |
+
|
80 |
+
# means3D = pc.get_xyz
|
81 |
+
rel_w2c = get_camera_from_tensor(camera_pose)
|
82 |
+
# Transform mean and rot of Gaussians to camera frame
|
83 |
+
gaussians_xyz = pc._xyz.clone()
|
84 |
+
gaussians_rot = pc._rotation.clone()
|
85 |
+
|
86 |
+
xyz_ones = torch.ones(gaussians_xyz.shape[0], 1).cuda().float()
|
87 |
+
xyz_homo = torch.cat((gaussians_xyz, xyz_ones), dim=1)
|
88 |
+
gaussians_xyz_trans = (rel_w2c @ xyz_homo.T).T[:, :3]
|
89 |
+
gaussians_rot_trans = quadmultiply(camera_pose[:4], gaussians_rot)
|
90 |
+
means3D = gaussians_xyz_trans
|
91 |
+
means2D = screenspace_points
|
92 |
+
opacity = pc.get_opacity
|
93 |
+
|
94 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
95 |
+
# scaling / rotation by the rasterizer.
|
96 |
+
scales = None
|
97 |
+
rotations = None
|
98 |
+
cov3D_precomp = None
|
99 |
+
if pipe.compute_cov3D_python:
|
100 |
+
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
101 |
+
else:
|
102 |
+
scales = pc.get_scaling
|
103 |
+
rotations = gaussians_rot_trans # pc.get_rotation
|
104 |
+
|
105 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
106 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
107 |
+
shs = None
|
108 |
+
colors_precomp = None
|
109 |
+
if override_color is None:
|
110 |
+
if pipe.convert_SHs_python:
|
111 |
+
shs_view = pc.get_features.transpose(1, 2).view(
|
112 |
+
-1, 3, (pc.max_sh_degree + 1) ** 2
|
113 |
+
)
|
114 |
+
dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
|
115 |
+
pc.get_features.shape[0], 1
|
116 |
+
)
|
117 |
+
dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
|
118 |
+
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
119 |
+
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
120 |
+
else:
|
121 |
+
shs = pc.get_features
|
122 |
+
else:
|
123 |
+
colors_precomp = override_color
|
124 |
+
|
125 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
126 |
+
rendered_image, radii = rasterizer(
|
127 |
+
means3D=means3D,
|
128 |
+
means2D=means2D,
|
129 |
+
shs=shs,
|
130 |
+
colors_precomp=colors_precomp,
|
131 |
+
opacities=opacity,
|
132 |
+
scales=scales,
|
133 |
+
rotations=rotations,
|
134 |
+
cov3D_precomp=cov3D_precomp,
|
135 |
+
)
|
136 |
+
|
137 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
138 |
+
# They will be excluded from value updates used in the splitting criteria.
|
139 |
+
return {
|
140 |
+
"render": rendered_image,
|
141 |
+
"viewspace_points": screenspace_points,
|
142 |
+
"visibility_filter": radii > 0,
|
143 |
+
"radii": radii,
|
144 |
+
}
|
gaussian_renderer/__init__3dgs.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import math
|
14 |
+
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
15 |
+
from scene.gaussian_model import GaussianModel
|
16 |
+
from utils.sh_utils import eval_sh
|
17 |
+
|
18 |
+
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
|
19 |
+
"""
|
20 |
+
Render the scene.
|
21 |
+
|
22 |
+
Background tensor (bg_color) must be on GPU!
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
26 |
+
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
|
27 |
+
try:
|
28 |
+
screenspace_points.retain_grad()
|
29 |
+
except:
|
30 |
+
pass
|
31 |
+
|
32 |
+
# Set up rasterization configuration
|
33 |
+
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
34 |
+
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
35 |
+
|
36 |
+
raster_settings = GaussianRasterizationSettings(
|
37 |
+
image_height=int(viewpoint_camera.image_height),
|
38 |
+
image_width=int(viewpoint_camera.image_width),
|
39 |
+
tanfovx=tanfovx,
|
40 |
+
tanfovy=tanfovy,
|
41 |
+
bg=bg_color,
|
42 |
+
scale_modifier=scaling_modifier,
|
43 |
+
viewmatrix=viewpoint_camera.world_view_transform,
|
44 |
+
projmatrix=viewpoint_camera.full_proj_transform,
|
45 |
+
sh_degree=pc.active_sh_degree,
|
46 |
+
campos=viewpoint_camera.camera_center,
|
47 |
+
prefiltered=False,
|
48 |
+
debug=pipe.debug
|
49 |
+
)
|
50 |
+
|
51 |
+
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
52 |
+
|
53 |
+
means3D = pc.get_xyz
|
54 |
+
means2D = screenspace_points
|
55 |
+
opacity = pc.get_opacity
|
56 |
+
|
57 |
+
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
58 |
+
# scaling / rotation by the rasterizer.
|
59 |
+
scales = None
|
60 |
+
rotations = None
|
61 |
+
cov3D_precomp = None
|
62 |
+
if pipe.compute_cov3D_python:
|
63 |
+
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
64 |
+
else:
|
65 |
+
scales = pc.get_scaling
|
66 |
+
rotations = pc.get_rotation
|
67 |
+
|
68 |
+
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
69 |
+
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
70 |
+
shs = None
|
71 |
+
colors_precomp = None
|
72 |
+
if override_color is None:
|
73 |
+
if pipe.convert_SHs_python:
|
74 |
+
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
|
75 |
+
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
|
76 |
+
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
|
77 |
+
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
78 |
+
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
79 |
+
else:
|
80 |
+
shs = pc.get_features
|
81 |
+
else:
|
82 |
+
colors_precomp = override_color
|
83 |
+
|
84 |
+
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
85 |
+
rendered_image, radii = rasterizer(
|
86 |
+
means3D = means3D,
|
87 |
+
means2D = means2D,
|
88 |
+
shs = shs,
|
89 |
+
colors_precomp = colors_precomp,
|
90 |
+
opacities = opacity,
|
91 |
+
scales = scales,
|
92 |
+
rotations = rotations,
|
93 |
+
cov3D_precomp = cov3D_precomp)
|
94 |
+
|
95 |
+
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
96 |
+
# They will be excluded from value updates used in the splitting criteria.
|
97 |
+
return {"render": rendered_image,
|
98 |
+
"viewspace_points": screenspace_points,
|
99 |
+
"visibility_filter" : radii > 0,
|
100 |
+
"radii": radii}
|
gaussian_renderer/network_gui.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import traceback
|
14 |
+
import socket
|
15 |
+
import json
|
16 |
+
from scene.cameras import MiniCam
|
17 |
+
|
18 |
+
host = "127.0.0.1"
|
19 |
+
port = 6009
|
20 |
+
|
21 |
+
conn = None
|
22 |
+
addr = None
|
23 |
+
|
24 |
+
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
25 |
+
|
26 |
+
def init(wish_host, wish_port):
|
27 |
+
global host, port, listener
|
28 |
+
host = wish_host
|
29 |
+
port = wish_port
|
30 |
+
listener.bind((host, port))
|
31 |
+
listener.listen()
|
32 |
+
listener.settimeout(0)
|
33 |
+
|
34 |
+
def try_connect():
|
35 |
+
global conn, addr, listener
|
36 |
+
try:
|
37 |
+
conn, addr = listener.accept()
|
38 |
+
print(f"\nConnected by {addr}")
|
39 |
+
conn.settimeout(None)
|
40 |
+
except Exception as inst:
|
41 |
+
pass
|
42 |
+
|
43 |
+
def read():
|
44 |
+
global conn
|
45 |
+
messageLength = conn.recv(4)
|
46 |
+
messageLength = int.from_bytes(messageLength, 'little')
|
47 |
+
message = conn.recv(messageLength)
|
48 |
+
return json.loads(message.decode("utf-8"))
|
49 |
+
|
50 |
+
def send(message_bytes, verify):
|
51 |
+
global conn
|
52 |
+
if message_bytes != None:
|
53 |
+
conn.sendall(message_bytes)
|
54 |
+
conn.sendall(len(verify).to_bytes(4, 'little'))
|
55 |
+
conn.sendall(bytes(verify, 'ascii'))
|
56 |
+
|
57 |
+
def receive():
|
58 |
+
message = read()
|
59 |
+
|
60 |
+
width = message["resolution_x"]
|
61 |
+
height = message["resolution_y"]
|
62 |
+
|
63 |
+
if width != 0 and height != 0:
|
64 |
+
try:
|
65 |
+
do_training = bool(message["train"])
|
66 |
+
fovy = message["fov_y"]
|
67 |
+
fovx = message["fov_x"]
|
68 |
+
znear = message["z_near"]
|
69 |
+
zfar = message["z_far"]
|
70 |
+
do_shs_python = bool(message["shs_python"])
|
71 |
+
do_rot_scale_python = bool(message["rot_scale_python"])
|
72 |
+
keep_alive = bool(message["keep_alive"])
|
73 |
+
scaling_modifier = message["scaling_modifier"]
|
74 |
+
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
|
75 |
+
world_view_transform[:,1] = -world_view_transform[:,1]
|
76 |
+
world_view_transform[:,2] = -world_view_transform[:,2]
|
77 |
+
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
|
78 |
+
full_proj_transform[:,1] = -full_proj_transform[:,1]
|
79 |
+
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
|
80 |
+
except Exception as e:
|
81 |
+
print("")
|
82 |
+
traceback.print_exc()
|
83 |
+
raise e
|
84 |
+
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
|
85 |
+
else:
|
86 |
+
return None, None, None, None, None, None
|
lpipsPyTorch/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from .modules.lpips import LPIPS
|
4 |
+
|
5 |
+
|
6 |
+
def lpips(x: torch.Tensor,
|
7 |
+
y: torch.Tensor,
|
8 |
+
net_type: str = 'alex',
|
9 |
+
version: str = '0.1'):
|
10 |
+
r"""Function that measures
|
11 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
12 |
+
|
13 |
+
Arguments:
|
14 |
+
x, y (torch.Tensor): the input tensors to compare.
|
15 |
+
net_type (str): the network type to compare the features:
|
16 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
17 |
+
version (str): the version of LPIPS. Default: 0.1.
|
18 |
+
"""
|
19 |
+
device = x.device
|
20 |
+
criterion = LPIPS(net_type, version).to(device)
|
21 |
+
return criterion(x, y)
|
lpipsPyTorch/modules/lpips.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .networks import get_network, LinLayers
|
5 |
+
from .utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
|
12 |
+
Arguments:
|
13 |
+
net_type (str): the network type to compare the features:
|
14 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
15 |
+
version (str): the version of LPIPS. Default: 0.1.
|
16 |
+
"""
|
17 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
18 |
+
|
19 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
20 |
+
|
21 |
+
super(LPIPS, self).__init__()
|
22 |
+
|
23 |
+
# pretrained network
|
24 |
+
self.net = get_network(net_type)
|
25 |
+
|
26 |
+
# linear layers
|
27 |
+
self.lin = LinLayers(self.net.n_channels_list)
|
28 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
29 |
+
|
30 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
31 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
32 |
+
|
33 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
34 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
35 |
+
|
36 |
+
return torch.sum(torch.cat(res, 0), 0, True)
|
lpipsPyTorch/modules/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from .utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
lpipsPyTorch/modules/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
render_by_interp.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from scene import Scene
|
14 |
+
import os
|
15 |
+
from tqdm import tqdm
|
16 |
+
from os import makedirs
|
17 |
+
from gaussian_renderer import render
|
18 |
+
import torchvision
|
19 |
+
from utils.general_utils import safe_state
|
20 |
+
from argparse import ArgumentParser
|
21 |
+
from arguments import ModelParams, PipelineParams, get_combined_args
|
22 |
+
from gaussian_renderer import GaussianModel
|
23 |
+
from utils.pose_utils import get_tensor_from_camera
|
24 |
+
from utils.camera_utils import generate_interpolated_path
|
25 |
+
from utils.camera_utils import visualizer
|
26 |
+
import cv2
|
27 |
+
import numpy as np
|
28 |
+
import imageio
|
29 |
+
|
30 |
+
|
31 |
+
def save_interpolate_pose(model_path, iter, n_views):
|
32 |
+
|
33 |
+
org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
|
34 |
+
# visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
|
35 |
+
# n_interp = int(10 * 30 / n_views) # 10second, fps=30
|
36 |
+
n_interp = int(5 * 30 / n_views) # 5second, fps=30
|
37 |
+
all_inter_pose = []
|
38 |
+
for i in range(n_views-1):
|
39 |
+
tmp_inter_pose = generate_interpolated_path(poses=org_pose[i:i+2], n_interp=n_interp)
|
40 |
+
all_inter_pose.append(tmp_inter_pose)
|
41 |
+
all_inter_pose = np.array(all_inter_pose).reshape(-1, 3, 4)
|
42 |
+
|
43 |
+
inter_pose_list = []
|
44 |
+
for p in all_inter_pose:
|
45 |
+
tmp_view = np.eye(4)
|
46 |
+
tmp_view[:3, :3] = p[:3, :3]
|
47 |
+
tmp_view[:3, 3] = p[:3, 3]
|
48 |
+
inter_pose_list.append(tmp_view)
|
49 |
+
inter_pose = np.stack(inter_pose_list, 0)
|
50 |
+
# visualizer(inter_pose, ["blue" for _ in inter_pose], model_path + "pose/poses_interpolated.png")
|
51 |
+
np.save(model_path + "pose/pose_interpolated.npy", inter_pose)
|
52 |
+
|
53 |
+
|
54 |
+
def images_to_video(image_folder, output_video_path, fps=30):
|
55 |
+
"""
|
56 |
+
Convert images in a folder to a video.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
- image_folder (str): The path to the folder containing the images.
|
60 |
+
- output_video_path (str): The path where the output video will be saved.
|
61 |
+
- fps (int): Frames per second for the output video.
|
62 |
+
"""
|
63 |
+
images = []
|
64 |
+
|
65 |
+
for filename in sorted(os.listdir(image_folder)):
|
66 |
+
if filename.endswith(('.png', '.jpg', '.jpeg', '.JPG', '.PNG')):
|
67 |
+
image_path = os.path.join(image_folder, filename)
|
68 |
+
image = imageio.imread(image_path)
|
69 |
+
images.append(image)
|
70 |
+
|
71 |
+
imageio.mimwrite(output_video_path, images, fps=fps)
|
72 |
+
|
73 |
+
|
74 |
+
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
|
75 |
+
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
|
76 |
+
makedirs(render_path, exist_ok=True)
|
77 |
+
|
78 |
+
# for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
|
79 |
+
for idx, view in enumerate(views):
|
80 |
+
camera_pose = get_tensor_from_camera(view.world_view_transform.transpose(0, 1))
|
81 |
+
rendering = render(
|
82 |
+
view, gaussians, pipeline, background, camera_pose=camera_pose
|
83 |
+
)["render"]
|
84 |
+
gt = view.original_image[0:3, :, :]
|
85 |
+
torchvision.utils.save_image(
|
86 |
+
rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def render_sets(
|
91 |
+
dataset: ModelParams,
|
92 |
+
iteration: int,
|
93 |
+
pipeline: PipelineParams,
|
94 |
+
skip_train: bool,
|
95 |
+
skip_test: bool,
|
96 |
+
args,
|
97 |
+
):
|
98 |
+
|
99 |
+
# Applying interpolation
|
100 |
+
save_interpolate_pose(dataset.model_path, iteration, args.n_views)
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
104 |
+
scene = Scene(dataset, gaussians, load_iteration=iteration, opt=args, shuffle=False)
|
105 |
+
|
106 |
+
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
107 |
+
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
108 |
+
|
109 |
+
# render interpolated views
|
110 |
+
render_set(
|
111 |
+
dataset.model_path,
|
112 |
+
"interp",
|
113 |
+
scene.loaded_iter,
|
114 |
+
scene.getTrainCameras(),
|
115 |
+
gaussians,
|
116 |
+
pipeline,
|
117 |
+
background,
|
118 |
+
)
|
119 |
+
|
120 |
+
if args.get_video:
|
121 |
+
image_folder = os.path.join(dataset.model_path, f'interp/ours_{args.iteration}/renders')
|
122 |
+
output_video_file = os.path.join(dataset.model_path, f'{args.scene}_{args.n_views}_view.mp4')
|
123 |
+
images_to_video(image_folder, output_video_file, fps=30)
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
# Set up command line argument parser
|
128 |
+
parser = ArgumentParser(description="Testing script parameters")
|
129 |
+
model = ModelParams(parser, sentinel=True)
|
130 |
+
pipeline = PipelineParams(parser)
|
131 |
+
parser.add_argument("--iteration", default=-1, type=int)
|
132 |
+
parser.add_argument("--skip_train", action="store_true")
|
133 |
+
parser.add_argument("--skip_test", action="store_true")
|
134 |
+
parser.add_argument("--quiet", action="store_true")
|
135 |
+
|
136 |
+
parser.add_argument("--get_video", action="store_true")
|
137 |
+
parser.add_argument("--n_views", default=None, type=int)
|
138 |
+
parser.add_argument("--scene", default=None, type=str)
|
139 |
+
args = get_combined_args(parser)
|
140 |
+
print("Rendering " + args.model_path)
|
141 |
+
|
142 |
+
# Initialize system state (RNG)
|
143 |
+
# safe_state(args.quiet)
|
144 |
+
|
145 |
+
render_sets(
|
146 |
+
model.extract(args),
|
147 |
+
args.iteration,
|
148 |
+
pipeline.extract(args),
|
149 |
+
args.skip_train,
|
150 |
+
args.skip_test,
|
151 |
+
args,
|
152 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.0
|
2 |
+
torchvision
|
3 |
+
roma
|
4 |
+
evo
|
5 |
+
gradio==5.0.1
|
6 |
+
matplotlib
|
7 |
+
tqdm
|
8 |
+
opencv-python
|
9 |
+
scipy
|
10 |
+
einops
|
11 |
+
trimesh
|
12 |
+
tensorboard
|
13 |
+
pyglet<2
|
14 |
+
huggingface-hub[torch]>=0.22
|
15 |
+
plyfile
|
16 |
+
imageio[ffmpeg]
|
17 |
+
spaces
|
scene/__init__.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import json
|
15 |
+
from utils.system_utils import searchForMaxIteration
|
16 |
+
from scene.dataset_readers import sceneLoadTypeCallbacks
|
17 |
+
from scene.gaussian_model import GaussianModel
|
18 |
+
from arguments import ModelParams
|
19 |
+
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
|
20 |
+
|
21 |
+
class Scene:
|
22 |
+
|
23 |
+
gaussians : GaussianModel
|
24 |
+
|
25 |
+
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, opt=None, shuffle=True, resolution_scales=[1.0]):
|
26 |
+
"""b
|
27 |
+
:param path: Path to colmap scene main folder.
|
28 |
+
"""
|
29 |
+
self.model_path = args.model_path
|
30 |
+
self.loaded_iter = None
|
31 |
+
self.gaussians = gaussians
|
32 |
+
|
33 |
+
if load_iteration:
|
34 |
+
if load_iteration == -1:
|
35 |
+
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
|
36 |
+
else:
|
37 |
+
self.loaded_iter = load_iteration
|
38 |
+
print("Loading trained model at iteration {}".format(self.loaded_iter))
|
39 |
+
|
40 |
+
self.train_cameras = {}
|
41 |
+
self.test_cameras = {}
|
42 |
+
|
43 |
+
if os.path.exists(os.path.join(args.source_path, "sparse")):
|
44 |
+
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args, opt)
|
45 |
+
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
|
46 |
+
print("Found transforms_train.json file, assuming Blender data set!")
|
47 |
+
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
|
48 |
+
else:
|
49 |
+
assert False, "Could not recognize scene type!"
|
50 |
+
|
51 |
+
if not self.loaded_iter:
|
52 |
+
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
|
53 |
+
dest_file.write(src_file.read())
|
54 |
+
json_cams = []
|
55 |
+
camlist = []
|
56 |
+
if scene_info.test_cameras:
|
57 |
+
camlist.extend(scene_info.test_cameras)
|
58 |
+
if scene_info.train_cameras:
|
59 |
+
camlist.extend(scene_info.train_cameras)
|
60 |
+
for id, cam in enumerate(camlist):
|
61 |
+
json_cams.append(camera_to_JSON(id, cam))
|
62 |
+
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
|
63 |
+
json.dump(json_cams, file)
|
64 |
+
|
65 |
+
if shuffle:
|
66 |
+
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
|
67 |
+
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
|
68 |
+
|
69 |
+
self.cameras_extent = scene_info.nerf_normalization["radius"]
|
70 |
+
|
71 |
+
for resolution_scale in resolution_scales:
|
72 |
+
print("Loading Training Cameras")
|
73 |
+
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
|
74 |
+
print('train_camera_num: ', len(self.train_cameras[resolution_scale]))
|
75 |
+
print("Loading Test Cameras")
|
76 |
+
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
|
77 |
+
print('test_camera_num: ', len(self.test_cameras[resolution_scale]))
|
78 |
+
|
79 |
+
if self.loaded_iter:
|
80 |
+
self.gaussians.load_ply(os.path.join(self.model_path,
|
81 |
+
"point_cloud",
|
82 |
+
"iteration_" + str(self.loaded_iter),
|
83 |
+
"point_cloud.ply"))
|
84 |
+
else:
|
85 |
+
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
|
86 |
+
self.gaussians.init_RT_seq(self.train_cameras)
|
87 |
+
|
88 |
+
def save(self, iteration):
|
89 |
+
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
|
90 |
+
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
|
91 |
+
|
92 |
+
def getTrainCameras(self, scale=1.0):
|
93 |
+
return self.train_cameras[scale]
|
94 |
+
|
95 |
+
def getTestCameras(self, scale=1.0):
|
96 |
+
return self.test_cameras[scale]
|
scene/cameras.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
import numpy as np
|
15 |
+
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
16 |
+
|
17 |
+
class Camera(nn.Module):
|
18 |
+
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
19 |
+
image_name, uid,
|
20 |
+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
21 |
+
):
|
22 |
+
super(Camera, self).__init__()
|
23 |
+
|
24 |
+
self.uid = uid
|
25 |
+
self.colmap_id = colmap_id
|
26 |
+
self.R = R
|
27 |
+
self.T = T
|
28 |
+
self.FoVx = FoVx
|
29 |
+
self.FoVy = FoVy
|
30 |
+
self.image_name = image_name
|
31 |
+
|
32 |
+
try:
|
33 |
+
self.data_device = torch.device(data_device)
|
34 |
+
except Exception as e:
|
35 |
+
print(e)
|
36 |
+
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
37 |
+
self.data_device = torch.device("cuda")
|
38 |
+
|
39 |
+
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
40 |
+
self.image_width = self.original_image.shape[2]
|
41 |
+
self.image_height = self.original_image.shape[1]
|
42 |
+
|
43 |
+
if gt_alpha_mask is not None:
|
44 |
+
self.original_image *= gt_alpha_mask.to(self.data_device)
|
45 |
+
else:
|
46 |
+
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
47 |
+
|
48 |
+
self.zfar = 100.0
|
49 |
+
self.znear = 0.01
|
50 |
+
|
51 |
+
self.trans = trans
|
52 |
+
self.scale = scale
|
53 |
+
|
54 |
+
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
|
55 |
+
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
|
56 |
+
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
|
57 |
+
self.camera_center = self.world_view_transform.inverse()[3, :3]
|
58 |
+
|
59 |
+
class MiniCam:
|
60 |
+
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
|
61 |
+
self.image_width = width
|
62 |
+
self.image_height = height
|
63 |
+
self.FoVy = fovy
|
64 |
+
self.FoVx = fovx
|
65 |
+
self.znear = znear
|
66 |
+
self.zfar = zfar
|
67 |
+
self.world_view_transform = world_view_transform
|
68 |
+
self.full_proj_transform = full_proj_transform
|
69 |
+
view_inv = torch.inverse(self.world_view_transform)
|
70 |
+
self.camera_center = view_inv[3][:3]
|
71 |
+
|
scene/colmap_loader.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import collections
|
14 |
+
import struct
|
15 |
+
|
16 |
+
CameraModel = collections.namedtuple(
|
17 |
+
"CameraModel", ["model_id", "model_name", "num_params"])
|
18 |
+
Camera = collections.namedtuple(
|
19 |
+
"Camera", ["id", "model", "width", "height", "params"])
|
20 |
+
BaseImage = collections.namedtuple(
|
21 |
+
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
|
22 |
+
Point3D = collections.namedtuple(
|
23 |
+
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
|
24 |
+
CAMERA_MODELS = {
|
25 |
+
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
|
26 |
+
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
|
27 |
+
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
|
28 |
+
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
|
29 |
+
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
|
30 |
+
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
|
31 |
+
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
|
32 |
+
CameraModel(model_id=7, model_name="FOV", num_params=5),
|
33 |
+
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
|
34 |
+
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
|
35 |
+
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
|
36 |
+
}
|
37 |
+
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
|
38 |
+
for camera_model in CAMERA_MODELS])
|
39 |
+
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
|
40 |
+
for camera_model in CAMERA_MODELS])
|
41 |
+
|
42 |
+
|
43 |
+
def qvec2rotmat(qvec):
|
44 |
+
return np.array([
|
45 |
+
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
|
46 |
+
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
|
47 |
+
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
|
48 |
+
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
|
49 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
|
50 |
+
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
|
51 |
+
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
|
52 |
+
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
|
53 |
+
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
|
54 |
+
|
55 |
+
def rotmat2qvec(R):
|
56 |
+
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
|
57 |
+
K = np.array([
|
58 |
+
[Rxx - Ryy - Rzz, 0, 0, 0],
|
59 |
+
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
|
60 |
+
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
|
61 |
+
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
|
62 |
+
eigvals, eigvecs = np.linalg.eigh(K)
|
63 |
+
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
|
64 |
+
if qvec[0] < 0:
|
65 |
+
qvec *= -1
|
66 |
+
return qvec
|
67 |
+
|
68 |
+
class Image(BaseImage):
|
69 |
+
def qvec2rotmat(self):
|
70 |
+
return qvec2rotmat(self.qvec)
|
71 |
+
|
72 |
+
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
|
73 |
+
"""Read and unpack the next bytes from a binary file.
|
74 |
+
:param fid:
|
75 |
+
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
|
76 |
+
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
|
77 |
+
:param endian_character: Any of {@, =, <, >, !}
|
78 |
+
:return: Tuple of read and unpacked values.
|
79 |
+
"""
|
80 |
+
data = fid.read(num_bytes)
|
81 |
+
return struct.unpack(endian_character + format_char_sequence, data)
|
82 |
+
|
83 |
+
def read_points3D_text(path):
|
84 |
+
"""
|
85 |
+
see: src/base/reconstruction.cc
|
86 |
+
void Reconstruction::ReadPoints3DText(const std::string& path)
|
87 |
+
void Reconstruction::WritePoints3DText(const std::string& path)
|
88 |
+
"""
|
89 |
+
xyzs = None
|
90 |
+
rgbs = None
|
91 |
+
errors = None
|
92 |
+
num_points = 0
|
93 |
+
with open(path, "r") as fid:
|
94 |
+
while True:
|
95 |
+
line = fid.readline()
|
96 |
+
if not line:
|
97 |
+
break
|
98 |
+
line = line.strip()
|
99 |
+
if len(line) > 0 and line[0] != "#":
|
100 |
+
num_points += 1
|
101 |
+
|
102 |
+
|
103 |
+
xyzs = np.empty((num_points, 3))
|
104 |
+
rgbs = np.empty((num_points, 3))
|
105 |
+
errors = np.empty((num_points, 1))
|
106 |
+
count = 0
|
107 |
+
with open(path, "r") as fid:
|
108 |
+
while True:
|
109 |
+
line = fid.readline()
|
110 |
+
if not line:
|
111 |
+
break
|
112 |
+
line = line.strip()
|
113 |
+
if len(line) > 0 and line[0] != "#":
|
114 |
+
elems = line.split()
|
115 |
+
xyz = np.array(tuple(map(float, elems[1:4])))
|
116 |
+
rgb = np.array(tuple(map(int, elems[4:7])))
|
117 |
+
error = np.array(float(elems[7]))
|
118 |
+
xyzs[count] = xyz
|
119 |
+
rgbs[count] = rgb
|
120 |
+
errors[count] = error
|
121 |
+
count += 1
|
122 |
+
|
123 |
+
return xyzs, rgbs, errors
|
124 |
+
|
125 |
+
def read_points3D_binary(path_to_model_file):
|
126 |
+
"""
|
127 |
+
see: src/base/reconstruction.cc
|
128 |
+
void Reconstruction::ReadPoints3DBinary(const std::string& path)
|
129 |
+
void Reconstruction::WritePoints3DBinary(const std::string& path)
|
130 |
+
"""
|
131 |
+
|
132 |
+
|
133 |
+
with open(path_to_model_file, "rb") as fid:
|
134 |
+
num_points = read_next_bytes(fid, 8, "Q")[0]
|
135 |
+
|
136 |
+
xyzs = np.empty((num_points, 3))
|
137 |
+
rgbs = np.empty((num_points, 3))
|
138 |
+
errors = np.empty((num_points, 1))
|
139 |
+
|
140 |
+
for p_id in range(num_points):
|
141 |
+
binary_point_line_properties = read_next_bytes(
|
142 |
+
fid, num_bytes=43, format_char_sequence="QdddBBBd")
|
143 |
+
xyz = np.array(binary_point_line_properties[1:4])
|
144 |
+
rgb = np.array(binary_point_line_properties[4:7])
|
145 |
+
error = np.array(binary_point_line_properties[7])
|
146 |
+
track_length = read_next_bytes(
|
147 |
+
fid, num_bytes=8, format_char_sequence="Q")[0]
|
148 |
+
track_elems = read_next_bytes(
|
149 |
+
fid, num_bytes=8*track_length,
|
150 |
+
format_char_sequence="ii"*track_length)
|
151 |
+
xyzs[p_id] = xyz
|
152 |
+
rgbs[p_id] = rgb
|
153 |
+
errors[p_id] = error
|
154 |
+
return xyzs, rgbs, errors
|
155 |
+
|
156 |
+
def read_intrinsics_text(path):
|
157 |
+
"""
|
158 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
159 |
+
"""
|
160 |
+
cameras = {}
|
161 |
+
with open(path, "r") as fid:
|
162 |
+
while True:
|
163 |
+
line = fid.readline()
|
164 |
+
if not line:
|
165 |
+
break
|
166 |
+
line = line.strip()
|
167 |
+
if len(line) > 0 and line[0] != "#":
|
168 |
+
elems = line.split()
|
169 |
+
camera_id = int(elems[0])
|
170 |
+
model = elems[1]
|
171 |
+
assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
|
172 |
+
width = int(elems[2])
|
173 |
+
height = int(elems[3])
|
174 |
+
params = np.array(tuple(map(float, elems[4:])))
|
175 |
+
cameras[camera_id] = Camera(id=camera_id, model=model,
|
176 |
+
width=width, height=height,
|
177 |
+
params=params)
|
178 |
+
return cameras
|
179 |
+
|
180 |
+
def read_extrinsics_binary(path_to_model_file):
|
181 |
+
"""
|
182 |
+
see: src/base/reconstruction.cc
|
183 |
+
void Reconstruction::ReadImagesBinary(const std::string& path)
|
184 |
+
void Reconstruction::WriteImagesBinary(const std::string& path)
|
185 |
+
"""
|
186 |
+
images = {}
|
187 |
+
with open(path_to_model_file, "rb") as fid:
|
188 |
+
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
|
189 |
+
for _ in range(num_reg_images):
|
190 |
+
binary_image_properties = read_next_bytes(
|
191 |
+
fid, num_bytes=64, format_char_sequence="idddddddi")
|
192 |
+
image_id = binary_image_properties[0]
|
193 |
+
qvec = np.array(binary_image_properties[1:5])
|
194 |
+
tvec = np.array(binary_image_properties[5:8])
|
195 |
+
camera_id = binary_image_properties[8]
|
196 |
+
image_name = ""
|
197 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
198 |
+
while current_char != b"\x00": # look for the ASCII 0 entry
|
199 |
+
image_name += current_char.decode("utf-8")
|
200 |
+
current_char = read_next_bytes(fid, 1, "c")[0]
|
201 |
+
num_points2D = read_next_bytes(fid, num_bytes=8,
|
202 |
+
format_char_sequence="Q")[0]
|
203 |
+
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
|
204 |
+
format_char_sequence="ddq"*num_points2D)
|
205 |
+
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
|
206 |
+
tuple(map(float, x_y_id_s[1::3]))])
|
207 |
+
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
|
208 |
+
images[image_id] = Image(
|
209 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
210 |
+
camera_id=camera_id, name=image_name,
|
211 |
+
xys=xys, point3D_ids=point3D_ids)
|
212 |
+
return images
|
213 |
+
|
214 |
+
|
215 |
+
def read_intrinsics_binary(path_to_model_file):
|
216 |
+
"""
|
217 |
+
see: src/base/reconstruction.cc
|
218 |
+
void Reconstruction::WriteCamerasBinary(const std::string& path)
|
219 |
+
void Reconstruction::ReadCamerasBinary(const std::string& path)
|
220 |
+
"""
|
221 |
+
cameras = {}
|
222 |
+
with open(path_to_model_file, "rb") as fid:
|
223 |
+
num_cameras = read_next_bytes(fid, 8, "Q")[0]
|
224 |
+
for _ in range(num_cameras):
|
225 |
+
camera_properties = read_next_bytes(
|
226 |
+
fid, num_bytes=24, format_char_sequence="iiQQ")
|
227 |
+
camera_id = camera_properties[0]
|
228 |
+
model_id = camera_properties[1]
|
229 |
+
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
|
230 |
+
width = camera_properties[2]
|
231 |
+
height = camera_properties[3]
|
232 |
+
num_params = CAMERA_MODEL_IDS[model_id].num_params
|
233 |
+
params = read_next_bytes(fid, num_bytes=8*num_params,
|
234 |
+
format_char_sequence="d"*num_params)
|
235 |
+
cameras[camera_id] = Camera(id=camera_id,
|
236 |
+
model=model_name,
|
237 |
+
width=width,
|
238 |
+
height=height,
|
239 |
+
params=np.array(params))
|
240 |
+
assert len(cameras) == num_cameras
|
241 |
+
return cameras
|
242 |
+
|
243 |
+
|
244 |
+
def read_extrinsics_text(path):
|
245 |
+
"""
|
246 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
|
247 |
+
"""
|
248 |
+
images = {}
|
249 |
+
with open(path, "r") as fid:
|
250 |
+
while True:
|
251 |
+
line = fid.readline()
|
252 |
+
if not line:
|
253 |
+
break
|
254 |
+
line = line.strip()
|
255 |
+
if len(line) > 0 and line[0] != "#":
|
256 |
+
elems = line.split()
|
257 |
+
image_id = int(elems[0])
|
258 |
+
qvec = np.array(tuple(map(float, elems[1:5])))
|
259 |
+
tvec = np.array(tuple(map(float, elems[5:8])))
|
260 |
+
camera_id = int(elems[8])
|
261 |
+
image_name = elems[9]
|
262 |
+
elems = fid.readline().split()
|
263 |
+
xys = np.column_stack([tuple(map(float, elems[0::3])),
|
264 |
+
tuple(map(float, elems[1::3]))])
|
265 |
+
point3D_ids = np.array(tuple(map(int, elems[2::3])))
|
266 |
+
images[image_id] = Image(
|
267 |
+
id=image_id, qvec=qvec, tvec=tvec,
|
268 |
+
camera_id=camera_id, name=image_name,
|
269 |
+
xys=xys, point3D_ids=point3D_ids)
|
270 |
+
return images
|
271 |
+
|
272 |
+
|
273 |
+
def read_colmap_bin_array(path):
|
274 |
+
"""
|
275 |
+
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
|
276 |
+
|
277 |
+
:param path: path to the colmap binary file.
|
278 |
+
:return: nd array with the floating point values in the value
|
279 |
+
"""
|
280 |
+
with open(path, "rb") as fid:
|
281 |
+
width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
|
282 |
+
usecols=(0, 1, 2), dtype=int)
|
283 |
+
fid.seek(0)
|
284 |
+
num_delimiter = 0
|
285 |
+
byte = fid.read(1)
|
286 |
+
while True:
|
287 |
+
if byte == b"&":
|
288 |
+
num_delimiter += 1
|
289 |
+
if num_delimiter >= 3:
|
290 |
+
break
|
291 |
+
byte = fid.read(1)
|
292 |
+
array = np.fromfile(fid, np.float32)
|
293 |
+
array = array.reshape((width, height, channels), order="F")
|
294 |
+
return np.transpose(array, (1, 0, 2)).squeeze()
|
scene/dataset_readers.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
from PIL import Image
|
15 |
+
from typing import NamedTuple
|
16 |
+
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
|
17 |
+
read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
|
18 |
+
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
|
19 |
+
import numpy as np
|
20 |
+
import json
|
21 |
+
from pathlib import Path
|
22 |
+
from plyfile import PlyData, PlyElement
|
23 |
+
from utils.sh_utils import SH2RGB
|
24 |
+
from scene.gaussian_model import BasicPointCloud
|
25 |
+
|
26 |
+
class CameraInfo(NamedTuple):
|
27 |
+
uid: int
|
28 |
+
R: np.array
|
29 |
+
T: np.array
|
30 |
+
FovY: np.array
|
31 |
+
FovX: np.array
|
32 |
+
image: np.array
|
33 |
+
image_path: str
|
34 |
+
image_name: str
|
35 |
+
width: int
|
36 |
+
height: int
|
37 |
+
|
38 |
+
|
39 |
+
class SceneInfo(NamedTuple):
|
40 |
+
point_cloud: BasicPointCloud
|
41 |
+
train_cameras: list
|
42 |
+
test_cameras: list
|
43 |
+
nerf_normalization: dict
|
44 |
+
ply_path: str
|
45 |
+
train_poses: list
|
46 |
+
test_poses: list
|
47 |
+
|
48 |
+
def getNerfppNorm(cam_info):
|
49 |
+
def get_center_and_diag(cam_centers):
|
50 |
+
cam_centers = np.hstack(cam_centers)
|
51 |
+
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
|
52 |
+
center = avg_cam_center
|
53 |
+
dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
|
54 |
+
diagonal = np.max(dist)
|
55 |
+
return center.flatten(), diagonal
|
56 |
+
|
57 |
+
cam_centers = []
|
58 |
+
|
59 |
+
for cam in cam_info:
|
60 |
+
W2C = getWorld2View2(cam.R, cam.T)
|
61 |
+
C2W = np.linalg.inv(W2C)
|
62 |
+
cam_centers.append(C2W[:3, 3:4])
|
63 |
+
|
64 |
+
center, diagonal = get_center_and_diag(cam_centers)
|
65 |
+
radius = diagonal * 1.1
|
66 |
+
|
67 |
+
translate = -center
|
68 |
+
|
69 |
+
return {"translate": translate, "radius": radius}
|
70 |
+
|
71 |
+
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, eval):
|
72 |
+
|
73 |
+
cam_infos = []
|
74 |
+
poses=[]
|
75 |
+
for idx, key in enumerate(cam_extrinsics):
|
76 |
+
sys.stdout.write('\r')
|
77 |
+
# the exact output you're looking for:
|
78 |
+
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
|
79 |
+
sys.stdout.flush()
|
80 |
+
|
81 |
+
if eval:
|
82 |
+
extr = cam_extrinsics[key]
|
83 |
+
intr = cam_intrinsics[1]
|
84 |
+
uid = idx+1
|
85 |
+
|
86 |
+
else:
|
87 |
+
extr = cam_extrinsics[key]
|
88 |
+
intr = cam_intrinsics[extr.camera_id]
|
89 |
+
uid = intr.id
|
90 |
+
|
91 |
+
height = intr.height
|
92 |
+
width = intr.width
|
93 |
+
R = np.transpose(qvec2rotmat(extr.qvec))
|
94 |
+
T = np.array(extr.tvec)
|
95 |
+
pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
|
96 |
+
poses.append(pose)
|
97 |
+
if intr.model=="SIMPLE_PINHOLE":
|
98 |
+
focal_length_x = intr.params[0]
|
99 |
+
FovY = focal2fov(focal_length_x, height)
|
100 |
+
FovX = focal2fov(focal_length_x, width)
|
101 |
+
elif intr.model=="PINHOLE":
|
102 |
+
focal_length_x = intr.params[0]
|
103 |
+
focal_length_y = intr.params[1]
|
104 |
+
FovY = focal2fov(focal_length_y, height)
|
105 |
+
FovX = focal2fov(focal_length_x, width)
|
106 |
+
else:
|
107 |
+
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
108 |
+
|
109 |
+
|
110 |
+
if eval:
|
111 |
+
tmp = os.path.dirname(os.path.dirname(os.path.join(images_folder)))
|
112 |
+
all_images_folder = os.path.join(tmp, 'images')
|
113 |
+
image_path = os.path.join(all_images_folder, os.path.basename(extr.name))
|
114 |
+
else:
|
115 |
+
image_path = os.path.join(images_folder, os.path.basename(extr.name))
|
116 |
+
image_name = os.path.basename(image_path).split(".")[0]
|
117 |
+
image = Image.open(image_path)
|
118 |
+
|
119 |
+
|
120 |
+
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
121 |
+
image_path=image_path, image_name=image_name, width=width, height=height)
|
122 |
+
|
123 |
+
cam_infos.append(cam_info)
|
124 |
+
sys.stdout.write('\n')
|
125 |
+
return cam_infos, poses
|
126 |
+
|
127 |
+
# For interpolated video, open when only render interpolated video
|
128 |
+
def readColmapCamerasInterp(cam_extrinsics, cam_intrinsics, images_folder, model_path):
|
129 |
+
|
130 |
+
pose_interpolated_path = model_path + 'pose/pose_interpolated.npy'
|
131 |
+
pose_interpolated = np.load(pose_interpolated_path)
|
132 |
+
intr = cam_intrinsics[1]
|
133 |
+
|
134 |
+
cam_infos = []
|
135 |
+
poses=[]
|
136 |
+
for idx, pose_npy in enumerate(pose_interpolated):
|
137 |
+
sys.stdout.write('\r')
|
138 |
+
sys.stdout.write("Reading camera {}/{}".format(idx+1, pose_interpolated.shape[0]))
|
139 |
+
sys.stdout.flush()
|
140 |
+
|
141 |
+
extr = pose_npy
|
142 |
+
intr = intr
|
143 |
+
height = intr.height
|
144 |
+
width = intr.width
|
145 |
+
|
146 |
+
uid = idx
|
147 |
+
R = extr[:3, :3].transpose()
|
148 |
+
T = extr[:3, 3]
|
149 |
+
pose = np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
|
150 |
+
# print(uid)
|
151 |
+
# print(pose.shape)
|
152 |
+
# pose = np.linalg.inv(pose)
|
153 |
+
poses.append(pose)
|
154 |
+
if intr.model=="SIMPLE_PINHOLE":
|
155 |
+
focal_length_x = intr.params[0]
|
156 |
+
FovY = focal2fov(focal_length_x, height)
|
157 |
+
FovX = focal2fov(focal_length_x, width)
|
158 |
+
elif intr.model=="PINHOLE":
|
159 |
+
focal_length_x = intr.params[0]
|
160 |
+
focal_length_y = intr.params[1]
|
161 |
+
FovY = focal2fov(focal_length_y, height)
|
162 |
+
FovX = focal2fov(focal_length_x, width)
|
163 |
+
else:
|
164 |
+
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
|
165 |
+
|
166 |
+
images_list = os.listdir(os.path.join(images_folder))
|
167 |
+
image_name_0 = images_list[0]
|
168 |
+
image_name = str(idx).zfill(4)
|
169 |
+
image = Image.open(images_folder + '/' + image_name_0)
|
170 |
+
|
171 |
+
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
172 |
+
image_path=images_folder, image_name=image_name, width=width, height=height)
|
173 |
+
cam_infos.append(cam_info)
|
174 |
+
|
175 |
+
sys.stdout.write('\n')
|
176 |
+
return cam_infos, poses
|
177 |
+
|
178 |
+
|
179 |
+
def fetchPly(path):
|
180 |
+
plydata = PlyData.read(path)
|
181 |
+
vertices = plydata['vertex']
|
182 |
+
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
|
183 |
+
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
|
184 |
+
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
|
185 |
+
return BasicPointCloud(points=positions, colors=colors, normals=normals)
|
186 |
+
|
187 |
+
def storePly(path, xyz, rgb):
|
188 |
+
# Define the dtype for the structured array
|
189 |
+
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
|
190 |
+
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
|
191 |
+
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
|
192 |
+
|
193 |
+
normals = np.zeros_like(xyz)
|
194 |
+
|
195 |
+
elements = np.empty(xyz.shape[0], dtype=dtype)
|
196 |
+
attributes = np.concatenate((xyz, normals, rgb), axis=1)
|
197 |
+
elements[:] = list(map(tuple, attributes))
|
198 |
+
|
199 |
+
# Create the PlyData object and write to file
|
200 |
+
vertex_element = PlyElement.describe(elements, 'vertex')
|
201 |
+
ply_data = PlyData([vertex_element])
|
202 |
+
ply_data.write(path)
|
203 |
+
|
204 |
+
def readColmapSceneInfo(path, images, eval, args, opt, llffhold=2):
|
205 |
+
# try:
|
206 |
+
# cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
|
207 |
+
# cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
|
208 |
+
# cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
|
209 |
+
# cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
|
210 |
+
# except:
|
211 |
+
|
212 |
+
##### For initializing test pose using PCD_Registration
|
213 |
+
if eval and opt.get_video==False:
|
214 |
+
print("Loading initial test pose for evaluation.")
|
215 |
+
cameras_extrinsic_file = os.path.join(path, "init_test_pose/sparse/0", "images.txt")
|
216 |
+
else:
|
217 |
+
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
|
218 |
+
|
219 |
+
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
|
220 |
+
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
|
221 |
+
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
|
222 |
+
|
223 |
+
reading_dir = "images" if images == None else images
|
224 |
+
|
225 |
+
if opt.get_video:
|
226 |
+
cam_infos_unsorted, poses = readColmapCamerasInterp(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), model_path=args.model_path)
|
227 |
+
else:
|
228 |
+
cam_infos_unsorted, poses = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), eval=eval)
|
229 |
+
sorting_indices = sorted(range(len(cam_infos_unsorted)), key=lambda x: cam_infos_unsorted[x].image_name)
|
230 |
+
cam_infos = [cam_infos_unsorted[i] for i in sorting_indices]
|
231 |
+
sorted_poses = [poses[i] for i in sorting_indices]
|
232 |
+
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
|
233 |
+
|
234 |
+
if eval:
|
235 |
+
# train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold != 0]
|
236 |
+
# test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold == 0]
|
237 |
+
# train_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold != 0]
|
238 |
+
# test_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold == 0]
|
239 |
+
|
240 |
+
train_cam_infos = cam_infos
|
241 |
+
test_cam_infos = cam_infos
|
242 |
+
train_poses = sorted_poses
|
243 |
+
test_poses = sorted_poses
|
244 |
+
|
245 |
+
else:
|
246 |
+
train_cam_infos = cam_infos
|
247 |
+
test_cam_infos = []
|
248 |
+
train_poses = sorted_poses
|
249 |
+
test_poses = []
|
250 |
+
|
251 |
+
nerf_normalization = getNerfppNorm(train_cam_infos)
|
252 |
+
|
253 |
+
ply_path = os.path.join(path, "sparse/0/points3D.ply")
|
254 |
+
bin_path = os.path.join(path, "sparse/0/points3D.bin")
|
255 |
+
txt_path = os.path.join(path, "sparse/0/points3D.txt")
|
256 |
+
if not os.path.exists(ply_path):
|
257 |
+
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
|
258 |
+
try:
|
259 |
+
xyz, rgb, _ = read_points3D_binary(bin_path)
|
260 |
+
except:
|
261 |
+
xyz, rgb, _ = read_points3D_text(txt_path)
|
262 |
+
storePly(ply_path, xyz, rgb)
|
263 |
+
try:
|
264 |
+
pcd = fetchPly(ply_path)
|
265 |
+
except:
|
266 |
+
pcd = None
|
267 |
+
|
268 |
+
# np.save("poses_family.npy", sorted_poses)
|
269 |
+
# breakpoint()
|
270 |
+
# np.save("3dpoints.npy", pcd.points)
|
271 |
+
# np.save("3dcolors.npy", pcd.colors)
|
272 |
+
|
273 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
274 |
+
train_cameras=train_cam_infos,
|
275 |
+
test_cameras=test_cam_infos,
|
276 |
+
nerf_normalization=nerf_normalization,
|
277 |
+
ply_path=ply_path,
|
278 |
+
train_poses=train_poses,
|
279 |
+
test_poses=test_poses)
|
280 |
+
return scene_info
|
281 |
+
|
282 |
+
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
|
283 |
+
cam_infos = []
|
284 |
+
|
285 |
+
with open(os.path.join(path, transformsfile)) as json_file:
|
286 |
+
contents = json.load(json_file)
|
287 |
+
fovx = contents["camera_angle_x"]
|
288 |
+
|
289 |
+
frames = contents["frames"]
|
290 |
+
for idx, frame in enumerate(frames):
|
291 |
+
cam_name = os.path.join(path, frame["file_path"] + extension)
|
292 |
+
|
293 |
+
# NeRF 'transform_matrix' is a camera-to-world transform
|
294 |
+
c2w = np.array(frame["transform_matrix"])
|
295 |
+
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
|
296 |
+
c2w[:3, 1:3] *= -1
|
297 |
+
|
298 |
+
# get the world-to-camera transform and set R, T
|
299 |
+
w2c = np.linalg.inv(c2w)
|
300 |
+
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
|
301 |
+
T = w2c[:3, 3]
|
302 |
+
|
303 |
+
image_path = os.path.join(path, cam_name)
|
304 |
+
image_name = Path(cam_name).stem
|
305 |
+
image = Image.open(image_path)
|
306 |
+
|
307 |
+
im_data = np.array(image.convert("RGBA"))
|
308 |
+
|
309 |
+
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
|
310 |
+
|
311 |
+
norm_data = im_data / 255.0
|
312 |
+
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
|
313 |
+
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
|
314 |
+
|
315 |
+
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
|
316 |
+
FovY = fovy
|
317 |
+
FovX = fovx
|
318 |
+
|
319 |
+
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
|
320 |
+
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
|
321 |
+
|
322 |
+
return cam_infos
|
323 |
+
|
324 |
+
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
|
325 |
+
print("Reading Training Transforms")
|
326 |
+
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
|
327 |
+
print("Reading Test Transforms")
|
328 |
+
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
|
329 |
+
|
330 |
+
if not eval:
|
331 |
+
train_cam_infos.extend(test_cam_infos)
|
332 |
+
test_cam_infos = []
|
333 |
+
|
334 |
+
nerf_normalization = getNerfppNorm(train_cam_infos)
|
335 |
+
|
336 |
+
ply_path = os.path.join(path, "points3d.ply")
|
337 |
+
if not os.path.exists(ply_path):
|
338 |
+
# Since this data set has no colmap data, we start with random points
|
339 |
+
num_pts = 100_000
|
340 |
+
print(f"Generating random point cloud ({num_pts})...")
|
341 |
+
|
342 |
+
# We create random points inside the bounds of the synthetic Blender scenes
|
343 |
+
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
|
344 |
+
shs = np.random.random((num_pts, 3)) / 255.0
|
345 |
+
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
|
346 |
+
|
347 |
+
storePly(ply_path, xyz, SH2RGB(shs) * 255)
|
348 |
+
try:
|
349 |
+
pcd = fetchPly(ply_path)
|
350 |
+
except:
|
351 |
+
pcd = None
|
352 |
+
|
353 |
+
scene_info = SceneInfo(point_cloud=pcd,
|
354 |
+
train_cameras=train_cam_infos,
|
355 |
+
test_cameras=test_cam_infos,
|
356 |
+
nerf_normalization=nerf_normalization,
|
357 |
+
ply_path=ply_path)
|
358 |
+
return scene_info
|
359 |
+
|
360 |
+
sceneLoadTypeCallbacks = {
|
361 |
+
"Colmap": readColmapSceneInfo,
|
362 |
+
"Blender" : readNerfSyntheticInfo
|
363 |
+
}
|
scene/gaussian_model.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
import torch
|
13 |
+
# from lietorch import SO3, SE3, Sim3, LieGroupParameter
|
14 |
+
import numpy as np
|
15 |
+
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
|
16 |
+
from torch import nn
|
17 |
+
import os
|
18 |
+
from utils.system_utils import mkdir_p
|
19 |
+
from plyfile import PlyData, PlyElement
|
20 |
+
from utils.sh_utils import RGB2SH
|
21 |
+
from simple_knn._C import distCUDA2
|
22 |
+
from utils.graphics_utils import BasicPointCloud
|
23 |
+
from utils.general_utils import strip_symmetric, build_scaling_rotation
|
24 |
+
from scipy.spatial.transform import Rotation as R
|
25 |
+
from utils.pose_utils import rotation2quad, get_tensor_from_camera
|
26 |
+
from utils.graphics_utils import getWorld2View2
|
27 |
+
|
28 |
+
def quaternion_to_rotation_matrix(quaternion):
|
29 |
+
"""
|
30 |
+
Convert a quaternion to a rotation matrix.
|
31 |
+
|
32 |
+
Parameters:
|
33 |
+
- quaternion: A tensor of shape (..., 4) representing quaternions.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
- A tensor of shape (..., 3, 3) representing rotation matrices.
|
37 |
+
"""
|
38 |
+
# Ensure quaternion is of float type for computation
|
39 |
+
quaternion = quaternion.float()
|
40 |
+
|
41 |
+
# Normalize the quaternion to unit length
|
42 |
+
quaternion = quaternion / quaternion.norm(p=2, dim=-1, keepdim=True)
|
43 |
+
|
44 |
+
# Extract components
|
45 |
+
w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3]
|
46 |
+
|
47 |
+
# Compute rotation matrix components
|
48 |
+
xx, yy, zz = x * x, y * y, z * z
|
49 |
+
xy, xz, yz = x * y, x * z, y * z
|
50 |
+
xw, yw, zw = x * w, y * w, z * w
|
51 |
+
|
52 |
+
# Assemble the rotation matrix
|
53 |
+
R = torch.stack([
|
54 |
+
torch.stack([1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw)], dim=-1),
|
55 |
+
torch.stack([ 2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw)], dim=-1),
|
56 |
+
torch.stack([ 2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy)], dim=-1)
|
57 |
+
], dim=-2)
|
58 |
+
|
59 |
+
return R
|
60 |
+
|
61 |
+
|
62 |
+
class GaussianModel:
|
63 |
+
|
64 |
+
def setup_functions(self):
|
65 |
+
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
|
66 |
+
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
|
67 |
+
actual_covariance = L @ L.transpose(1, 2)
|
68 |
+
symm = strip_symmetric(actual_covariance)
|
69 |
+
return symm
|
70 |
+
|
71 |
+
self.scaling_activation = torch.exp
|
72 |
+
self.scaling_inverse_activation = torch.log
|
73 |
+
|
74 |
+
self.covariance_activation = build_covariance_from_scaling_rotation
|
75 |
+
|
76 |
+
self.opacity_activation = torch.sigmoid
|
77 |
+
self.inverse_opacity_activation = inverse_sigmoid
|
78 |
+
|
79 |
+
self.rotation_activation = torch.nn.functional.normalize
|
80 |
+
|
81 |
+
|
82 |
+
def __init__(self, sh_degree : int):
|
83 |
+
self.active_sh_degree = 0
|
84 |
+
self.max_sh_degree = sh_degree
|
85 |
+
self._xyz = torch.empty(0)
|
86 |
+
self._features_dc = torch.empty(0)
|
87 |
+
self._features_rest = torch.empty(0)
|
88 |
+
self._scaling = torch.empty(0)
|
89 |
+
self._rotation = torch.empty(0)
|
90 |
+
self._opacity = torch.empty(0)
|
91 |
+
self.max_radii2D = torch.empty(0)
|
92 |
+
self.xyz_gradient_accum = torch.empty(0)
|
93 |
+
self.denom = torch.empty(0)
|
94 |
+
self.optimizer = None
|
95 |
+
self.percent_dense = 0
|
96 |
+
self.spatial_lr_scale = 0
|
97 |
+
self.setup_functions()
|
98 |
+
|
99 |
+
def capture(self):
|
100 |
+
return (
|
101 |
+
self.active_sh_degree,
|
102 |
+
self._xyz,
|
103 |
+
self._features_dc,
|
104 |
+
self._features_rest,
|
105 |
+
self._scaling,
|
106 |
+
self._rotation,
|
107 |
+
self._opacity,
|
108 |
+
self.max_radii2D,
|
109 |
+
self.xyz_gradient_accum,
|
110 |
+
self.denom,
|
111 |
+
self.optimizer.state_dict(),
|
112 |
+
self.spatial_lr_scale,
|
113 |
+
self.P,
|
114 |
+
)
|
115 |
+
|
116 |
+
def restore(self, model_args, training_args):
|
117 |
+
(self.active_sh_degree,
|
118 |
+
self._xyz,
|
119 |
+
self._features_dc,
|
120 |
+
self._features_rest,
|
121 |
+
self._scaling,
|
122 |
+
self._rotation,
|
123 |
+
self._opacity,
|
124 |
+
self.max_radii2D,
|
125 |
+
xyz_gradient_accum,
|
126 |
+
denom,
|
127 |
+
opt_dict,
|
128 |
+
self.spatial_lr_scale,
|
129 |
+
self.P) = model_args
|
130 |
+
self.training_setup(training_args)
|
131 |
+
self.xyz_gradient_accum = xyz_gradient_accum
|
132 |
+
self.denom = denom
|
133 |
+
self.optimizer.load_state_dict(opt_dict)
|
134 |
+
|
135 |
+
@property
|
136 |
+
def get_scaling(self):
|
137 |
+
return self.scaling_activation(self._scaling)
|
138 |
+
|
139 |
+
@property
|
140 |
+
def get_rotation(self):
|
141 |
+
return self.rotation_activation(self._rotation)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def get_xyz(self):
|
145 |
+
return self._xyz
|
146 |
+
|
147 |
+
def compute_relative_world_to_camera(self, R1, t1, R2, t2):
|
148 |
+
# Create a row of zeros with a one at the end, for homogeneous coordinates
|
149 |
+
zero_row = np.array([[0, 0, 0, 1]], dtype=np.float32)
|
150 |
+
|
151 |
+
# Compute the inverse of the first extrinsic matrix
|
152 |
+
E1_inv = np.hstack([R1.T, -R1.T @ t1.reshape(-1, 1)]) # Transpose and reshape for correct dimensions
|
153 |
+
E1_inv = np.vstack([E1_inv, zero_row]) # Append the zero_row to make it a 4x4 matrix
|
154 |
+
|
155 |
+
# Compute the second extrinsic matrix
|
156 |
+
E2 = np.hstack([R2, -R2 @ t2.reshape(-1, 1)]) # No need to transpose R2
|
157 |
+
E2 = np.vstack([E2, zero_row]) # Append the zero_row to make it a 4x4 matrix
|
158 |
+
|
159 |
+
# Compute the relative transformation
|
160 |
+
E_rel = E2 @ E1_inv
|
161 |
+
|
162 |
+
return E_rel
|
163 |
+
|
164 |
+
def init_RT_seq(self, cam_list):
|
165 |
+
poses =[]
|
166 |
+
for cam in cam_list[1.0]:
|
167 |
+
p = get_tensor_from_camera(cam.world_view_transform.transpose(0, 1)) # R T -> quat t
|
168 |
+
poses.append(p)
|
169 |
+
poses = torch.stack(poses)
|
170 |
+
self.P = poses.cuda().requires_grad_(True)
|
171 |
+
|
172 |
+
|
173 |
+
def get_RT(self, idx):
|
174 |
+
pose = self.P[idx]
|
175 |
+
return pose
|
176 |
+
|
177 |
+
def get_RT_test(self, idx):
|
178 |
+
pose = self.test_P[idx]
|
179 |
+
return pose
|
180 |
+
|
181 |
+
@property
|
182 |
+
def get_features(self):
|
183 |
+
features_dc = self._features_dc
|
184 |
+
features_rest = self._features_rest
|
185 |
+
return torch.cat((features_dc, features_rest), dim=1)
|
186 |
+
|
187 |
+
@property
|
188 |
+
def get_opacity(self):
|
189 |
+
return self.opacity_activation(self._opacity)
|
190 |
+
|
191 |
+
def get_covariance(self, scaling_modifier = 1):
|
192 |
+
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
|
193 |
+
|
194 |
+
def oneupSHdegree(self):
|
195 |
+
if self.active_sh_degree < self.max_sh_degree:
|
196 |
+
self.active_sh_degree += 1
|
197 |
+
|
198 |
+
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
|
199 |
+
|
200 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # gradio
|
201 |
+
|
202 |
+
self.spatial_lr_scale = spatial_lr_scale
|
203 |
+
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
|
204 |
+
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
|
205 |
+
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
|
206 |
+
features[:, :3, 0 ] = fused_color
|
207 |
+
features[:, 3:, 1:] = 0.0
|
208 |
+
|
209 |
+
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
|
210 |
+
|
211 |
+
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
|
212 |
+
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
|
213 |
+
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
|
214 |
+
rots[:, 0] = 1
|
215 |
+
|
216 |
+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
|
217 |
+
|
218 |
+
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
|
219 |
+
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
|
220 |
+
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
|
221 |
+
self._scaling = nn.Parameter(scales.requires_grad_(True))
|
222 |
+
self._rotation = nn.Parameter(rots.requires_grad_(True))
|
223 |
+
self._opacity = nn.Parameter(opacities.requires_grad_(True))
|
224 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
225 |
+
|
226 |
+
def training_setup(self, training_args):
|
227 |
+
self.percent_dense = training_args.percent_dense
|
228 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
229 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
230 |
+
|
231 |
+
l = [
|
232 |
+
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
|
233 |
+
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
|
234 |
+
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
|
235 |
+
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
|
236 |
+
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
|
237 |
+
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
|
238 |
+
]
|
239 |
+
|
240 |
+
l_cam = [{'params': [self.P],'lr': training_args.rotation_lr*0.1, "name": "pose"},]
|
241 |
+
|
242 |
+
l += l_cam
|
243 |
+
|
244 |
+
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
|
245 |
+
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
|
246 |
+
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
|
247 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
248 |
+
max_steps=training_args.position_lr_max_steps)
|
249 |
+
self.cam_scheduler_args = get_expon_lr_func(
|
250 |
+
# lr_init=0,
|
251 |
+
# lr_final=0,
|
252 |
+
lr_init=training_args.rotation_lr*0.1,
|
253 |
+
lr_final=training_args.rotation_lr*0.001,
|
254 |
+
# lr_init=training_args.position_lr_init*self.spatial_lr_scale*10,
|
255 |
+
# lr_final=training_args.position_lr_final*self.spatial_lr_scale*10,
|
256 |
+
lr_delay_mult=training_args.position_lr_delay_mult,
|
257 |
+
max_steps=1000)
|
258 |
+
|
259 |
+
def update_learning_rate(self, iteration):
|
260 |
+
''' Learning rate scheduling per step '''
|
261 |
+
for param_group in self.optimizer.param_groups:
|
262 |
+
if param_group["name"] == "pose":
|
263 |
+
lr = self.cam_scheduler_args(iteration)
|
264 |
+
# print("pose learning rate", iteration, lr)
|
265 |
+
param_group['lr'] = lr
|
266 |
+
if param_group["name"] == "xyz":
|
267 |
+
lr = self.xyz_scheduler_args(iteration)
|
268 |
+
param_group['lr'] = lr
|
269 |
+
# return lr
|
270 |
+
|
271 |
+
def construct_list_of_attributes(self):
|
272 |
+
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
|
273 |
+
# All channels except the 3 DC
|
274 |
+
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
|
275 |
+
l.append('f_dc_{}'.format(i))
|
276 |
+
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
|
277 |
+
l.append('f_rest_{}'.format(i))
|
278 |
+
l.append('opacity')
|
279 |
+
for i in range(self._scaling.shape[1]):
|
280 |
+
l.append('scale_{}'.format(i))
|
281 |
+
for i in range(self._rotation.shape[1]):
|
282 |
+
l.append('rot_{}'.format(i))
|
283 |
+
return l
|
284 |
+
|
285 |
+
def save_ply(self, path):
|
286 |
+
mkdir_p(os.path.dirname(path))
|
287 |
+
|
288 |
+
xyz = self._xyz.detach().cpu().numpy()
|
289 |
+
normals = np.zeros_like(xyz)
|
290 |
+
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
291 |
+
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
|
292 |
+
opacities = self._opacity.detach().cpu().numpy()
|
293 |
+
scale = self._scaling.detach().cpu().numpy()
|
294 |
+
rotation = self._rotation.detach().cpu().numpy()
|
295 |
+
|
296 |
+
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
|
297 |
+
|
298 |
+
elements = np.empty(xyz.shape[0], dtype=dtype_full)
|
299 |
+
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
|
300 |
+
elements[:] = list(map(tuple, attributes))
|
301 |
+
el = PlyElement.describe(elements, 'vertex')
|
302 |
+
PlyData([el]).write(path)
|
303 |
+
|
304 |
+
def reset_opacity(self):
|
305 |
+
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
|
306 |
+
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
|
307 |
+
self._opacity = optimizable_tensors["opacity"]
|
308 |
+
|
309 |
+
def load_ply(self, path):
|
310 |
+
plydata = PlyData.read(path)
|
311 |
+
|
312 |
+
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
|
313 |
+
np.asarray(plydata.elements[0]["y"]),
|
314 |
+
np.asarray(plydata.elements[0]["z"])), axis=1)
|
315 |
+
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
|
316 |
+
|
317 |
+
features_dc = np.zeros((xyz.shape[0], 3, 1))
|
318 |
+
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
|
319 |
+
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
|
320 |
+
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
|
321 |
+
|
322 |
+
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
|
323 |
+
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
|
324 |
+
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
|
325 |
+
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
|
326 |
+
for idx, attr_name in enumerate(extra_f_names):
|
327 |
+
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
328 |
+
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
|
329 |
+
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
|
330 |
+
|
331 |
+
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
|
332 |
+
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
|
333 |
+
scales = np.zeros((xyz.shape[0], len(scale_names)))
|
334 |
+
for idx, attr_name in enumerate(scale_names):
|
335 |
+
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
336 |
+
|
337 |
+
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
|
338 |
+
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
|
339 |
+
rots = np.zeros((xyz.shape[0], len(rot_names)))
|
340 |
+
for idx, attr_name in enumerate(rot_names):
|
341 |
+
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
|
342 |
+
|
343 |
+
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
|
344 |
+
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
345 |
+
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
|
346 |
+
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
|
347 |
+
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
|
348 |
+
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
|
349 |
+
|
350 |
+
self.active_sh_degree = self.max_sh_degree
|
351 |
+
|
352 |
+
def replace_tensor_to_optimizer(self, tensor, name):
|
353 |
+
optimizable_tensors = {}
|
354 |
+
for group in self.optimizer.param_groups:
|
355 |
+
if group["name"] == name:
|
356 |
+
# breakpoint()
|
357 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
358 |
+
stored_state["exp_avg"] = torch.zeros_like(tensor)
|
359 |
+
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
|
360 |
+
|
361 |
+
del self.optimizer.state[group['params'][0]]
|
362 |
+
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
|
363 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
364 |
+
|
365 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
366 |
+
return optimizable_tensors
|
367 |
+
|
368 |
+
def _prune_optimizer(self, mask):
|
369 |
+
optimizable_tensors = {}
|
370 |
+
for group in self.optimizer.param_groups:
|
371 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
372 |
+
if stored_state is not None:
|
373 |
+
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
|
374 |
+
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
|
375 |
+
|
376 |
+
del self.optimizer.state[group['params'][0]]
|
377 |
+
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
|
378 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
379 |
+
|
380 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
381 |
+
else:
|
382 |
+
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
|
383 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
384 |
+
return optimizable_tensors
|
385 |
+
|
386 |
+
def prune_points(self, mask):
|
387 |
+
valid_points_mask = ~mask
|
388 |
+
optimizable_tensors = self._prune_optimizer(valid_points_mask)
|
389 |
+
|
390 |
+
self._xyz = optimizable_tensors["xyz"]
|
391 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
392 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
393 |
+
self._opacity = optimizable_tensors["opacity"]
|
394 |
+
self._scaling = optimizable_tensors["scaling"]
|
395 |
+
self._rotation = optimizable_tensors["rotation"]
|
396 |
+
|
397 |
+
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
|
398 |
+
|
399 |
+
self.denom = self.denom[valid_points_mask]
|
400 |
+
self.max_radii2D = self.max_radii2D[valid_points_mask]
|
401 |
+
|
402 |
+
def cat_tensors_to_optimizer(self, tensors_dict):
|
403 |
+
optimizable_tensors = {}
|
404 |
+
for group in self.optimizer.param_groups:
|
405 |
+
assert len(group["params"]) == 1
|
406 |
+
extension_tensor = tensors_dict[group["name"]]
|
407 |
+
stored_state = self.optimizer.state.get(group['params'][0], None)
|
408 |
+
if stored_state is not None:
|
409 |
+
|
410 |
+
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
|
411 |
+
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
|
412 |
+
|
413 |
+
del self.optimizer.state[group['params'][0]]
|
414 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
415 |
+
self.optimizer.state[group['params'][0]] = stored_state
|
416 |
+
|
417 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
418 |
+
else:
|
419 |
+
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
|
420 |
+
optimizable_tensors[group["name"]] = group["params"][0]
|
421 |
+
|
422 |
+
return optimizable_tensors
|
423 |
+
|
424 |
+
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
|
425 |
+
d = {"xyz": new_xyz,
|
426 |
+
"f_dc": new_features_dc,
|
427 |
+
"f_rest": new_features_rest,
|
428 |
+
"opacity": new_opacities,
|
429 |
+
"scaling" : new_scaling,
|
430 |
+
"rotation" : new_rotation}
|
431 |
+
|
432 |
+
optimizable_tensors = self.cat_tensors_to_optimizer(d)
|
433 |
+
self._xyz = optimizable_tensors["xyz"]
|
434 |
+
self._features_dc = optimizable_tensors["f_dc"]
|
435 |
+
self._features_rest = optimizable_tensors["f_rest"]
|
436 |
+
self._opacity = optimizable_tensors["opacity"]
|
437 |
+
self._scaling = optimizable_tensors["scaling"]
|
438 |
+
self._rotation = optimizable_tensors["rotation"]
|
439 |
+
|
440 |
+
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
441 |
+
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
|
442 |
+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
|
443 |
+
|
444 |
+
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
|
445 |
+
n_init_points = self.get_xyz.shape[0]
|
446 |
+
# Extract points that satisfy the gradient condition
|
447 |
+
padded_grad = torch.zeros((n_init_points), device="cuda")
|
448 |
+
padded_grad[:grads.shape[0]] = grads.squeeze()
|
449 |
+
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
|
450 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
451 |
+
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
|
452 |
+
|
453 |
+
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
|
454 |
+
means =torch.zeros((stds.size(0), 3),device="cuda")
|
455 |
+
samples = torch.normal(mean=means, std=stds)
|
456 |
+
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
|
457 |
+
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
|
458 |
+
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
|
459 |
+
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
|
460 |
+
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
|
461 |
+
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
|
462 |
+
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
|
463 |
+
|
464 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
|
465 |
+
|
466 |
+
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
|
467 |
+
self.prune_points(prune_filter)
|
468 |
+
|
469 |
+
def densify_and_clone(self, grads, grad_threshold, scene_extent):
|
470 |
+
# Extract points that satisfy the gradient condition
|
471 |
+
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
|
472 |
+
selected_pts_mask = torch.logical_and(selected_pts_mask,
|
473 |
+
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
|
474 |
+
|
475 |
+
new_xyz = self._xyz[selected_pts_mask]
|
476 |
+
new_features_dc = self._features_dc[selected_pts_mask]
|
477 |
+
new_features_rest = self._features_rest[selected_pts_mask]
|
478 |
+
new_opacities = self._opacity[selected_pts_mask]
|
479 |
+
new_scaling = self._scaling[selected_pts_mask]
|
480 |
+
new_rotation = self._rotation[selected_pts_mask]
|
481 |
+
|
482 |
+
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
|
483 |
+
|
484 |
+
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
|
485 |
+
grads = self.xyz_gradient_accum / self.denom
|
486 |
+
grads[grads.isnan()] = 0.0
|
487 |
+
|
488 |
+
# self.densify_and_clone(grads, max_grad, extent)
|
489 |
+
# self.densify_and_split(grads, max_grad, extent)
|
490 |
+
|
491 |
+
prune_mask = (self.get_opacity < min_opacity).squeeze()
|
492 |
+
if max_screen_size:
|
493 |
+
big_points_vs = self.max_radii2D > max_screen_size
|
494 |
+
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
|
495 |
+
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
|
496 |
+
self.prune_points(prune_mask)
|
497 |
+
|
498 |
+
torch.cuda.empty_cache()
|
499 |
+
|
500 |
+
def add_densification_stats(self, viewspace_point_tensor, update_filter):
|
501 |
+
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
|
502 |
+
self.denom[update_filter] += 1
|
submodules/diff-gaussian-rasterization/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
build/
|
2 |
+
diff_gaussian_rasterization.egg-info/
|
3 |
+
dist/
|
submodules/diff-gaussian-rasterization/.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/glm"]
|
2 |
+
path = third_party/glm
|
3 |
+
url = https://github.com/g-truc/glm.git
|
submodules/diff-gaussian-rasterization/CMakeLists.txt
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
cmake_minimum_required(VERSION 3.20)
|
13 |
+
|
14 |
+
project(DiffRast LANGUAGES CUDA CXX)
|
15 |
+
|
16 |
+
set(CMAKE_CXX_STANDARD 17)
|
17 |
+
set(CMAKE_CXX_EXTENSIONS OFF)
|
18 |
+
set(CMAKE_CUDA_STANDARD 17)
|
19 |
+
|
20 |
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
21 |
+
|
22 |
+
add_library(CudaRasterizer
|
23 |
+
cuda_rasterizer/backward.h
|
24 |
+
cuda_rasterizer/backward.cu
|
25 |
+
cuda_rasterizer/forward.h
|
26 |
+
cuda_rasterizer/forward.cu
|
27 |
+
cuda_rasterizer/auxiliary.h
|
28 |
+
cuda_rasterizer/rasterizer_impl.cu
|
29 |
+
cuda_rasterizer/rasterizer_impl.h
|
30 |
+
cuda_rasterizer/rasterizer.h
|
31 |
+
)
|
32 |
+
|
33 |
+
set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
|
34 |
+
|
35 |
+
target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
|
36 |
+
target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
submodules/diff-gaussian-rasterization/LICENSE.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Gaussian-Splatting License
|
2 |
+
===========================
|
3 |
+
|
4 |
+
**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
|
5 |
+
The *Software* is in the process of being registered with the Agence pour la Protection des
|
6 |
+
Programmes (APP).
|
7 |
+
|
8 |
+
The *Software* is still being developed by the *Licensor*.
|
9 |
+
|
10 |
+
*Licensor*'s goal is to allow the research community to use, test and evaluate
|
11 |
+
the *Software*.
|
12 |
+
|
13 |
+
## 1. Definitions
|
14 |
+
|
15 |
+
*Licensee* means any person or entity that uses the *Software* and distributes
|
16 |
+
its *Work*.
|
17 |
+
|
18 |
+
*Licensor* means the owners of the *Software*, i.e Inria and MPII
|
19 |
+
|
20 |
+
*Software* means the original work of authorship made available under this
|
21 |
+
License ie gaussian-splatting.
|
22 |
+
|
23 |
+
*Work* means the *Software* and any additions to or derivative works of the
|
24 |
+
*Software* that are made available under this License.
|
25 |
+
|
26 |
+
|
27 |
+
## 2. Purpose
|
28 |
+
This license is intended to define the rights granted to the *Licensee* by
|
29 |
+
Licensors under the *Software*.
|
30 |
+
|
31 |
+
## 3. Rights granted
|
32 |
+
|
33 |
+
For the above reasons Licensors have decided to distribute the *Software*.
|
34 |
+
Licensors grant non-exclusive rights to use the *Software* for research purposes
|
35 |
+
to research users (both academic and industrial), free of charge, without right
|
36 |
+
to sublicense.. The *Software* may be used "non-commercially", i.e., for research
|
37 |
+
and/or evaluation purposes only.
|
38 |
+
|
39 |
+
Subject to the terms and conditions of this License, you are granted a
|
40 |
+
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
|
41 |
+
publicly display, publicly perform and distribute its *Work* and any resulting
|
42 |
+
derivative works in any form.
|
43 |
+
|
44 |
+
## 4. Limitations
|
45 |
+
|
46 |
+
**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
|
47 |
+
so under this License, (b) you include a complete copy of this License with
|
48 |
+
your distribution, and (c) you retain without modification any copyright,
|
49 |
+
patent, trademark, or attribution notices that are present in the *Work*.
|
50 |
+
|
51 |
+
**4.2 Derivative Works.** You may specify that additional or different terms apply
|
52 |
+
to the use, reproduction, and distribution of your derivative works of the *Work*
|
53 |
+
("Your Terms") only if (a) Your Terms provide that the use limitation in
|
54 |
+
Section 2 applies to your derivative works, and (b) you identify the specific
|
55 |
+
derivative works that are subject to Your Terms. Notwithstanding Your Terms,
|
56 |
+
this License (including the redistribution requirements in Section 3.1) will
|
57 |
+
continue to apply to the *Work* itself.
|
58 |
+
|
59 |
+
**4.3** Any other use without of prior consent of Licensors is prohibited. Research
|
60 |
+
users explicitly acknowledge having received from Licensors all information
|
61 |
+
allowing to appreciate the adequacy between of the *Software* and their needs and
|
62 |
+
to undertake all necessary precautions for its execution and use.
|
63 |
+
|
64 |
+
**4.4** The *Software* is provided both as a compiled library file and as source
|
65 |
+
code. In case of using the *Software* for a publication or other results obtained
|
66 |
+
through the use of the *Software*, users are strongly encouraged to cite the
|
67 |
+
corresponding publications as explained in the documentation of the *Software*.
|
68 |
+
|
69 |
+
## 5. Disclaimer
|
70 |
+
|
71 |
+
THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
|
72 |
+
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
|
73 |
+
UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
|
74 |
+
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
|
75 |
+
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
|
76 |
+
USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
|
77 |
+
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
|
78 |
+
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
79 |
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
80 |
+
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
|
81 |
+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
82 |
+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
|
83 |
+
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
|
submodules/diff-gaussian-rasterization/README.md
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Differential Gaussian Rasterization
|
2 |
+
|
3 |
+
Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us.
|
4 |
+
|
5 |
+
<section class="section" id="BibTeX">
|
6 |
+
<div class="container is-max-desktop content">
|
7 |
+
<h2 class="title">BibTeX</h2>
|
8 |
+
<pre><code>@Article{kerbl3Dgaussians,
|
9 |
+
author = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
|
10 |
+
title = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
|
11 |
+
journal = {ACM Transactions on Graphics},
|
12 |
+
number = {4},
|
13 |
+
volume = {42},
|
14 |
+
month = {July},
|
15 |
+
year = {2023},
|
16 |
+
url = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
|
17 |
+
}</code></pre>
|
18 |
+
</div>
|
19 |
+
</section>
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
|
14 |
+
|
15 |
+
#include "config.h"
|
16 |
+
#include "stdio.h"
|
17 |
+
|
18 |
+
#define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
|
19 |
+
#define NUM_WARPS (BLOCK_SIZE/32)
|
20 |
+
|
21 |
+
// Spherical harmonics coefficients
|
22 |
+
__device__ const float SH_C0 = 0.28209479177387814f;
|
23 |
+
__device__ const float SH_C1 = 0.4886025119029199f;
|
24 |
+
__device__ const float SH_C2[] = {
|
25 |
+
1.0925484305920792f,
|
26 |
+
-1.0925484305920792f,
|
27 |
+
0.31539156525252005f,
|
28 |
+
-1.0925484305920792f,
|
29 |
+
0.5462742152960396f
|
30 |
+
};
|
31 |
+
__device__ const float SH_C3[] = {
|
32 |
+
-0.5900435899266435f,
|
33 |
+
2.890611442640554f,
|
34 |
+
-0.4570457994644658f,
|
35 |
+
0.3731763325901154f,
|
36 |
+
-0.4570457994644658f,
|
37 |
+
1.445305721320277f,
|
38 |
+
-0.5900435899266435f
|
39 |
+
};
|
40 |
+
|
41 |
+
__forceinline__ __device__ float ndc2Pix(float v, int S)
|
42 |
+
{
|
43 |
+
return ((v + 1.0) * S - 1.0) * 0.5;
|
44 |
+
}
|
45 |
+
|
46 |
+
__forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
|
47 |
+
{
|
48 |
+
rect_min = {
|
49 |
+
min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
|
50 |
+
min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
|
51 |
+
};
|
52 |
+
rect_max = {
|
53 |
+
min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
|
54 |
+
min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
|
55 |
+
};
|
56 |
+
}
|
57 |
+
|
58 |
+
__forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix)
|
59 |
+
{
|
60 |
+
float3 transformed = {
|
61 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
|
62 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
|
63 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
|
64 |
+
};
|
65 |
+
return transformed;
|
66 |
+
}
|
67 |
+
|
68 |
+
__forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
|
69 |
+
{
|
70 |
+
float4 transformed = {
|
71 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
|
72 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
|
73 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
|
74 |
+
matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
|
75 |
+
};
|
76 |
+
return transformed;
|
77 |
+
}
|
78 |
+
|
79 |
+
__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix)
|
80 |
+
{
|
81 |
+
float3 transformed = {
|
82 |
+
matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
|
83 |
+
matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
|
84 |
+
matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
|
85 |
+
};
|
86 |
+
return transformed;
|
87 |
+
}
|
88 |
+
|
89 |
+
__forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix)
|
90 |
+
{
|
91 |
+
float3 transformed = {
|
92 |
+
matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
|
93 |
+
matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
|
94 |
+
matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
|
95 |
+
};
|
96 |
+
return transformed;
|
97 |
+
}
|
98 |
+
|
99 |
+
__forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
|
100 |
+
{
|
101 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
|
102 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
103 |
+
float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
|
104 |
+
return dnormvdz;
|
105 |
+
}
|
106 |
+
|
107 |
+
__forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
|
108 |
+
{
|
109 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
|
110 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
111 |
+
|
112 |
+
float3 dnormvdv;
|
113 |
+
dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
|
114 |
+
dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32;
|
115 |
+
dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
|
116 |
+
return dnormvdv;
|
117 |
+
}
|
118 |
+
|
119 |
+
__forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
|
120 |
+
{
|
121 |
+
float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
|
122 |
+
float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
|
123 |
+
|
124 |
+
float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
|
125 |
+
float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
|
126 |
+
float4 dnormvdv;
|
127 |
+
dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
|
128 |
+
dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
|
129 |
+
dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
|
130 |
+
dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
|
131 |
+
return dnormvdv;
|
132 |
+
}
|
133 |
+
|
134 |
+
__forceinline__ __device__ float sigmoid(float x)
|
135 |
+
{
|
136 |
+
return 1.0f / (1.0f + expf(-x));
|
137 |
+
}
|
138 |
+
|
139 |
+
__forceinline__ __device__ bool in_frustum(int idx,
|
140 |
+
const float* orig_points,
|
141 |
+
const float* viewmatrix,
|
142 |
+
const float* projmatrix,
|
143 |
+
bool prefiltered,
|
144 |
+
float3& p_view)
|
145 |
+
{
|
146 |
+
float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
|
147 |
+
|
148 |
+
// Bring points to screen space
|
149 |
+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
|
150 |
+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
|
151 |
+
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
|
152 |
+
p_view = transformPoint4x3(p_orig, viewmatrix);
|
153 |
+
|
154 |
+
if (p_view.z <= 0.01f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
|
155 |
+
{
|
156 |
+
if (prefiltered)
|
157 |
+
{
|
158 |
+
printf("Point is filtered although prefiltered is set. This shouldn't happen!");
|
159 |
+
__trap();
|
160 |
+
}
|
161 |
+
return false;
|
162 |
+
}
|
163 |
+
return true;
|
164 |
+
}
|
165 |
+
|
166 |
+
#define CHECK_CUDA(A, debug) \
|
167 |
+
A; if(debug) { \
|
168 |
+
auto ret = cudaDeviceSynchronize(); \
|
169 |
+
if (ret != cudaSuccess) { \
|
170 |
+
std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
|
171 |
+
throw std::runtime_error(cudaGetErrorString(ret)); \
|
172 |
+
} \
|
173 |
+
}
|
174 |
+
|
175 |
+
#endif
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "backward.h"
|
13 |
+
#include "auxiliary.h"
|
14 |
+
#include <cooperative_groups.h>
|
15 |
+
#include <cooperative_groups/reduce.h>
|
16 |
+
namespace cg = cooperative_groups;
|
17 |
+
|
18 |
+
// Backward pass for conversion of spherical harmonics to RGB for
|
19 |
+
// each Gaussian.
|
20 |
+
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
|
21 |
+
{
|
22 |
+
// Compute intermediate values, as it is done during forward
|
23 |
+
glm::vec3 pos = means[idx];
|
24 |
+
glm::vec3 dir_orig = pos - campos;
|
25 |
+
glm::vec3 dir = dir_orig / glm::length(dir_orig);
|
26 |
+
|
27 |
+
glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
|
28 |
+
|
29 |
+
// Use PyTorch rule for clamping: if clamping was applied,
|
30 |
+
// gradient becomes 0.
|
31 |
+
glm::vec3 dL_dRGB = dL_dcolor[idx];
|
32 |
+
dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
|
33 |
+
dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
|
34 |
+
dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
|
35 |
+
|
36 |
+
glm::vec3 dRGBdx(0, 0, 0);
|
37 |
+
glm::vec3 dRGBdy(0, 0, 0);
|
38 |
+
glm::vec3 dRGBdz(0, 0, 0);
|
39 |
+
float x = dir.x;
|
40 |
+
float y = dir.y;
|
41 |
+
float z = dir.z;
|
42 |
+
|
43 |
+
// Target location for this Gaussian to write SH gradients to
|
44 |
+
glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
|
45 |
+
|
46 |
+
// No tricks here, just high school-level calculus.
|
47 |
+
float dRGBdsh0 = SH_C0;
|
48 |
+
dL_dsh[0] = dRGBdsh0 * dL_dRGB;
|
49 |
+
if (deg > 0)
|
50 |
+
{
|
51 |
+
float dRGBdsh1 = -SH_C1 * y;
|
52 |
+
float dRGBdsh2 = SH_C1 * z;
|
53 |
+
float dRGBdsh3 = -SH_C1 * x;
|
54 |
+
dL_dsh[1] = dRGBdsh1 * dL_dRGB;
|
55 |
+
dL_dsh[2] = dRGBdsh2 * dL_dRGB;
|
56 |
+
dL_dsh[3] = dRGBdsh3 * dL_dRGB;
|
57 |
+
|
58 |
+
dRGBdx = -SH_C1 * sh[3];
|
59 |
+
dRGBdy = -SH_C1 * sh[1];
|
60 |
+
dRGBdz = SH_C1 * sh[2];
|
61 |
+
|
62 |
+
if (deg > 1)
|
63 |
+
{
|
64 |
+
float xx = x * x, yy = y * y, zz = z * z;
|
65 |
+
float xy = x * y, yz = y * z, xz = x * z;
|
66 |
+
|
67 |
+
float dRGBdsh4 = SH_C2[0] * xy;
|
68 |
+
float dRGBdsh5 = SH_C2[1] * yz;
|
69 |
+
float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
|
70 |
+
float dRGBdsh7 = SH_C2[3] * xz;
|
71 |
+
float dRGBdsh8 = SH_C2[4] * (xx - yy);
|
72 |
+
dL_dsh[4] = dRGBdsh4 * dL_dRGB;
|
73 |
+
dL_dsh[5] = dRGBdsh5 * dL_dRGB;
|
74 |
+
dL_dsh[6] = dRGBdsh6 * dL_dRGB;
|
75 |
+
dL_dsh[7] = dRGBdsh7 * dL_dRGB;
|
76 |
+
dL_dsh[8] = dRGBdsh8 * dL_dRGB;
|
77 |
+
|
78 |
+
dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
|
79 |
+
dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
|
80 |
+
dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];
|
81 |
+
|
82 |
+
if (deg > 2)
|
83 |
+
{
|
84 |
+
float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
|
85 |
+
float dRGBdsh10 = SH_C3[1] * xy * z;
|
86 |
+
float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
|
87 |
+
float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
|
88 |
+
float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
|
89 |
+
float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
|
90 |
+
float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
|
91 |
+
dL_dsh[9] = dRGBdsh9 * dL_dRGB;
|
92 |
+
dL_dsh[10] = dRGBdsh10 * dL_dRGB;
|
93 |
+
dL_dsh[11] = dRGBdsh11 * dL_dRGB;
|
94 |
+
dL_dsh[12] = dRGBdsh12 * dL_dRGB;
|
95 |
+
dL_dsh[13] = dRGBdsh13 * dL_dRGB;
|
96 |
+
dL_dsh[14] = dRGBdsh14 * dL_dRGB;
|
97 |
+
dL_dsh[15] = dRGBdsh15 * dL_dRGB;
|
98 |
+
|
99 |
+
dRGBdx += (
|
100 |
+
SH_C3[0] * sh[9] * 3.f * 2.f * xy +
|
101 |
+
SH_C3[1] * sh[10] * yz +
|
102 |
+
SH_C3[2] * sh[11] * -2.f * xy +
|
103 |
+
SH_C3[3] * sh[12] * -3.f * 2.f * xz +
|
104 |
+
SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
|
105 |
+
SH_C3[5] * sh[14] * 2.f * xz +
|
106 |
+
SH_C3[6] * sh[15] * 3.f * (xx - yy));
|
107 |
+
|
108 |
+
dRGBdy += (
|
109 |
+
SH_C3[0] * sh[9] * 3.f * (xx - yy) +
|
110 |
+
SH_C3[1] * sh[10] * xz +
|
111 |
+
SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
|
112 |
+
SH_C3[3] * sh[12] * -3.f * 2.f * yz +
|
113 |
+
SH_C3[4] * sh[13] * -2.f * xy +
|
114 |
+
SH_C3[5] * sh[14] * -2.f * yz +
|
115 |
+
SH_C3[6] * sh[15] * -3.f * 2.f * xy);
|
116 |
+
|
117 |
+
dRGBdz += (
|
118 |
+
SH_C3[1] * sh[10] * xy +
|
119 |
+
SH_C3[2] * sh[11] * 4.f * 2.f * yz +
|
120 |
+
SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
|
121 |
+
SH_C3[4] * sh[13] * 4.f * 2.f * xz +
|
122 |
+
SH_C3[5] * sh[14] * (xx - yy));
|
123 |
+
}
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
// The view direction is an input to the computation. View direction
|
128 |
+
// is influenced by the Gaussian's mean, so SHs gradients
|
129 |
+
// must propagate back into 3D position.
|
130 |
+
glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));
|
131 |
+
|
132 |
+
// Account for normalization of direction
|
133 |
+
float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });
|
134 |
+
|
135 |
+
// Gradients of loss w.r.t. Gaussian means, but only the portion
|
136 |
+
// that is caused because the mean affects the view-dependent color.
|
137 |
+
// Additional mean gradient is accumulated in below methods.
|
138 |
+
dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
|
139 |
+
}
|
140 |
+
|
141 |
+
// Backward version of INVERSE 2D covariance matrix computation
|
142 |
+
// (due to length launched as separate kernel before other
|
143 |
+
// backward steps contained in preprocess)
|
144 |
+
__global__ void computeCov2DCUDA(int P,
|
145 |
+
const float3* means,
|
146 |
+
const int* radii,
|
147 |
+
const float* cov3Ds,
|
148 |
+
const float h_x, float h_y,
|
149 |
+
const float tan_fovx, float tan_fovy,
|
150 |
+
const float* view_matrix,
|
151 |
+
const float* dL_dconics,
|
152 |
+
float3* dL_dmeans,
|
153 |
+
float* dL_dcov)
|
154 |
+
{
|
155 |
+
auto idx = cg::this_grid().thread_rank();
|
156 |
+
if (idx >= P || !(radii[idx] > 0))
|
157 |
+
return;
|
158 |
+
|
159 |
+
// Reading location of 3D covariance for this Gaussian
|
160 |
+
const float* cov3D = cov3Ds + 6 * idx;
|
161 |
+
|
162 |
+
// Fetch gradients, recompute 2D covariance and relevant
|
163 |
+
// intermediate forward results needed in the backward.
|
164 |
+
float3 mean = means[idx];
|
165 |
+
float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
|
166 |
+
float3 t = transformPoint4x3(mean, view_matrix);
|
167 |
+
|
168 |
+
const float limx = 1.3f * tan_fovx;
|
169 |
+
const float limy = 1.3f * tan_fovy;
|
170 |
+
const float txtz = t.x / t.z;
|
171 |
+
const float tytz = t.y / t.z;
|
172 |
+
t.x = min(limx, max(-limx, txtz)) * t.z;
|
173 |
+
t.y = min(limy, max(-limy, tytz)) * t.z;
|
174 |
+
|
175 |
+
const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
|
176 |
+
const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
|
177 |
+
|
178 |
+
glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
|
179 |
+
0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
|
180 |
+
0, 0, 0);
|
181 |
+
|
182 |
+
glm::mat3 W = glm::mat3(
|
183 |
+
view_matrix[0], view_matrix[4], view_matrix[8],
|
184 |
+
view_matrix[1], view_matrix[5], view_matrix[9],
|
185 |
+
view_matrix[2], view_matrix[6], view_matrix[10]);
|
186 |
+
|
187 |
+
glm::mat3 Vrk = glm::mat3(
|
188 |
+
cov3D[0], cov3D[1], cov3D[2],
|
189 |
+
cov3D[1], cov3D[3], cov3D[4],
|
190 |
+
cov3D[2], cov3D[4], cov3D[5]);
|
191 |
+
|
192 |
+
glm::mat3 T = W * J;
|
193 |
+
|
194 |
+
glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
|
195 |
+
|
196 |
+
// Use helper variables for 2D covariance entries. More compact.
|
197 |
+
float a = cov2D[0][0] += 0.3f;
|
198 |
+
float b = cov2D[0][1];
|
199 |
+
float c = cov2D[1][1] += 0.3f;
|
200 |
+
|
201 |
+
float denom = a * c - b * b;
|
202 |
+
float dL_da = 0, dL_db = 0, dL_dc = 0;
|
203 |
+
float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
|
204 |
+
|
205 |
+
if (denom2inv != 0)
|
206 |
+
{
|
207 |
+
// Gradients of loss w.r.t. entries of 2D covariance matrix,
|
208 |
+
// given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
|
209 |
+
// e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
|
210 |
+
dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z);
|
211 |
+
dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x);
|
212 |
+
dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z);
|
213 |
+
|
214 |
+
// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
|
215 |
+
// given gradients w.r.t. 2D covariance matrix (diagonal).
|
216 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
217 |
+
dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc);
|
218 |
+
dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc);
|
219 |
+
dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc);
|
220 |
+
|
221 |
+
// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry,
|
222 |
+
// given gradients w.r.t. 2D covariance matrix (off-diagonal).
|
223 |
+
// Off-diagonal elements appear twice --> double the gradient.
|
224 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
225 |
+
dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc;
|
226 |
+
dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc;
|
227 |
+
dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc;
|
228 |
+
}
|
229 |
+
else
|
230 |
+
{
|
231 |
+
for (int i = 0; i < 6; i++)
|
232 |
+
dL_dcov[6 * idx + i] = 0;
|
233 |
+
}
|
234 |
+
|
235 |
+
// Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
|
236 |
+
// cov2D = transpose(T) * transpose(Vrk) * T;
|
237 |
+
float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da +
|
238 |
+
(T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
|
239 |
+
float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da +
|
240 |
+
(T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
|
241 |
+
float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da +
|
242 |
+
(T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
|
243 |
+
float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc +
|
244 |
+
(T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
|
245 |
+
float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc +
|
246 |
+
(T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
|
247 |
+
float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc +
|
248 |
+
(T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
|
249 |
+
|
250 |
+
// Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
|
251 |
+
// T = W * J
|
252 |
+
float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
|
253 |
+
float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
|
254 |
+
float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
|
255 |
+
float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
|
256 |
+
|
257 |
+
float tz = 1.f / t.z;
|
258 |
+
float tz2 = tz * tz;
|
259 |
+
float tz3 = tz2 * tz;
|
260 |
+
|
261 |
+
// Gradients of loss w.r.t. transformed Gaussian mean t
|
262 |
+
float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
|
263 |
+
float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
|
264 |
+
float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;
|
265 |
+
|
266 |
+
// Account for transformation of mean to t
|
267 |
+
// t = transformPoint4x3(mean, view_matrix);
|
268 |
+
float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix);
|
269 |
+
|
270 |
+
// Gradients of loss w.r.t. Gaussian means, but only the portion
|
271 |
+
// that is caused because the mean affects the covariance matrix.
|
272 |
+
// Additional mean gradient is accumulated in BACKWARD::preprocess.
|
273 |
+
dL_dmeans[idx] = dL_dmean;
|
274 |
+
}
|
275 |
+
|
276 |
+
// Backward pass for the conversion of scale and rotation to a
|
277 |
+
// 3D covariance matrix for each Gaussian.
|
278 |
+
__device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
|
279 |
+
{
|
280 |
+
// Recompute (intermediate) results for the 3D covariance computation.
|
281 |
+
glm::vec4 q = rot;// / glm::length(rot);
|
282 |
+
float r = q.x;
|
283 |
+
float x = q.y;
|
284 |
+
float y = q.z;
|
285 |
+
float z = q.w;
|
286 |
+
|
287 |
+
glm::mat3 R = glm::mat3(
|
288 |
+
1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
|
289 |
+
2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
|
290 |
+
2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
|
291 |
+
);
|
292 |
+
|
293 |
+
glm::mat3 S = glm::mat3(1.0f);
|
294 |
+
|
295 |
+
glm::vec3 s = mod * scale;
|
296 |
+
S[0][0] = s.x;
|
297 |
+
S[1][1] = s.y;
|
298 |
+
S[2][2] = s.z;
|
299 |
+
|
300 |
+
glm::mat3 M = S * R;
|
301 |
+
|
302 |
+
const float* dL_dcov3D = dL_dcov3Ds + 6 * idx;
|
303 |
+
|
304 |
+
glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
|
305 |
+
glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
|
306 |
+
|
307 |
+
// Convert per-element covariance loss gradients to matrix form
|
308 |
+
glm::mat3 dL_dSigma = glm::mat3(
|
309 |
+
dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
|
310 |
+
0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
|
311 |
+
0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]
|
312 |
+
);
|
313 |
+
|
314 |
+
// Compute loss gradient w.r.t. matrix M
|
315 |
+
// dSigma_dM = 2 * M
|
316 |
+
glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
|
317 |
+
|
318 |
+
glm::mat3 Rt = glm::transpose(R);
|
319 |
+
glm::mat3 dL_dMt = glm::transpose(dL_dM);
|
320 |
+
|
321 |
+
// Gradients of loss w.r.t. scale
|
322 |
+
glm::vec3* dL_dscale = dL_dscales + idx;
|
323 |
+
dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
|
324 |
+
dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
|
325 |
+
dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
|
326 |
+
|
327 |
+
dL_dMt[0] *= s.x;
|
328 |
+
dL_dMt[1] *= s.y;
|
329 |
+
dL_dMt[2] *= s.z;
|
330 |
+
|
331 |
+
// Gradients of loss w.r.t. normalized quaternion
|
332 |
+
glm::vec4 dL_dq;
|
333 |
+
dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
|
334 |
+
dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
|
335 |
+
dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
|
336 |
+
dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
|
337 |
+
|
338 |
+
// Gradients of loss w.r.t. unnormalized quaternion
|
339 |
+
float4* dL_drot = (float4*)(dL_drots + idx);
|
340 |
+
*dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
|
341 |
+
}
|
342 |
+
|
343 |
+
// Backward pass of the preprocessing steps, except
|
344 |
+
// for the covariance computation and inversion
|
345 |
+
// (those are handled by a previous kernel call)
|
346 |
+
template<int C>
|
347 |
+
__global__ void preprocessCUDA(
|
348 |
+
int P, int D, int M,
|
349 |
+
const float3* means,
|
350 |
+
const int* radii,
|
351 |
+
const float* shs,
|
352 |
+
const bool* clamped,
|
353 |
+
const glm::vec3* scales,
|
354 |
+
const glm::vec4* rotations,
|
355 |
+
const float scale_modifier,
|
356 |
+
const float* proj,
|
357 |
+
const glm::vec3* campos,
|
358 |
+
const float3* dL_dmean2D,
|
359 |
+
glm::vec3* dL_dmeans,
|
360 |
+
float* dL_dcolor,
|
361 |
+
float* dL_dcov3D,
|
362 |
+
float* dL_dsh,
|
363 |
+
glm::vec3* dL_dscale,
|
364 |
+
glm::vec4* dL_drot)
|
365 |
+
{
|
366 |
+
auto idx = cg::this_grid().thread_rank();
|
367 |
+
if (idx >= P || !(radii[idx] > 0))
|
368 |
+
return;
|
369 |
+
|
370 |
+
float3 m = means[idx];
|
371 |
+
|
372 |
+
// Taking care of gradients from the screenspace points
|
373 |
+
float4 m_hom = transformPoint4x4(m, proj);
|
374 |
+
float m_w = 1.0f / (m_hom.w + 0.0000001f);
|
375 |
+
|
376 |
+
// Compute loss gradient w.r.t. 3D means due to gradients of 2D means
|
377 |
+
// from rendering procedure
|
378 |
+
glm::vec3 dL_dmean;
|
379 |
+
float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
|
380 |
+
float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
|
381 |
+
dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
|
382 |
+
dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
|
383 |
+
dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
|
384 |
+
|
385 |
+
// That's the second part of the mean gradient. Previous computation
|
386 |
+
// of cov2D and following SH conversion also affects it.
|
387 |
+
dL_dmeans[idx] += dL_dmean;
|
388 |
+
|
389 |
+
// Compute gradient updates due to computing colors from SHs
|
390 |
+
if (shs)
|
391 |
+
computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
|
392 |
+
|
393 |
+
// Compute gradient updates due to computing covariance from scale/rotation
|
394 |
+
if (scales)
|
395 |
+
computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot);
|
396 |
+
}
|
397 |
+
|
398 |
+
// Backward version of the rendering procedure.
|
399 |
+
template <uint32_t C>
|
400 |
+
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
|
401 |
+
renderCUDA(
|
402 |
+
const uint2* __restrict__ ranges,
|
403 |
+
const uint32_t* __restrict__ point_list,
|
404 |
+
int W, int H,
|
405 |
+
const float* __restrict__ bg_color,
|
406 |
+
const float2* __restrict__ points_xy_image,
|
407 |
+
const float4* __restrict__ conic_opacity,
|
408 |
+
const float* __restrict__ colors,
|
409 |
+
const float* __restrict__ final_Ts,
|
410 |
+
const uint32_t* __restrict__ n_contrib,
|
411 |
+
const float* __restrict__ dL_dpixels,
|
412 |
+
float3* __restrict__ dL_dmean2D,
|
413 |
+
float4* __restrict__ dL_dconic2D,
|
414 |
+
float* __restrict__ dL_dopacity,
|
415 |
+
float* __restrict__ dL_dcolors)
|
416 |
+
{
|
417 |
+
// We rasterize again. Compute necessary block info.
|
418 |
+
auto block = cg::this_thread_block();
|
419 |
+
const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
|
420 |
+
const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
|
421 |
+
const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
|
422 |
+
const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
|
423 |
+
const uint32_t pix_id = W * pix.y + pix.x;
|
424 |
+
const float2 pixf = { (float)pix.x, (float)pix.y };
|
425 |
+
|
426 |
+
const bool inside = pix.x < W&& pix.y < H;
|
427 |
+
const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
|
428 |
+
|
429 |
+
const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
430 |
+
|
431 |
+
bool done = !inside;
|
432 |
+
int toDo = range.y - range.x;
|
433 |
+
|
434 |
+
__shared__ int collected_id[BLOCK_SIZE];
|
435 |
+
__shared__ float2 collected_xy[BLOCK_SIZE];
|
436 |
+
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
|
437 |
+
__shared__ float collected_colors[C * BLOCK_SIZE];
|
438 |
+
|
439 |
+
// In the forward, we stored the final value for T, the
|
440 |
+
// product of all (1 - alpha) factors.
|
441 |
+
const float T_final = inside ? final_Ts[pix_id] : 0;
|
442 |
+
float T = T_final;
|
443 |
+
|
444 |
+
// We start from the back. The ID of the last contributing
|
445 |
+
// Gaussian is known from each pixel from the forward.
|
446 |
+
uint32_t contributor = toDo;
|
447 |
+
const int last_contributor = inside ? n_contrib[pix_id] : 0;
|
448 |
+
|
449 |
+
float accum_rec[C] = { 0 };
|
450 |
+
float dL_dpixel[C];
|
451 |
+
if (inside)
|
452 |
+
for (int i = 0; i < C; i++)
|
453 |
+
dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
|
454 |
+
|
455 |
+
float last_alpha = 0;
|
456 |
+
float last_color[C] = { 0 };
|
457 |
+
|
458 |
+
// Gradient of pixel coordinate w.r.t. normalized
|
459 |
+
// screen-space viewport corrdinates (-1 to 1)
|
460 |
+
const float ddelx_dx = 0.5 * W;
|
461 |
+
const float ddely_dy = 0.5 * H;
|
462 |
+
|
463 |
+
// Traverse all Gaussians
|
464 |
+
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
|
465 |
+
{
|
466 |
+
// Load auxiliary data into shared memory, start in the BACK
|
467 |
+
// and load them in revers order.
|
468 |
+
block.sync();
|
469 |
+
const int progress = i * BLOCK_SIZE + block.thread_rank();
|
470 |
+
if (range.x + progress < range.y)
|
471 |
+
{
|
472 |
+
const int coll_id = point_list[range.y - progress - 1];
|
473 |
+
collected_id[block.thread_rank()] = coll_id;
|
474 |
+
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
|
475 |
+
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
|
476 |
+
for (int i = 0; i < C; i++)
|
477 |
+
collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
|
478 |
+
}
|
479 |
+
block.sync();
|
480 |
+
|
481 |
+
// Iterate over Gaussians
|
482 |
+
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
|
483 |
+
{
|
484 |
+
// Keep track of current Gaussian ID. Skip, if this one
|
485 |
+
// is behind the last contributor for this pixel.
|
486 |
+
contributor--;
|
487 |
+
if (contributor >= last_contributor)
|
488 |
+
continue;
|
489 |
+
|
490 |
+
// Compute blending values, as before.
|
491 |
+
const float2 xy = collected_xy[j];
|
492 |
+
const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
|
493 |
+
const float4 con_o = collected_conic_opacity[j];
|
494 |
+
const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
|
495 |
+
if (power > 0.0f)
|
496 |
+
continue;
|
497 |
+
|
498 |
+
const float G = exp(power);
|
499 |
+
const float alpha = min(0.99f, con_o.w * G);
|
500 |
+
if (alpha < 1.0f / 255.0f)
|
501 |
+
continue;
|
502 |
+
|
503 |
+
T = T / (1.f - alpha);
|
504 |
+
const float dchannel_dcolor = alpha * T;
|
505 |
+
|
506 |
+
// Propagate gradients to per-Gaussian colors and keep
|
507 |
+
// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
|
508 |
+
// pair).
|
509 |
+
float dL_dalpha = 0.0f;
|
510 |
+
const int global_id = collected_id[j];
|
511 |
+
for (int ch = 0; ch < C; ch++)
|
512 |
+
{
|
513 |
+
const float c = collected_colors[ch * BLOCK_SIZE + j];
|
514 |
+
// Update last color (to be used in the next iteration)
|
515 |
+
accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
|
516 |
+
last_color[ch] = c;
|
517 |
+
|
518 |
+
const float dL_dchannel = dL_dpixel[ch];
|
519 |
+
dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
|
520 |
+
// Update the gradients w.r.t. color of the Gaussian.
|
521 |
+
// Atomic, since this pixel is just one of potentially
|
522 |
+
// many that were affected by this Gaussian.
|
523 |
+
atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
|
524 |
+
}
|
525 |
+
dL_dalpha *= T;
|
526 |
+
// Update last alpha (to be used in the next iteration)
|
527 |
+
last_alpha = alpha;
|
528 |
+
|
529 |
+
// Account for fact that alpha also influences how much of
|
530 |
+
// the background color is added if nothing left to blend
|
531 |
+
float bg_dot_dpixel = 0;
|
532 |
+
for (int i = 0; i < C; i++)
|
533 |
+
bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
|
534 |
+
dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
|
535 |
+
|
536 |
+
|
537 |
+
// Helpful reusable temporary variables
|
538 |
+
const float dL_dG = con_o.w * dL_dalpha;
|
539 |
+
const float gdx = G * d.x;
|
540 |
+
const float gdy = G * d.y;
|
541 |
+
const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
|
542 |
+
const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
|
543 |
+
|
544 |
+
// Update gradients w.r.t. 2D mean position of the Gaussian
|
545 |
+
atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
|
546 |
+
atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
|
547 |
+
|
548 |
+
// Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
|
549 |
+
atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
|
550 |
+
atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
|
551 |
+
atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
|
552 |
+
|
553 |
+
// Update gradients w.r.t. opacity of the Gaussian
|
554 |
+
atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
|
555 |
+
}
|
556 |
+
}
|
557 |
+
}
|
558 |
+
|
559 |
+
void BACKWARD::preprocess(
|
560 |
+
int P, int D, int M,
|
561 |
+
const float3* means3D,
|
562 |
+
const int* radii,
|
563 |
+
const float* shs,
|
564 |
+
const bool* clamped,
|
565 |
+
const glm::vec3* scales,
|
566 |
+
const glm::vec4* rotations,
|
567 |
+
const float scale_modifier,
|
568 |
+
const float* cov3Ds,
|
569 |
+
const float* viewmatrix,
|
570 |
+
const float* projmatrix,
|
571 |
+
const float focal_x, float focal_y,
|
572 |
+
const float tan_fovx, float tan_fovy,
|
573 |
+
const glm::vec3* campos,
|
574 |
+
const float3* dL_dmean2D,
|
575 |
+
const float* dL_dconic,
|
576 |
+
glm::vec3* dL_dmean3D,
|
577 |
+
float* dL_dcolor,
|
578 |
+
float* dL_dcov3D,
|
579 |
+
float* dL_dsh,
|
580 |
+
glm::vec3* dL_dscale,
|
581 |
+
glm::vec4* dL_drot)
|
582 |
+
{
|
583 |
+
// Propagate gradients for the path of 2D conic matrix computation.
|
584 |
+
// Somewhat long, thus it is its own kernel rather than being part of
|
585 |
+
// "preprocess". When done, loss gradient w.r.t. 3D means has been
|
586 |
+
// modified and gradient w.r.t. 3D covariance matrix has been computed.
|
587 |
+
computeCov2DCUDA << <(P + 255) / 256, 256 >> > (
|
588 |
+
P,
|
589 |
+
means3D,
|
590 |
+
radii,
|
591 |
+
cov3Ds,
|
592 |
+
focal_x,
|
593 |
+
focal_y,
|
594 |
+
tan_fovx,
|
595 |
+
tan_fovy,
|
596 |
+
viewmatrix,
|
597 |
+
dL_dconic,
|
598 |
+
(float3*)dL_dmean3D,
|
599 |
+
dL_dcov3D);
|
600 |
+
|
601 |
+
// Propagate gradients for remaining steps: finish 3D mean gradients,
|
602 |
+
// propagate color gradients to SH (if desireD), propagate 3D covariance
|
603 |
+
// matrix gradients to scale and rotation.
|
604 |
+
preprocessCUDA<NUM_CHANNELS> << < (P + 255) / 256, 256 >> > (
|
605 |
+
P, D, M,
|
606 |
+
(float3*)means3D,
|
607 |
+
radii,
|
608 |
+
shs,
|
609 |
+
clamped,
|
610 |
+
(glm::vec3*)scales,
|
611 |
+
(glm::vec4*)rotations,
|
612 |
+
scale_modifier,
|
613 |
+
projmatrix,
|
614 |
+
campos,
|
615 |
+
(float3*)dL_dmean2D,
|
616 |
+
(glm::vec3*)dL_dmean3D,
|
617 |
+
dL_dcolor,
|
618 |
+
dL_dcov3D,
|
619 |
+
dL_dsh,
|
620 |
+
dL_dscale,
|
621 |
+
dL_drot);
|
622 |
+
}
|
623 |
+
|
624 |
+
void BACKWARD::render(
|
625 |
+
const dim3 grid, const dim3 block,
|
626 |
+
const uint2* ranges,
|
627 |
+
const uint32_t* point_list,
|
628 |
+
int W, int H,
|
629 |
+
const float* bg_color,
|
630 |
+
const float2* means2D,
|
631 |
+
const float4* conic_opacity,
|
632 |
+
const float* colors,
|
633 |
+
const float* final_Ts,
|
634 |
+
const uint32_t* n_contrib,
|
635 |
+
const float* dL_dpixels,
|
636 |
+
float3* dL_dmean2D,
|
637 |
+
float4* dL_dconic2D,
|
638 |
+
float* dL_dopacity,
|
639 |
+
float* dL_dcolors)
|
640 |
+
{
|
641 |
+
renderCUDA<NUM_CHANNELS> << <grid, block >> >(
|
642 |
+
ranges,
|
643 |
+
point_list,
|
644 |
+
W, H,
|
645 |
+
bg_color,
|
646 |
+
means2D,
|
647 |
+
conic_opacity,
|
648 |
+
colors,
|
649 |
+
final_Ts,
|
650 |
+
n_contrib,
|
651 |
+
dL_dpixels,
|
652 |
+
dL_dmean2D,
|
653 |
+
dL_dconic2D,
|
654 |
+
dL_dopacity,
|
655 |
+
dL_dcolors
|
656 |
+
);
|
657 |
+
}
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include "cuda_runtime.h"
|
17 |
+
#include "device_launch_parameters.h"
|
18 |
+
#define GLM_FORCE_CUDA
|
19 |
+
#include <glm/glm.hpp>
|
20 |
+
|
21 |
+
namespace BACKWARD
|
22 |
+
{
|
23 |
+
void render(
|
24 |
+
const dim3 grid, dim3 block,
|
25 |
+
const uint2* ranges,
|
26 |
+
const uint32_t* point_list,
|
27 |
+
int W, int H,
|
28 |
+
const float* bg_color,
|
29 |
+
const float2* means2D,
|
30 |
+
const float4* conic_opacity,
|
31 |
+
const float* colors,
|
32 |
+
const float* final_Ts,
|
33 |
+
const uint32_t* n_contrib,
|
34 |
+
const float* dL_dpixels,
|
35 |
+
float3* dL_dmean2D,
|
36 |
+
float4* dL_dconic2D,
|
37 |
+
float* dL_dopacity,
|
38 |
+
float* dL_dcolors);
|
39 |
+
|
40 |
+
void preprocess(
|
41 |
+
int P, int D, int M,
|
42 |
+
const float3* means,
|
43 |
+
const int* radii,
|
44 |
+
const float* shs,
|
45 |
+
const bool* clamped,
|
46 |
+
const glm::vec3* scales,
|
47 |
+
const glm::vec4* rotations,
|
48 |
+
const float scale_modifier,
|
49 |
+
const float* cov3Ds,
|
50 |
+
const float* view,
|
51 |
+
const float* proj,
|
52 |
+
const float focal_x, float focal_y,
|
53 |
+
const float tan_fovx, float tan_fovy,
|
54 |
+
const glm::vec3* campos,
|
55 |
+
const float3* dL_dmean2D,
|
56 |
+
const float* dL_dconics,
|
57 |
+
glm::vec3* dL_dmeans,
|
58 |
+
float* dL_dcolor,
|
59 |
+
float* dL_dcov3D,
|
60 |
+
float* dL_dsh,
|
61 |
+
glm::vec3* dL_dscale,
|
62 |
+
glm::vec4* dL_drot);
|
63 |
+
}
|
64 |
+
|
65 |
+
#endif
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_CONFIG_H_INCLUDED
|
14 |
+
|
15 |
+
#define NUM_CHANNELS 3 // Default 3, RGB
|
16 |
+
#define BLOCK_X 16
|
17 |
+
#define BLOCK_Y 16
|
18 |
+
|
19 |
+
#endif
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "forward.h"
|
13 |
+
#include "auxiliary.h"
|
14 |
+
#include <cooperative_groups.h>
|
15 |
+
#include <cooperative_groups/reduce.h>
|
16 |
+
namespace cg = cooperative_groups;
|
17 |
+
|
18 |
+
// Forward method for converting the input spherical harmonics
|
19 |
+
// coefficients of each Gaussian to a simple RGB color.
|
20 |
+
__device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
|
21 |
+
{
|
22 |
+
// The implementation is loosely based on code for
|
23 |
+
// "Differentiable Point-Based Radiance Fields for
|
24 |
+
// Efficient View Synthesis" by Zhang et al. (2022)
|
25 |
+
glm::vec3 pos = means[idx];
|
26 |
+
glm::vec3 dir = pos - campos;
|
27 |
+
dir = dir / glm::length(dir);
|
28 |
+
|
29 |
+
glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
|
30 |
+
glm::vec3 result = SH_C0 * sh[0];
|
31 |
+
|
32 |
+
if (deg > 0)
|
33 |
+
{
|
34 |
+
float x = dir.x;
|
35 |
+
float y = dir.y;
|
36 |
+
float z = dir.z;
|
37 |
+
result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
|
38 |
+
|
39 |
+
if (deg > 1)
|
40 |
+
{
|
41 |
+
float xx = x * x, yy = y * y, zz = z * z;
|
42 |
+
float xy = x * y, yz = y * z, xz = x * z;
|
43 |
+
result = result +
|
44 |
+
SH_C2[0] * xy * sh[4] +
|
45 |
+
SH_C2[1] * yz * sh[5] +
|
46 |
+
SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
|
47 |
+
SH_C2[3] * xz * sh[7] +
|
48 |
+
SH_C2[4] * (xx - yy) * sh[8];
|
49 |
+
|
50 |
+
if (deg > 2)
|
51 |
+
{
|
52 |
+
result = result +
|
53 |
+
SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
|
54 |
+
SH_C3[1] * xy * z * sh[10] +
|
55 |
+
SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
|
56 |
+
SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
|
57 |
+
SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
|
58 |
+
SH_C3[5] * z * (xx - yy) * sh[14] +
|
59 |
+
SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
result += 0.5f;
|
64 |
+
|
65 |
+
// RGB colors are clamped to positive values. If values are
|
66 |
+
// clamped, we need to keep track of this for the backward pass.
|
67 |
+
clamped[3 * idx + 0] = (result.x < 0);
|
68 |
+
clamped[3 * idx + 1] = (result.y < 0);
|
69 |
+
clamped[3 * idx + 2] = (result.z < 0);
|
70 |
+
return glm::max(result, 0.0f);
|
71 |
+
}
|
72 |
+
|
73 |
+
// Forward version of 2D covariance matrix computation
|
74 |
+
__device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
|
75 |
+
{
|
76 |
+
// The following models the steps outlined by equations 29
|
77 |
+
// and 31 in "EWA Splatting" (Zwicker et al., 2002).
|
78 |
+
// Additionally considers aspect / scaling of viewport.
|
79 |
+
// Transposes used to account for row-/column-major conventions.
|
80 |
+
float3 t = transformPoint4x3(mean, viewmatrix);
|
81 |
+
|
82 |
+
const float limx = 1.3f * tan_fovx;
|
83 |
+
const float limy = 1.3f * tan_fovy;
|
84 |
+
const float txtz = t.x / t.z;
|
85 |
+
const float tytz = t.y / t.z;
|
86 |
+
t.x = min(limx, max(-limx, txtz)) * t.z;
|
87 |
+
t.y = min(limy, max(-limy, tytz)) * t.z;
|
88 |
+
|
89 |
+
glm::mat3 J = glm::mat3(
|
90 |
+
focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
|
91 |
+
0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
|
92 |
+
0, 0, 0);
|
93 |
+
|
94 |
+
glm::mat3 W = glm::mat3(
|
95 |
+
viewmatrix[0], viewmatrix[4], viewmatrix[8],
|
96 |
+
viewmatrix[1], viewmatrix[5], viewmatrix[9],
|
97 |
+
viewmatrix[2], viewmatrix[6], viewmatrix[10]);
|
98 |
+
|
99 |
+
glm::mat3 T = W * J;
|
100 |
+
|
101 |
+
glm::mat3 Vrk = glm::mat3(
|
102 |
+
cov3D[0], cov3D[1], cov3D[2],
|
103 |
+
cov3D[1], cov3D[3], cov3D[4],
|
104 |
+
cov3D[2], cov3D[4], cov3D[5]);
|
105 |
+
|
106 |
+
glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
|
107 |
+
|
108 |
+
// Apply low-pass filter: every Gaussian should be at least
|
109 |
+
// one pixel wide/high. Discard 3rd row and column.
|
110 |
+
cov[0][0] += 0.3f;
|
111 |
+
cov[1][1] += 0.3f;
|
112 |
+
return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) };
|
113 |
+
}
|
114 |
+
|
115 |
+
// Forward method for converting scale and rotation properties of each
|
116 |
+
// Gaussian to a 3D covariance matrix in world space. Also takes care
|
117 |
+
// of quaternion normalization.
|
118 |
+
__device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
|
119 |
+
{
|
120 |
+
// Create scaling matrix
|
121 |
+
glm::mat3 S = glm::mat3(1.0f);
|
122 |
+
S[0][0] = mod * scale.x;
|
123 |
+
S[1][1] = mod * scale.y;
|
124 |
+
S[2][2] = mod * scale.z;
|
125 |
+
|
126 |
+
// Normalize quaternion to get valid rotation
|
127 |
+
glm::vec4 q = rot;// / glm::length(rot);
|
128 |
+
float r = q.x;
|
129 |
+
float x = q.y;
|
130 |
+
float y = q.z;
|
131 |
+
float z = q.w;
|
132 |
+
|
133 |
+
// Compute rotation matrix from quaternion
|
134 |
+
glm::mat3 R = glm::mat3(
|
135 |
+
1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
|
136 |
+
2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
|
137 |
+
2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
|
138 |
+
);
|
139 |
+
|
140 |
+
glm::mat3 M = S * R;
|
141 |
+
|
142 |
+
// Compute 3D world covariance matrix Sigma
|
143 |
+
glm::mat3 Sigma = glm::transpose(M) * M;
|
144 |
+
|
145 |
+
// Covariance is symmetric, only store upper right
|
146 |
+
cov3D[0] = Sigma[0][0];
|
147 |
+
cov3D[1] = Sigma[0][1];
|
148 |
+
cov3D[2] = Sigma[0][2];
|
149 |
+
cov3D[3] = Sigma[1][1];
|
150 |
+
cov3D[4] = Sigma[1][2];
|
151 |
+
cov3D[5] = Sigma[2][2];
|
152 |
+
}
|
153 |
+
|
154 |
+
// Perform initial steps for each Gaussian prior to rasterization.
|
155 |
+
template<int C>
|
156 |
+
__global__ void preprocessCUDA(int P, int D, int M,
|
157 |
+
const float* orig_points,
|
158 |
+
const glm::vec3* scales,
|
159 |
+
const float scale_modifier,
|
160 |
+
const glm::vec4* rotations,
|
161 |
+
const float* opacities,
|
162 |
+
const float* shs,
|
163 |
+
bool* clamped,
|
164 |
+
const float* cov3D_precomp,
|
165 |
+
const float* colors_precomp,
|
166 |
+
const float* viewmatrix,
|
167 |
+
const float* projmatrix,
|
168 |
+
const glm::vec3* cam_pos,
|
169 |
+
const int W, int H,
|
170 |
+
const float tan_fovx, float tan_fovy,
|
171 |
+
const float focal_x, float focal_y,
|
172 |
+
int* radii,
|
173 |
+
float2* points_xy_image,
|
174 |
+
float* depths,
|
175 |
+
float* cov3Ds,
|
176 |
+
float* rgb,
|
177 |
+
float4* conic_opacity,
|
178 |
+
const dim3 grid,
|
179 |
+
uint32_t* tiles_touched,
|
180 |
+
bool prefiltered)
|
181 |
+
{
|
182 |
+
auto idx = cg::this_grid().thread_rank();
|
183 |
+
if (idx >= P)
|
184 |
+
return;
|
185 |
+
|
186 |
+
// Initialize radius and touched tiles to 0. If this isn't changed,
|
187 |
+
// this Gaussian will not be processed further.
|
188 |
+
radii[idx] = 0;
|
189 |
+
tiles_touched[idx] = 0;
|
190 |
+
|
191 |
+
// Perform near culling, quit if outside.
|
192 |
+
float3 p_view;
|
193 |
+
if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
|
194 |
+
return;
|
195 |
+
|
196 |
+
// Transform point by projecting
|
197 |
+
float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
|
198 |
+
float4 p_hom = transformPoint4x4(p_orig, projmatrix);
|
199 |
+
float p_w = 1.0f / (p_hom.w + 0.0000001f);
|
200 |
+
float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
|
201 |
+
|
202 |
+
// If 3D covariance matrix is precomputed, use it, otherwise compute
|
203 |
+
// from scaling and rotation parameters.
|
204 |
+
const float* cov3D;
|
205 |
+
if (cov3D_precomp != nullptr)
|
206 |
+
{
|
207 |
+
cov3D = cov3D_precomp + idx * 6;
|
208 |
+
}
|
209 |
+
else
|
210 |
+
{
|
211 |
+
computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
|
212 |
+
cov3D = cov3Ds + idx * 6;
|
213 |
+
}
|
214 |
+
|
215 |
+
// Compute 2D screen-space covariance matrix
|
216 |
+
float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
|
217 |
+
|
218 |
+
// Invert covariance (EWA algorithm)
|
219 |
+
float det = (cov.x * cov.z - cov.y * cov.y);
|
220 |
+
if (det == 0.0f)
|
221 |
+
return;
|
222 |
+
float det_inv = 1.f / det;
|
223 |
+
float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
|
224 |
+
|
225 |
+
// Compute extent in screen space (by finding eigenvalues of
|
226 |
+
// 2D covariance matrix). Use extent to compute a bounding rectangle
|
227 |
+
// of screen-space tiles that this Gaussian overlaps with. Quit if
|
228 |
+
// rectangle covers 0 tiles.
|
229 |
+
float mid = 0.5f * (cov.x + cov.z);
|
230 |
+
float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
|
231 |
+
float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
|
232 |
+
float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
|
233 |
+
float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
|
234 |
+
uint2 rect_min, rect_max;
|
235 |
+
getRect(point_image, my_radius, rect_min, rect_max, grid);
|
236 |
+
if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
|
237 |
+
return;
|
238 |
+
|
239 |
+
// If colors have been precomputed, use them, otherwise convert
|
240 |
+
// spherical harmonics coefficients to RGB color.
|
241 |
+
if (colors_precomp == nullptr)
|
242 |
+
{
|
243 |
+
glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
|
244 |
+
rgb[idx * C + 0] = result.x;
|
245 |
+
rgb[idx * C + 1] = result.y;
|
246 |
+
rgb[idx * C + 2] = result.z;
|
247 |
+
}
|
248 |
+
|
249 |
+
// Store some useful helper data for the next steps.
|
250 |
+
depths[idx] = p_view.z;
|
251 |
+
radii[idx] = my_radius;
|
252 |
+
points_xy_image[idx] = point_image;
|
253 |
+
// Inverse 2D covariance and opacity neatly pack into one float4
|
254 |
+
conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] };
|
255 |
+
tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
|
256 |
+
}
|
257 |
+
|
258 |
+
// Main rasterization method. Collaboratively works on one tile per
|
259 |
+
// block, each thread treats one pixel. Alternates between fetching
|
260 |
+
// and rasterizing data.
|
261 |
+
template <uint32_t CHANNELS>
|
262 |
+
__global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
|
263 |
+
renderCUDA(
|
264 |
+
const uint2* __restrict__ ranges,
|
265 |
+
const uint32_t* __restrict__ point_list,
|
266 |
+
int W, int H,
|
267 |
+
const float2* __restrict__ points_xy_image,
|
268 |
+
const float* __restrict__ features,
|
269 |
+
const float4* __restrict__ conic_opacity,
|
270 |
+
float* __restrict__ final_T,
|
271 |
+
uint32_t* __restrict__ n_contrib,
|
272 |
+
const float* __restrict__ bg_color,
|
273 |
+
float* __restrict__ out_color)
|
274 |
+
{
|
275 |
+
// Identify current tile and associated min/max pixel range.
|
276 |
+
auto block = cg::this_thread_block();
|
277 |
+
uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
|
278 |
+
uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
|
279 |
+
uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
|
280 |
+
uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
|
281 |
+
uint32_t pix_id = W * pix.y + pix.x;
|
282 |
+
float2 pixf = { (float)pix.x, (float)pix.y };
|
283 |
+
|
284 |
+
// Check if this thread is associated with a valid pixel or outside.
|
285 |
+
bool inside = pix.x < W&& pix.y < H;
|
286 |
+
// Done threads can help with fetching, but don't rasterize
|
287 |
+
bool done = !inside;
|
288 |
+
|
289 |
+
// Load start/end range of IDs to process in bit sorted list.
|
290 |
+
uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
|
291 |
+
const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
|
292 |
+
int toDo = range.y - range.x;
|
293 |
+
|
294 |
+
// Allocate storage for batches of collectively fetched data.
|
295 |
+
__shared__ int collected_id[BLOCK_SIZE];
|
296 |
+
__shared__ float2 collected_xy[BLOCK_SIZE];
|
297 |
+
__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
|
298 |
+
|
299 |
+
// Initialize helper variables
|
300 |
+
float T = 1.0f;
|
301 |
+
uint32_t contributor = 0;
|
302 |
+
uint32_t last_contributor = 0;
|
303 |
+
float C[CHANNELS] = { 0 };
|
304 |
+
|
305 |
+
// Iterate over batches until all done or range is complete
|
306 |
+
for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
|
307 |
+
{
|
308 |
+
// End if entire block votes that it is done rasterizing
|
309 |
+
int num_done = __syncthreads_count(done);
|
310 |
+
if (num_done == BLOCK_SIZE)
|
311 |
+
break;
|
312 |
+
|
313 |
+
// Collectively fetch per-Gaussian data from global to shared
|
314 |
+
int progress = i * BLOCK_SIZE + block.thread_rank();
|
315 |
+
if (range.x + progress < range.y)
|
316 |
+
{
|
317 |
+
int coll_id = point_list[range.x + progress];
|
318 |
+
collected_id[block.thread_rank()] = coll_id;
|
319 |
+
collected_xy[block.thread_rank()] = points_xy_image[coll_id];
|
320 |
+
collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
|
321 |
+
}
|
322 |
+
block.sync();
|
323 |
+
|
324 |
+
// Iterate over current batch
|
325 |
+
for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
|
326 |
+
{
|
327 |
+
// Keep track of current position in range
|
328 |
+
contributor++;
|
329 |
+
|
330 |
+
// Resample using conic matrix (cf. "Surface
|
331 |
+
// Splatting" by Zwicker et al., 2001)
|
332 |
+
float2 xy = collected_xy[j];
|
333 |
+
float2 d = { xy.x - pixf.x, xy.y - pixf.y };
|
334 |
+
float4 con_o = collected_conic_opacity[j];
|
335 |
+
float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
|
336 |
+
if (power > 0.0f)
|
337 |
+
continue;
|
338 |
+
|
339 |
+
// Eq. (2) from 3D Gaussian splatting paper.
|
340 |
+
// Obtain alpha by multiplying with Gaussian opacity
|
341 |
+
// and its exponential falloff from mean.
|
342 |
+
// Avoid numerical instabilities (see paper appendix).
|
343 |
+
float alpha = min(0.99f, con_o.w * exp(power));
|
344 |
+
if (alpha < 1.0f / 255.0f)
|
345 |
+
continue;
|
346 |
+
float test_T = T * (1 - alpha);
|
347 |
+
if (test_T < 0.0001f)
|
348 |
+
{
|
349 |
+
done = true;
|
350 |
+
continue;
|
351 |
+
}
|
352 |
+
|
353 |
+
// Eq. (3) from 3D Gaussian splatting paper.
|
354 |
+
for (int ch = 0; ch < CHANNELS; ch++)
|
355 |
+
C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
|
356 |
+
|
357 |
+
T = test_T;
|
358 |
+
|
359 |
+
// Keep track of last range entry to update this
|
360 |
+
// pixel.
|
361 |
+
last_contributor = contributor;
|
362 |
+
}
|
363 |
+
}
|
364 |
+
|
365 |
+
// All threads that treat valid pixel write out their final
|
366 |
+
// rendering data to the frame and auxiliary buffers.
|
367 |
+
if (inside)
|
368 |
+
{
|
369 |
+
final_T[pix_id] = T;
|
370 |
+
n_contrib[pix_id] = last_contributor;
|
371 |
+
for (int ch = 0; ch < CHANNELS; ch++)
|
372 |
+
out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
|
373 |
+
}
|
374 |
+
}
|
375 |
+
|
376 |
+
void FORWARD::render(
|
377 |
+
const dim3 grid, dim3 block,
|
378 |
+
const uint2* ranges,
|
379 |
+
const uint32_t* point_list,
|
380 |
+
int W, int H,
|
381 |
+
const float2* means2D,
|
382 |
+
const float* colors,
|
383 |
+
const float4* conic_opacity,
|
384 |
+
float* final_T,
|
385 |
+
uint32_t* n_contrib,
|
386 |
+
const float* bg_color,
|
387 |
+
float* out_color)
|
388 |
+
{
|
389 |
+
renderCUDA<NUM_CHANNELS> << <grid, block >> > (
|
390 |
+
ranges,
|
391 |
+
point_list,
|
392 |
+
W, H,
|
393 |
+
means2D,
|
394 |
+
colors,
|
395 |
+
conic_opacity,
|
396 |
+
final_T,
|
397 |
+
n_contrib,
|
398 |
+
bg_color,
|
399 |
+
out_color);
|
400 |
+
}
|
401 |
+
|
402 |
+
void FORWARD::preprocess(int P, int D, int M,
|
403 |
+
const float* means3D,
|
404 |
+
const glm::vec3* scales,
|
405 |
+
const float scale_modifier,
|
406 |
+
const glm::vec4* rotations,
|
407 |
+
const float* opacities,
|
408 |
+
const float* shs,
|
409 |
+
bool* clamped,
|
410 |
+
const float* cov3D_precomp,
|
411 |
+
const float* colors_precomp,
|
412 |
+
const float* viewmatrix,
|
413 |
+
const float* projmatrix,
|
414 |
+
const glm::vec3* cam_pos,
|
415 |
+
const int W, int H,
|
416 |
+
const float focal_x, float focal_y,
|
417 |
+
const float tan_fovx, float tan_fovy,
|
418 |
+
int* radii,
|
419 |
+
float2* means2D,
|
420 |
+
float* depths,
|
421 |
+
float* cov3Ds,
|
422 |
+
float* rgb,
|
423 |
+
float4* conic_opacity,
|
424 |
+
const dim3 grid,
|
425 |
+
uint32_t* tiles_touched,
|
426 |
+
bool prefiltered)
|
427 |
+
{
|
428 |
+
preprocessCUDA<NUM_CHANNELS> << <(P + 255) / 256, 256 >> > (
|
429 |
+
P, D, M,
|
430 |
+
means3D,
|
431 |
+
scales,
|
432 |
+
scale_modifier,
|
433 |
+
rotations,
|
434 |
+
opacities,
|
435 |
+
shs,
|
436 |
+
clamped,
|
437 |
+
cov3D_precomp,
|
438 |
+
colors_precomp,
|
439 |
+
viewmatrix,
|
440 |
+
projmatrix,
|
441 |
+
cam_pos,
|
442 |
+
W, H,
|
443 |
+
tan_fovx, tan_fovy,
|
444 |
+
focal_x, focal_y,
|
445 |
+
radii,
|
446 |
+
means2D,
|
447 |
+
depths,
|
448 |
+
cov3Ds,
|
449 |
+
rgb,
|
450 |
+
conic_opacity,
|
451 |
+
grid,
|
452 |
+
tiles_touched,
|
453 |
+
prefiltered
|
454 |
+
);
|
455 |
+
}
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_FORWARD_H_INCLUDED
|
14 |
+
|
15 |
+
#include <cuda.h>
|
16 |
+
#include "cuda_runtime.h"
|
17 |
+
#include "device_launch_parameters.h"
|
18 |
+
#define GLM_FORCE_CUDA
|
19 |
+
#include <glm/glm.hpp>
|
20 |
+
|
21 |
+
namespace FORWARD
|
22 |
+
{
|
23 |
+
// Perform initial steps for each Gaussian prior to rasterization.
|
24 |
+
void preprocess(int P, int D, int M,
|
25 |
+
const float* orig_points,
|
26 |
+
const glm::vec3* scales,
|
27 |
+
const float scale_modifier,
|
28 |
+
const glm::vec4* rotations,
|
29 |
+
const float* opacities,
|
30 |
+
const float* shs,
|
31 |
+
bool* clamped,
|
32 |
+
const float* cov3D_precomp,
|
33 |
+
const float* colors_precomp,
|
34 |
+
const float* viewmatrix,
|
35 |
+
const float* projmatrix,
|
36 |
+
const glm::vec3* cam_pos,
|
37 |
+
const int W, int H,
|
38 |
+
const float focal_x, float focal_y,
|
39 |
+
const float tan_fovx, float tan_fovy,
|
40 |
+
int* radii,
|
41 |
+
float2* points_xy_image,
|
42 |
+
float* depths,
|
43 |
+
float* cov3Ds,
|
44 |
+
float* colors,
|
45 |
+
float4* conic_opacity,
|
46 |
+
const dim3 grid,
|
47 |
+
uint32_t* tiles_touched,
|
48 |
+
bool prefiltered);
|
49 |
+
|
50 |
+
// Main rasterization method.
|
51 |
+
void render(
|
52 |
+
const dim3 grid, dim3 block,
|
53 |
+
const uint2* ranges,
|
54 |
+
const uint32_t* point_list,
|
55 |
+
int W, int H,
|
56 |
+
const float2* points_xy_image,
|
57 |
+
const float* features,
|
58 |
+
const float4* conic_opacity,
|
59 |
+
float* final_T,
|
60 |
+
uint32_t* n_contrib,
|
61 |
+
const float* bg_color,
|
62 |
+
float* out_color);
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
#endif
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifndef CUDA_RASTERIZER_H_INCLUDED
|
13 |
+
#define CUDA_RASTERIZER_H_INCLUDED
|
14 |
+
|
15 |
+
#include <vector>
|
16 |
+
#include <functional>
|
17 |
+
|
18 |
+
namespace CudaRasterizer
|
19 |
+
{
|
20 |
+
class Rasterizer
|
21 |
+
{
|
22 |
+
public:
|
23 |
+
|
24 |
+
static void markVisible(
|
25 |
+
int P,
|
26 |
+
float* means3D,
|
27 |
+
float* viewmatrix,
|
28 |
+
float* projmatrix,
|
29 |
+
bool* present);
|
30 |
+
|
31 |
+
static int forward(
|
32 |
+
std::function<char* (size_t)> geometryBuffer,
|
33 |
+
std::function<char* (size_t)> binningBuffer,
|
34 |
+
std::function<char* (size_t)> imageBuffer,
|
35 |
+
const int P, int D, int M,
|
36 |
+
const float* background,
|
37 |
+
const int width, int height,
|
38 |
+
const float* means3D,
|
39 |
+
const float* shs,
|
40 |
+
const float* colors_precomp,
|
41 |
+
const float* opacities,
|
42 |
+
const float* scales,
|
43 |
+
const float scale_modifier,
|
44 |
+
const float* rotations,
|
45 |
+
const float* cov3D_precomp,
|
46 |
+
const float* viewmatrix,
|
47 |
+
const float* projmatrix,
|
48 |
+
const float* cam_pos,
|
49 |
+
const float tan_fovx, float tan_fovy,
|
50 |
+
const bool prefiltered,
|
51 |
+
float* out_color,
|
52 |
+
int* radii = nullptr,
|
53 |
+
bool debug = false);
|
54 |
+
|
55 |
+
static void backward(
|
56 |
+
const int P, int D, int M, int R,
|
57 |
+
const float* background,
|
58 |
+
const int width, int height,
|
59 |
+
const float* means3D,
|
60 |
+
const float* shs,
|
61 |
+
const float* colors_precomp,
|
62 |
+
const float* scales,
|
63 |
+
const float scale_modifier,
|
64 |
+
const float* rotations,
|
65 |
+
const float* cov3D_precomp,
|
66 |
+
const float* viewmatrix,
|
67 |
+
const float* projmatrix,
|
68 |
+
const float* campos,
|
69 |
+
const float tan_fovx, float tan_fovy,
|
70 |
+
const int* radii,
|
71 |
+
char* geom_buffer,
|
72 |
+
char* binning_buffer,
|
73 |
+
char* image_buffer,
|
74 |
+
const float* dL_dpix,
|
75 |
+
float* dL_dmean2D,
|
76 |
+
float* dL_dconic,
|
77 |
+
float* dL_dopacity,
|
78 |
+
float* dL_dcolor,
|
79 |
+
float* dL_dmean3D,
|
80 |
+
float* dL_dcov3D,
|
81 |
+
float* dL_dsh,
|
82 |
+
float* dL_dscale,
|
83 |
+
float* dL_drot,
|
84 |
+
bool debug);
|
85 |
+
};
|
86 |
+
};
|
87 |
+
|
88 |
+
#endif
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "rasterizer_impl.h"
|
13 |
+
#include <iostream>
|
14 |
+
#include <fstream>
|
15 |
+
#include <algorithm>
|
16 |
+
#include <numeric>
|
17 |
+
#include <cuda.h>
|
18 |
+
#include "cuda_runtime.h"
|
19 |
+
#include "device_launch_parameters.h"
|
20 |
+
#include <cub/cub.cuh>
|
21 |
+
#include <cub/device/device_radix_sort.cuh>
|
22 |
+
#define GLM_FORCE_CUDA
|
23 |
+
#include <glm/glm.hpp>
|
24 |
+
|
25 |
+
#include <cooperative_groups.h>
|
26 |
+
#include <cooperative_groups/reduce.h>
|
27 |
+
namespace cg = cooperative_groups;
|
28 |
+
|
29 |
+
#include "auxiliary.h"
|
30 |
+
#include "forward.h"
|
31 |
+
#include "backward.h"
|
32 |
+
|
33 |
+
// Helper function to find the next-highest bit of the MSB
|
34 |
+
// on the CPU.
|
35 |
+
uint32_t getHigherMsb(uint32_t n)
|
36 |
+
{
|
37 |
+
uint32_t msb = sizeof(n) * 4;
|
38 |
+
uint32_t step = msb;
|
39 |
+
while (step > 1)
|
40 |
+
{
|
41 |
+
step /= 2;
|
42 |
+
if (n >> msb)
|
43 |
+
msb += step;
|
44 |
+
else
|
45 |
+
msb -= step;
|
46 |
+
}
|
47 |
+
if (n >> msb)
|
48 |
+
msb++;
|
49 |
+
return msb;
|
50 |
+
}
|
51 |
+
|
52 |
+
// Wrapper method to call auxiliary coarse frustum containment test.
|
53 |
+
// Mark all Gaussians that pass it.
|
54 |
+
__global__ void checkFrustum(int P,
|
55 |
+
const float* orig_points,
|
56 |
+
const float* viewmatrix,
|
57 |
+
const float* projmatrix,
|
58 |
+
bool* present)
|
59 |
+
{
|
60 |
+
auto idx = cg::this_grid().thread_rank();
|
61 |
+
if (idx >= P)
|
62 |
+
return;
|
63 |
+
|
64 |
+
float3 p_view;
|
65 |
+
present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
|
66 |
+
}
|
67 |
+
|
68 |
+
// Generates one key/value pair for all Gaussian / tile overlaps.
|
69 |
+
// Run once per Gaussian (1:N mapping).
|
70 |
+
__global__ void duplicateWithKeys(
|
71 |
+
int P,
|
72 |
+
const float2* points_xy,
|
73 |
+
const float* depths,
|
74 |
+
const uint32_t* offsets,
|
75 |
+
uint64_t* gaussian_keys_unsorted,
|
76 |
+
uint32_t* gaussian_values_unsorted,
|
77 |
+
int* radii,
|
78 |
+
dim3 grid)
|
79 |
+
{
|
80 |
+
auto idx = cg::this_grid().thread_rank();
|
81 |
+
if (idx >= P)
|
82 |
+
return;
|
83 |
+
|
84 |
+
// Generate no key/value pair for invisible Gaussians
|
85 |
+
if (radii[idx] > 0)
|
86 |
+
{
|
87 |
+
// Find this Gaussian's offset in buffer for writing keys/values.
|
88 |
+
uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
|
89 |
+
uint2 rect_min, rect_max;
|
90 |
+
|
91 |
+
getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
|
92 |
+
|
93 |
+
// For each tile that the bounding rect overlaps, emit a
|
94 |
+
// key/value pair. The key is | tile ID | depth |,
|
95 |
+
// and the value is the ID of the Gaussian. Sorting the values
|
96 |
+
// with this key yields Gaussian IDs in a list, such that they
|
97 |
+
// are first sorted by tile and then by depth.
|
98 |
+
for (int y = rect_min.y; y < rect_max.y; y++)
|
99 |
+
{
|
100 |
+
for (int x = rect_min.x; x < rect_max.x; x++)
|
101 |
+
{
|
102 |
+
uint64_t key = y * grid.x + x;
|
103 |
+
key <<= 32;
|
104 |
+
key |= *((uint32_t*)&depths[idx]);
|
105 |
+
gaussian_keys_unsorted[off] = key;
|
106 |
+
gaussian_values_unsorted[off] = idx;
|
107 |
+
off++;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
}
|
111 |
+
}
|
112 |
+
|
113 |
+
// Check keys to see if it is at the start/end of one tile's range in
|
114 |
+
// the full sorted list. If yes, write start/end of this tile.
|
115 |
+
// Run once per instanced (duplicated) Gaussian ID.
|
116 |
+
__global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges)
|
117 |
+
{
|
118 |
+
auto idx = cg::this_grid().thread_rank();
|
119 |
+
if (idx >= L)
|
120 |
+
return;
|
121 |
+
|
122 |
+
// Read tile ID from key. Update start/end of tile range if at limit.
|
123 |
+
uint64_t key = point_list_keys[idx];
|
124 |
+
uint32_t currtile = key >> 32;
|
125 |
+
if (idx == 0)
|
126 |
+
ranges[currtile].x = 0;
|
127 |
+
else
|
128 |
+
{
|
129 |
+
uint32_t prevtile = point_list_keys[idx - 1] >> 32;
|
130 |
+
if (currtile != prevtile)
|
131 |
+
{
|
132 |
+
ranges[prevtile].y = idx;
|
133 |
+
ranges[currtile].x = idx;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
if (idx == L - 1)
|
137 |
+
ranges[currtile].y = L;
|
138 |
+
}
|
139 |
+
|
140 |
+
// Mark Gaussians as visible/invisible, based on view frustum testing
|
141 |
+
void CudaRasterizer::Rasterizer::markVisible(
|
142 |
+
int P,
|
143 |
+
float* means3D,
|
144 |
+
float* viewmatrix,
|
145 |
+
float* projmatrix,
|
146 |
+
bool* present)
|
147 |
+
{
|
148 |
+
checkFrustum << <(P + 255) / 256, 256 >> > (
|
149 |
+
P,
|
150 |
+
means3D,
|
151 |
+
viewmatrix, projmatrix,
|
152 |
+
present);
|
153 |
+
}
|
154 |
+
|
155 |
+
CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P)
|
156 |
+
{
|
157 |
+
GeometryState geom;
|
158 |
+
obtain(chunk, geom.depths, P, 128);
|
159 |
+
obtain(chunk, geom.clamped, P * 3, 128);
|
160 |
+
obtain(chunk, geom.internal_radii, P, 128);
|
161 |
+
obtain(chunk, geom.means2D, P, 128);
|
162 |
+
obtain(chunk, geom.cov3D, P * 6, 128);
|
163 |
+
obtain(chunk, geom.conic_opacity, P, 128);
|
164 |
+
obtain(chunk, geom.rgb, P * 3, 128);
|
165 |
+
obtain(chunk, geom.tiles_touched, P, 128);
|
166 |
+
cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
|
167 |
+
obtain(chunk, geom.scanning_space, geom.scan_size, 128);
|
168 |
+
obtain(chunk, geom.point_offsets, P, 128);
|
169 |
+
return geom;
|
170 |
+
}
|
171 |
+
|
172 |
+
CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N)
|
173 |
+
{
|
174 |
+
ImageState img;
|
175 |
+
obtain(chunk, img.accum_alpha, N, 128);
|
176 |
+
obtain(chunk, img.n_contrib, N, 128);
|
177 |
+
obtain(chunk, img.ranges, N, 128);
|
178 |
+
return img;
|
179 |
+
}
|
180 |
+
|
181 |
+
CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P)
|
182 |
+
{
|
183 |
+
BinningState binning;
|
184 |
+
obtain(chunk, binning.point_list, P, 128);
|
185 |
+
obtain(chunk, binning.point_list_unsorted, P, 128);
|
186 |
+
obtain(chunk, binning.point_list_keys, P, 128);
|
187 |
+
obtain(chunk, binning.point_list_keys_unsorted, P, 128);
|
188 |
+
cub::DeviceRadixSort::SortPairs(
|
189 |
+
nullptr, binning.sorting_size,
|
190 |
+
binning.point_list_keys_unsorted, binning.point_list_keys,
|
191 |
+
binning.point_list_unsorted, binning.point_list, P);
|
192 |
+
obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
|
193 |
+
return binning;
|
194 |
+
}
|
195 |
+
|
196 |
+
// Forward rendering procedure for differentiable rasterization
|
197 |
+
// of Gaussians.
|
198 |
+
int CudaRasterizer::Rasterizer::forward(
|
199 |
+
std::function<char* (size_t)> geometryBuffer,
|
200 |
+
std::function<char* (size_t)> binningBuffer,
|
201 |
+
std::function<char* (size_t)> imageBuffer,
|
202 |
+
const int P, int D, int M,
|
203 |
+
const float* background,
|
204 |
+
const int width, int height,
|
205 |
+
const float* means3D,
|
206 |
+
const float* shs,
|
207 |
+
const float* colors_precomp,
|
208 |
+
const float* opacities,
|
209 |
+
const float* scales,
|
210 |
+
const float scale_modifier,
|
211 |
+
const float* rotations,
|
212 |
+
const float* cov3D_precomp,
|
213 |
+
const float* viewmatrix,
|
214 |
+
const float* projmatrix,
|
215 |
+
const float* cam_pos,
|
216 |
+
const float tan_fovx, float tan_fovy,
|
217 |
+
const bool prefiltered,
|
218 |
+
float* out_color,
|
219 |
+
int* radii,
|
220 |
+
bool debug)
|
221 |
+
{
|
222 |
+
const float focal_y = height / (2.0f * tan_fovy);
|
223 |
+
const float focal_x = width / (2.0f * tan_fovx);
|
224 |
+
|
225 |
+
size_t chunk_size = required<GeometryState>(P);
|
226 |
+
char* chunkptr = geometryBuffer(chunk_size);
|
227 |
+
GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
|
228 |
+
|
229 |
+
if (radii == nullptr)
|
230 |
+
{
|
231 |
+
radii = geomState.internal_radii;
|
232 |
+
}
|
233 |
+
|
234 |
+
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
|
235 |
+
dim3 block(BLOCK_X, BLOCK_Y, 1);
|
236 |
+
|
237 |
+
// Dynamically resize image-based auxiliary buffers during training
|
238 |
+
size_t img_chunk_size = required<ImageState>(width * height);
|
239 |
+
char* img_chunkptr = imageBuffer(img_chunk_size);
|
240 |
+
ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
|
241 |
+
|
242 |
+
if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
|
243 |
+
{
|
244 |
+
throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!");
|
245 |
+
}
|
246 |
+
|
247 |
+
// Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
|
248 |
+
CHECK_CUDA(FORWARD::preprocess(
|
249 |
+
P, D, M,
|
250 |
+
means3D,
|
251 |
+
(glm::vec3*)scales,
|
252 |
+
scale_modifier,
|
253 |
+
(glm::vec4*)rotations,
|
254 |
+
opacities,
|
255 |
+
shs,
|
256 |
+
geomState.clamped,
|
257 |
+
cov3D_precomp,
|
258 |
+
colors_precomp,
|
259 |
+
viewmatrix, projmatrix,
|
260 |
+
(glm::vec3*)cam_pos,
|
261 |
+
width, height,
|
262 |
+
focal_x, focal_y,
|
263 |
+
tan_fovx, tan_fovy,
|
264 |
+
radii,
|
265 |
+
geomState.means2D,
|
266 |
+
geomState.depths,
|
267 |
+
geomState.cov3D,
|
268 |
+
geomState.rgb,
|
269 |
+
geomState.conic_opacity,
|
270 |
+
tile_grid,
|
271 |
+
geomState.tiles_touched,
|
272 |
+
prefiltered
|
273 |
+
), debug)
|
274 |
+
|
275 |
+
// Compute prefix sum over full list of touched tile counts by Gaussians
|
276 |
+
// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
|
277 |
+
CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug)
|
278 |
+
|
279 |
+
// Retrieve total number of Gaussian instances to launch and resize aux buffers
|
280 |
+
int num_rendered;
|
281 |
+
CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug);
|
282 |
+
|
283 |
+
size_t binning_chunk_size = required<BinningState>(num_rendered);
|
284 |
+
char* binning_chunkptr = binningBuffer(binning_chunk_size);
|
285 |
+
BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
|
286 |
+
|
287 |
+
// For each instance to be rendered, produce adequate [ tile | depth ] key
|
288 |
+
// and corresponding dublicated Gaussian indices to be sorted
|
289 |
+
duplicateWithKeys << <(P + 255) / 256, 256 >> > (
|
290 |
+
P,
|
291 |
+
geomState.means2D,
|
292 |
+
geomState.depths,
|
293 |
+
geomState.point_offsets,
|
294 |
+
binningState.point_list_keys_unsorted,
|
295 |
+
binningState.point_list_unsorted,
|
296 |
+
radii,
|
297 |
+
tile_grid)
|
298 |
+
CHECK_CUDA(, debug)
|
299 |
+
|
300 |
+
int bit = getHigherMsb(tile_grid.x * tile_grid.y);
|
301 |
+
|
302 |
+
// Sort complete list of (duplicated) Gaussian indices by keys
|
303 |
+
CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
|
304 |
+
binningState.list_sorting_space,
|
305 |
+
binningState.sorting_size,
|
306 |
+
binningState.point_list_keys_unsorted, binningState.point_list_keys,
|
307 |
+
binningState.point_list_unsorted, binningState.point_list,
|
308 |
+
num_rendered, 0, 32 + bit), debug)
|
309 |
+
|
310 |
+
CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug);
|
311 |
+
|
312 |
+
// Identify start and end of per-tile workloads in sorted list
|
313 |
+
if (num_rendered > 0)
|
314 |
+
identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
|
315 |
+
num_rendered,
|
316 |
+
binningState.point_list_keys,
|
317 |
+
imgState.ranges);
|
318 |
+
CHECK_CUDA(, debug)
|
319 |
+
|
320 |
+
// Let each tile blend its range of Gaussians independently in parallel
|
321 |
+
const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
|
322 |
+
CHECK_CUDA(FORWARD::render(
|
323 |
+
tile_grid, block,
|
324 |
+
imgState.ranges,
|
325 |
+
binningState.point_list,
|
326 |
+
width, height,
|
327 |
+
geomState.means2D,
|
328 |
+
feature_ptr,
|
329 |
+
geomState.conic_opacity,
|
330 |
+
imgState.accum_alpha,
|
331 |
+
imgState.n_contrib,
|
332 |
+
background,
|
333 |
+
out_color), debug)
|
334 |
+
|
335 |
+
return num_rendered;
|
336 |
+
}
|
337 |
+
|
338 |
+
// Produce necessary gradients for optimization, corresponding
|
339 |
+
// to forward render pass
|
340 |
+
void CudaRasterizer::Rasterizer::backward(
|
341 |
+
const int P, int D, int M, int R,
|
342 |
+
const float* background,
|
343 |
+
const int width, int height,
|
344 |
+
const float* means3D,
|
345 |
+
const float* shs,
|
346 |
+
const float* colors_precomp,
|
347 |
+
const float* scales,
|
348 |
+
const float scale_modifier,
|
349 |
+
const float* rotations,
|
350 |
+
const float* cov3D_precomp,
|
351 |
+
const float* viewmatrix,
|
352 |
+
const float* projmatrix,
|
353 |
+
const float* campos,
|
354 |
+
const float tan_fovx, float tan_fovy,
|
355 |
+
const int* radii,
|
356 |
+
char* geom_buffer,
|
357 |
+
char* binning_buffer,
|
358 |
+
char* img_buffer,
|
359 |
+
const float* dL_dpix,
|
360 |
+
float* dL_dmean2D,
|
361 |
+
float* dL_dconic,
|
362 |
+
float* dL_dopacity,
|
363 |
+
float* dL_dcolor,
|
364 |
+
float* dL_dmean3D,
|
365 |
+
float* dL_dcov3D,
|
366 |
+
float* dL_dsh,
|
367 |
+
float* dL_dscale,
|
368 |
+
float* dL_drot,
|
369 |
+
bool debug)
|
370 |
+
{
|
371 |
+
GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
|
372 |
+
BinningState binningState = BinningState::fromChunk(binning_buffer, R);
|
373 |
+
ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
|
374 |
+
|
375 |
+
if (radii == nullptr)
|
376 |
+
{
|
377 |
+
radii = geomState.internal_radii;
|
378 |
+
}
|
379 |
+
|
380 |
+
const float focal_y = height / (2.0f * tan_fovy);
|
381 |
+
const float focal_x = width / (2.0f * tan_fovx);
|
382 |
+
|
383 |
+
const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
|
384 |
+
const dim3 block(BLOCK_X, BLOCK_Y, 1);
|
385 |
+
|
386 |
+
// Compute loss gradients w.r.t. 2D mean position, conic matrix,
|
387 |
+
// opacity and RGB of Gaussians from per-pixel loss gradients.
|
388 |
+
// If we were given precomputed colors and not SHs, use them.
|
389 |
+
const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
|
390 |
+
CHECK_CUDA(BACKWARD::render(
|
391 |
+
tile_grid,
|
392 |
+
block,
|
393 |
+
imgState.ranges,
|
394 |
+
binningState.point_list,
|
395 |
+
width, height,
|
396 |
+
background,
|
397 |
+
geomState.means2D,
|
398 |
+
geomState.conic_opacity,
|
399 |
+
color_ptr,
|
400 |
+
imgState.accum_alpha,
|
401 |
+
imgState.n_contrib,
|
402 |
+
dL_dpix,
|
403 |
+
(float3*)dL_dmean2D,
|
404 |
+
(float4*)dL_dconic,
|
405 |
+
dL_dopacity,
|
406 |
+
dL_dcolor), debug)
|
407 |
+
|
408 |
+
// Take care of the rest of preprocessing. Was the precomputed covariance
|
409 |
+
// given to us or a scales/rot pair? If precomputed, pass that. If not,
|
410 |
+
// use the one we computed ourselves.
|
411 |
+
const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
|
412 |
+
CHECK_CUDA(BACKWARD::preprocess(P, D, M,
|
413 |
+
(float3*)means3D,
|
414 |
+
radii,
|
415 |
+
shs,
|
416 |
+
geomState.clamped,
|
417 |
+
(glm::vec3*)scales,
|
418 |
+
(glm::vec4*)rotations,
|
419 |
+
scale_modifier,
|
420 |
+
cov3D_ptr,
|
421 |
+
viewmatrix,
|
422 |
+
projmatrix,
|
423 |
+
focal_x, focal_y,
|
424 |
+
tan_fovx, tan_fovy,
|
425 |
+
(glm::vec3*)campos,
|
426 |
+
(float3*)dL_dmean2D,
|
427 |
+
dL_dconic,
|
428 |
+
(glm::vec3*)dL_dmean3D,
|
429 |
+
dL_dcolor,
|
430 |
+
dL_dcov3D,
|
431 |
+
dL_dsh,
|
432 |
+
(glm::vec3*)dL_dscale,
|
433 |
+
(glm::vec4*)dL_drot), debug)
|
434 |
+
}
|
submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include <iostream>
|
15 |
+
#include <vector>
|
16 |
+
#include "rasterizer.h"
|
17 |
+
#include <cuda_runtime_api.h>
|
18 |
+
|
19 |
+
namespace CudaRasterizer
|
20 |
+
{
|
21 |
+
template <typename T>
|
22 |
+
static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
|
23 |
+
{
|
24 |
+
std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
|
25 |
+
ptr = reinterpret_cast<T*>(offset);
|
26 |
+
chunk = reinterpret_cast<char*>(ptr + count);
|
27 |
+
}
|
28 |
+
|
29 |
+
struct GeometryState
|
30 |
+
{
|
31 |
+
size_t scan_size;
|
32 |
+
float* depths;
|
33 |
+
char* scanning_space;
|
34 |
+
bool* clamped;
|
35 |
+
int* internal_radii;
|
36 |
+
float2* means2D;
|
37 |
+
float* cov3D;
|
38 |
+
float4* conic_opacity;
|
39 |
+
float* rgb;
|
40 |
+
uint32_t* point_offsets;
|
41 |
+
uint32_t* tiles_touched;
|
42 |
+
|
43 |
+
static GeometryState fromChunk(char*& chunk, size_t P);
|
44 |
+
};
|
45 |
+
|
46 |
+
struct ImageState
|
47 |
+
{
|
48 |
+
uint2* ranges;
|
49 |
+
uint32_t* n_contrib;
|
50 |
+
float* accum_alpha;
|
51 |
+
|
52 |
+
static ImageState fromChunk(char*& chunk, size_t N);
|
53 |
+
};
|
54 |
+
|
55 |
+
struct BinningState
|
56 |
+
{
|
57 |
+
size_t sorting_size;
|
58 |
+
uint64_t* point_list_keys_unsorted;
|
59 |
+
uint64_t* point_list_keys;
|
60 |
+
uint32_t* point_list_unsorted;
|
61 |
+
uint32_t* point_list;
|
62 |
+
char* list_sorting_space;
|
63 |
+
|
64 |
+
static BinningState fromChunk(char*& chunk, size_t P);
|
65 |
+
};
|
66 |
+
|
67 |
+
template<typename T>
|
68 |
+
size_t required(size_t P)
|
69 |
+
{
|
70 |
+
char* size = nullptr;
|
71 |
+
T::fromChunk(size, P);
|
72 |
+
return ((size_t)size) + 128;
|
73 |
+
}
|
74 |
+
};
|
submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
from typing import NamedTuple
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch
|
15 |
+
from . import _C
|
16 |
+
|
17 |
+
def cpu_deep_copy_tuple(input_tuple):
|
18 |
+
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
|
19 |
+
return tuple(copied_tensors)
|
20 |
+
|
21 |
+
def rasterize_gaussians(
|
22 |
+
means3D,
|
23 |
+
means2D,
|
24 |
+
sh,
|
25 |
+
colors_precomp,
|
26 |
+
opacities,
|
27 |
+
scales,
|
28 |
+
rotations,
|
29 |
+
cov3Ds_precomp,
|
30 |
+
raster_settings,
|
31 |
+
):
|
32 |
+
return _RasterizeGaussians.apply(
|
33 |
+
means3D,
|
34 |
+
means2D,
|
35 |
+
sh,
|
36 |
+
colors_precomp,
|
37 |
+
opacities,
|
38 |
+
scales,
|
39 |
+
rotations,
|
40 |
+
cov3Ds_precomp,
|
41 |
+
raster_settings,
|
42 |
+
)
|
43 |
+
|
44 |
+
class _RasterizeGaussians(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(
|
47 |
+
ctx,
|
48 |
+
means3D,
|
49 |
+
means2D,
|
50 |
+
sh,
|
51 |
+
colors_precomp,
|
52 |
+
opacities,
|
53 |
+
scales,
|
54 |
+
rotations,
|
55 |
+
cov3Ds_precomp,
|
56 |
+
raster_settings,
|
57 |
+
):
|
58 |
+
|
59 |
+
# Restructure arguments the way that the C++ lib expects them
|
60 |
+
args = (
|
61 |
+
raster_settings.bg,
|
62 |
+
means3D,
|
63 |
+
colors_precomp,
|
64 |
+
opacities,
|
65 |
+
scales,
|
66 |
+
rotations,
|
67 |
+
raster_settings.scale_modifier,
|
68 |
+
cov3Ds_precomp,
|
69 |
+
raster_settings.viewmatrix,
|
70 |
+
raster_settings.projmatrix,
|
71 |
+
raster_settings.tanfovx,
|
72 |
+
raster_settings.tanfovy,
|
73 |
+
raster_settings.image_height,
|
74 |
+
raster_settings.image_width,
|
75 |
+
sh,
|
76 |
+
raster_settings.sh_degree,
|
77 |
+
raster_settings.campos,
|
78 |
+
raster_settings.prefiltered,
|
79 |
+
raster_settings.debug
|
80 |
+
)
|
81 |
+
|
82 |
+
# Invoke C++/CUDA rasterizer
|
83 |
+
if raster_settings.debug:
|
84 |
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
85 |
+
try:
|
86 |
+
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
87 |
+
except Exception as ex:
|
88 |
+
torch.save(cpu_args, "snapshot_fw.dump")
|
89 |
+
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
|
90 |
+
raise ex
|
91 |
+
else:
|
92 |
+
num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
|
93 |
+
|
94 |
+
# Keep relevant tensors for backward
|
95 |
+
ctx.raster_settings = raster_settings
|
96 |
+
ctx.num_rendered = num_rendered
|
97 |
+
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
|
98 |
+
return color, radii
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def backward(ctx, grad_out_color, _):
|
102 |
+
|
103 |
+
# Restore necessary values from context
|
104 |
+
num_rendered = ctx.num_rendered
|
105 |
+
raster_settings = ctx.raster_settings
|
106 |
+
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
|
107 |
+
|
108 |
+
# Restructure args as C++ method expects them
|
109 |
+
args = (raster_settings.bg,
|
110 |
+
means3D,
|
111 |
+
radii,
|
112 |
+
colors_precomp,
|
113 |
+
scales,
|
114 |
+
rotations,
|
115 |
+
raster_settings.scale_modifier,
|
116 |
+
cov3Ds_precomp,
|
117 |
+
raster_settings.viewmatrix,
|
118 |
+
raster_settings.projmatrix,
|
119 |
+
raster_settings.tanfovx,
|
120 |
+
raster_settings.tanfovy,
|
121 |
+
grad_out_color,
|
122 |
+
sh,
|
123 |
+
raster_settings.sh_degree,
|
124 |
+
raster_settings.campos,
|
125 |
+
geomBuffer,
|
126 |
+
num_rendered,
|
127 |
+
binningBuffer,
|
128 |
+
imgBuffer,
|
129 |
+
raster_settings.debug)
|
130 |
+
|
131 |
+
# Compute gradients for relevant tensors by invoking backward method
|
132 |
+
if raster_settings.debug:
|
133 |
+
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
|
134 |
+
try:
|
135 |
+
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
|
136 |
+
except Exception as ex:
|
137 |
+
torch.save(cpu_args, "snapshot_bw.dump")
|
138 |
+
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
|
139 |
+
raise ex
|
140 |
+
else:
|
141 |
+
grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
|
142 |
+
|
143 |
+
grads = (
|
144 |
+
grad_means3D,
|
145 |
+
grad_means2D,
|
146 |
+
grad_sh,
|
147 |
+
grad_colors_precomp,
|
148 |
+
grad_opacities,
|
149 |
+
grad_scales,
|
150 |
+
grad_rotations,
|
151 |
+
grad_cov3Ds_precomp,
|
152 |
+
None,
|
153 |
+
)
|
154 |
+
|
155 |
+
return grads
|
156 |
+
|
157 |
+
class GaussianRasterizationSettings(NamedTuple):
|
158 |
+
image_height: int
|
159 |
+
image_width: int
|
160 |
+
tanfovx : float
|
161 |
+
tanfovy : float
|
162 |
+
bg : torch.Tensor
|
163 |
+
scale_modifier : float
|
164 |
+
viewmatrix : torch.Tensor
|
165 |
+
projmatrix : torch.Tensor
|
166 |
+
sh_degree : int
|
167 |
+
campos : torch.Tensor
|
168 |
+
prefiltered : bool
|
169 |
+
debug : bool
|
170 |
+
|
171 |
+
class GaussianRasterizer(nn.Module):
|
172 |
+
def __init__(self, raster_settings):
|
173 |
+
super().__init__()
|
174 |
+
self.raster_settings = raster_settings
|
175 |
+
|
176 |
+
def markVisible(self, positions):
|
177 |
+
# Mark visible points (based on frustum culling for camera) with a boolean
|
178 |
+
with torch.no_grad():
|
179 |
+
raster_settings = self.raster_settings
|
180 |
+
visible = _C.mark_visible(
|
181 |
+
positions,
|
182 |
+
raster_settings.viewmatrix,
|
183 |
+
raster_settings.projmatrix)
|
184 |
+
|
185 |
+
return visible
|
186 |
+
|
187 |
+
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
|
188 |
+
|
189 |
+
raster_settings = self.raster_settings
|
190 |
+
|
191 |
+
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
|
192 |
+
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
|
193 |
+
|
194 |
+
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
|
195 |
+
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
|
196 |
+
|
197 |
+
if shs is None:
|
198 |
+
shs = torch.Tensor([])
|
199 |
+
if colors_precomp is None:
|
200 |
+
colors_precomp = torch.Tensor([])
|
201 |
+
|
202 |
+
if scales is None:
|
203 |
+
scales = torch.Tensor([])
|
204 |
+
if rotations is None:
|
205 |
+
rotations = torch.Tensor([])
|
206 |
+
if cov3D_precomp is None:
|
207 |
+
cov3D_precomp = torch.Tensor([])
|
208 |
+
|
209 |
+
# Invoke C++/CUDA rasterization routine
|
210 |
+
return rasterize_gaussians(
|
211 |
+
means3D,
|
212 |
+
means2D,
|
213 |
+
shs,
|
214 |
+
colors_precomp,
|
215 |
+
opacities,
|
216 |
+
scales,
|
217 |
+
rotations,
|
218 |
+
cov3D_precomp,
|
219 |
+
raster_settings,
|
220 |
+
)
|
221 |
+
|
submodules/diff-gaussian-rasterization/ext.cpp
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <torch/extension.h>
|
13 |
+
#include "rasterize_points.h"
|
14 |
+
|
15 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
16 |
+
m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
|
17 |
+
m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
|
18 |
+
m.def("mark_visible", &markVisible);
|
19 |
+
}
|
submodules/diff-gaussian-rasterization/rasterize_points.cu
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <math.h>
|
13 |
+
#include <torch/extension.h>
|
14 |
+
#include <cstdio>
|
15 |
+
#include <sstream>
|
16 |
+
#include <iostream>
|
17 |
+
#include <tuple>
|
18 |
+
#include <stdio.h>
|
19 |
+
#include <cuda_runtime_api.h>
|
20 |
+
#include <memory>
|
21 |
+
#include "cuda_rasterizer/config.h"
|
22 |
+
#include "cuda_rasterizer/rasterizer.h"
|
23 |
+
#include <fstream>
|
24 |
+
#include <string>
|
25 |
+
#include <functional>
|
26 |
+
|
27 |
+
std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
|
28 |
+
auto lambda = [&t](size_t N) {
|
29 |
+
t.resize_({(long long)N});
|
30 |
+
return reinterpret_cast<char*>(t.contiguous().data_ptr());
|
31 |
+
};
|
32 |
+
return lambda;
|
33 |
+
}
|
34 |
+
|
35 |
+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
36 |
+
RasterizeGaussiansCUDA(
|
37 |
+
const torch::Tensor& background,
|
38 |
+
const torch::Tensor& means3D,
|
39 |
+
const torch::Tensor& colors,
|
40 |
+
const torch::Tensor& opacity,
|
41 |
+
const torch::Tensor& scales,
|
42 |
+
const torch::Tensor& rotations,
|
43 |
+
const float scale_modifier,
|
44 |
+
const torch::Tensor& cov3D_precomp,
|
45 |
+
const torch::Tensor& viewmatrix,
|
46 |
+
const torch::Tensor& projmatrix,
|
47 |
+
const float tan_fovx,
|
48 |
+
const float tan_fovy,
|
49 |
+
const int image_height,
|
50 |
+
const int image_width,
|
51 |
+
const torch::Tensor& sh,
|
52 |
+
const int degree,
|
53 |
+
const torch::Tensor& campos,
|
54 |
+
const bool prefiltered,
|
55 |
+
const bool debug)
|
56 |
+
{
|
57 |
+
if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
|
58 |
+
AT_ERROR("means3D must have dimensions (num_points, 3)");
|
59 |
+
}
|
60 |
+
|
61 |
+
const int P = means3D.size(0);
|
62 |
+
const int H = image_height;
|
63 |
+
const int W = image_width;
|
64 |
+
|
65 |
+
auto int_opts = means3D.options().dtype(torch::kInt32);
|
66 |
+
auto float_opts = means3D.options().dtype(torch::kFloat32);
|
67 |
+
|
68 |
+
torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
|
69 |
+
torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
|
70 |
+
|
71 |
+
torch::Device device(torch::kCUDA);
|
72 |
+
torch::TensorOptions options(torch::kByte);
|
73 |
+
torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
|
74 |
+
torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
|
75 |
+
torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
|
76 |
+
std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
|
77 |
+
std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
|
78 |
+
std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
|
79 |
+
|
80 |
+
int rendered = 0;
|
81 |
+
if(P != 0)
|
82 |
+
{
|
83 |
+
int M = 0;
|
84 |
+
if(sh.size(0) != 0)
|
85 |
+
{
|
86 |
+
M = sh.size(1);
|
87 |
+
}
|
88 |
+
|
89 |
+
rendered = CudaRasterizer::Rasterizer::forward(
|
90 |
+
geomFunc,
|
91 |
+
binningFunc,
|
92 |
+
imgFunc,
|
93 |
+
P, degree, M,
|
94 |
+
background.contiguous().data<float>(),
|
95 |
+
W, H,
|
96 |
+
means3D.contiguous().data<float>(),
|
97 |
+
sh.contiguous().data_ptr<float>(),
|
98 |
+
colors.contiguous().data<float>(),
|
99 |
+
opacity.contiguous().data<float>(),
|
100 |
+
scales.contiguous().data_ptr<float>(),
|
101 |
+
scale_modifier,
|
102 |
+
rotations.contiguous().data_ptr<float>(),
|
103 |
+
cov3D_precomp.contiguous().data<float>(),
|
104 |
+
viewmatrix.contiguous().data<float>(),
|
105 |
+
projmatrix.contiguous().data<float>(),
|
106 |
+
campos.contiguous().data<float>(),
|
107 |
+
tan_fovx,
|
108 |
+
tan_fovy,
|
109 |
+
prefiltered,
|
110 |
+
out_color.contiguous().data<float>(),
|
111 |
+
radii.contiguous().data<int>(),
|
112 |
+
debug);
|
113 |
+
}
|
114 |
+
return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
|
115 |
+
}
|
116 |
+
|
117 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
118 |
+
RasterizeGaussiansBackwardCUDA(
|
119 |
+
const torch::Tensor& background,
|
120 |
+
const torch::Tensor& means3D,
|
121 |
+
const torch::Tensor& radii,
|
122 |
+
const torch::Tensor& colors,
|
123 |
+
const torch::Tensor& scales,
|
124 |
+
const torch::Tensor& rotations,
|
125 |
+
const float scale_modifier,
|
126 |
+
const torch::Tensor& cov3D_precomp,
|
127 |
+
const torch::Tensor& viewmatrix,
|
128 |
+
const torch::Tensor& projmatrix,
|
129 |
+
const float tan_fovx,
|
130 |
+
const float tan_fovy,
|
131 |
+
const torch::Tensor& dL_dout_color,
|
132 |
+
const torch::Tensor& sh,
|
133 |
+
const int degree,
|
134 |
+
const torch::Tensor& campos,
|
135 |
+
const torch::Tensor& geomBuffer,
|
136 |
+
const int R,
|
137 |
+
const torch::Tensor& binningBuffer,
|
138 |
+
const torch::Tensor& imageBuffer,
|
139 |
+
const bool debug)
|
140 |
+
{
|
141 |
+
const int P = means3D.size(0);
|
142 |
+
const int H = dL_dout_color.size(1);
|
143 |
+
const int W = dL_dout_color.size(2);
|
144 |
+
|
145 |
+
int M = 0;
|
146 |
+
if(sh.size(0) != 0)
|
147 |
+
{
|
148 |
+
M = sh.size(1);
|
149 |
+
}
|
150 |
+
|
151 |
+
torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
|
152 |
+
torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
|
153 |
+
torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
|
154 |
+
torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
|
155 |
+
torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
|
156 |
+
torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
|
157 |
+
torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
|
158 |
+
torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
|
159 |
+
torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
|
160 |
+
|
161 |
+
if(P != 0)
|
162 |
+
{
|
163 |
+
CudaRasterizer::Rasterizer::backward(P, degree, M, R,
|
164 |
+
background.contiguous().data<float>(),
|
165 |
+
W, H,
|
166 |
+
means3D.contiguous().data<float>(),
|
167 |
+
sh.contiguous().data<float>(),
|
168 |
+
colors.contiguous().data<float>(),
|
169 |
+
scales.data_ptr<float>(),
|
170 |
+
scale_modifier,
|
171 |
+
rotations.data_ptr<float>(),
|
172 |
+
cov3D_precomp.contiguous().data<float>(),
|
173 |
+
viewmatrix.contiguous().data<float>(),
|
174 |
+
projmatrix.contiguous().data<float>(),
|
175 |
+
campos.contiguous().data<float>(),
|
176 |
+
tan_fovx,
|
177 |
+
tan_fovy,
|
178 |
+
radii.contiguous().data<int>(),
|
179 |
+
reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
|
180 |
+
reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
|
181 |
+
reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
|
182 |
+
dL_dout_color.contiguous().data<float>(),
|
183 |
+
dL_dmeans2D.contiguous().data<float>(),
|
184 |
+
dL_dconic.contiguous().data<float>(),
|
185 |
+
dL_dopacity.contiguous().data<float>(),
|
186 |
+
dL_dcolors.contiguous().data<float>(),
|
187 |
+
dL_dmeans3D.contiguous().data<float>(),
|
188 |
+
dL_dcov3D.contiguous().data<float>(),
|
189 |
+
dL_dsh.contiguous().data<float>(),
|
190 |
+
dL_dscales.contiguous().data<float>(),
|
191 |
+
dL_drotations.contiguous().data<float>(),
|
192 |
+
debug);
|
193 |
+
}
|
194 |
+
|
195 |
+
return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
|
196 |
+
}
|
197 |
+
|
198 |
+
torch::Tensor markVisible(
|
199 |
+
torch::Tensor& means3D,
|
200 |
+
torch::Tensor& viewmatrix,
|
201 |
+
torch::Tensor& projmatrix)
|
202 |
+
{
|
203 |
+
const int P = means3D.size(0);
|
204 |
+
|
205 |
+
torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
|
206 |
+
|
207 |
+
if(P != 0)
|
208 |
+
{
|
209 |
+
CudaRasterizer::Rasterizer::markVisible(P,
|
210 |
+
means3D.contiguous().data<float>(),
|
211 |
+
viewmatrix.contiguous().data<float>(),
|
212 |
+
projmatrix.contiguous().data<float>(),
|
213 |
+
present.contiguous().data<bool>());
|
214 |
+
}
|
215 |
+
|
216 |
+
return present;
|
217 |
+
}
|
submodules/diff-gaussian-rasterization/rasterize_points.h
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (C) 2023, Inria
|
3 |
+
* GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
* All rights reserved.
|
5 |
+
*
|
6 |
+
* This software is free for non-commercial, research and evaluation use
|
7 |
+
* under the terms of the LICENSE.md file.
|
8 |
+
*
|
9 |
+
* For inquiries contact george.drettakis@inria.fr
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
#include <torch/extension.h>
|
14 |
+
#include <cstdio>
|
15 |
+
#include <tuple>
|
16 |
+
#include <string>
|
17 |
+
|
18 |
+
std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
19 |
+
RasterizeGaussiansCUDA(
|
20 |
+
const torch::Tensor& background,
|
21 |
+
const torch::Tensor& means3D,
|
22 |
+
const torch::Tensor& colors,
|
23 |
+
const torch::Tensor& opacity,
|
24 |
+
const torch::Tensor& scales,
|
25 |
+
const torch::Tensor& rotations,
|
26 |
+
const float scale_modifier,
|
27 |
+
const torch::Tensor& cov3D_precomp,
|
28 |
+
const torch::Tensor& viewmatrix,
|
29 |
+
const torch::Tensor& projmatrix,
|
30 |
+
const float tan_fovx,
|
31 |
+
const float tan_fovy,
|
32 |
+
const int image_height,
|
33 |
+
const int image_width,
|
34 |
+
const torch::Tensor& sh,
|
35 |
+
const int degree,
|
36 |
+
const torch::Tensor& campos,
|
37 |
+
const bool prefiltered,
|
38 |
+
const bool debug);
|
39 |
+
|
40 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
41 |
+
RasterizeGaussiansBackwardCUDA(
|
42 |
+
const torch::Tensor& background,
|
43 |
+
const torch::Tensor& means3D,
|
44 |
+
const torch::Tensor& radii,
|
45 |
+
const torch::Tensor& colors,
|
46 |
+
const torch::Tensor& scales,
|
47 |
+
const torch::Tensor& rotations,
|
48 |
+
const float scale_modifier,
|
49 |
+
const torch::Tensor& cov3D_precomp,
|
50 |
+
const torch::Tensor& viewmatrix,
|
51 |
+
const torch::Tensor& projmatrix,
|
52 |
+
const float tan_fovx,
|
53 |
+
const float tan_fovy,
|
54 |
+
const torch::Tensor& dL_dout_color,
|
55 |
+
const torch::Tensor& sh,
|
56 |
+
const int degree,
|
57 |
+
const torch::Tensor& campos,
|
58 |
+
const torch::Tensor& geomBuffer,
|
59 |
+
const int R,
|
60 |
+
const torch::Tensor& binningBuffer,
|
61 |
+
const torch::Tensor& imageBuffer,
|
62 |
+
const bool debug);
|
63 |
+
|
64 |
+
torch::Tensor markVisible(
|
65 |
+
torch::Tensor& means3D,
|
66 |
+
torch::Tensor& viewmatrix,
|
67 |
+
torch::Tensor& projmatrix);
|
submodules/diff-gaussian-rasterization/setup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Copyright (C) 2023, Inria
|
3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
4 |
+
# All rights reserved.
|
5 |
+
#
|
6 |
+
# This software is free for non-commercial, research and evaluation use
|
7 |
+
# under the terms of the LICENSE.md file.
|
8 |
+
#
|
9 |
+
# For inquiries contact george.drettakis@inria.fr
|
10 |
+
#
|
11 |
+
|
12 |
+
from setuptools import setup
|
13 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
14 |
+
import os
|
15 |
+
os.path.dirname(os.path.abspath(__file__))
|
16 |
+
|
17 |
+
setup(
|
18 |
+
name="diff_gaussian_rasterization",
|
19 |
+
packages=['diff_gaussian_rasterization'],
|
20 |
+
ext_modules=[
|
21 |
+
CUDAExtension(
|
22 |
+
name="diff_gaussian_rasterization._C",
|
23 |
+
sources=[
|
24 |
+
"cuda_rasterizer/rasterizer_impl.cu",
|
25 |
+
"cuda_rasterizer/forward.cu",
|
26 |
+
"cuda_rasterizer/backward.cu",
|
27 |
+
"rasterize_points.cu",
|
28 |
+
"ext.cpp"],
|
29 |
+
extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]})
|
30 |
+
],
|
31 |
+
cmdclass={
|
32 |
+
'build_ext': BuildExtension
|
33 |
+
}
|
34 |
+
)
|
submodules/diff-gaussian-rasterization/third_party/glm/.appveyor.yml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
shallow_clone: true
|
2 |
+
|
3 |
+
platform:
|
4 |
+
- x86
|
5 |
+
- x64
|
6 |
+
|
7 |
+
configuration:
|
8 |
+
- Debug
|
9 |
+
- Release
|
10 |
+
|
11 |
+
image:
|
12 |
+
- Visual Studio 2013
|
13 |
+
- Visual Studio 2015
|
14 |
+
- Visual Studio 2017
|
15 |
+
- Visual Studio 2019
|
16 |
+
|
17 |
+
environment:
|
18 |
+
matrix:
|
19 |
+
- GLM_ARGUMENTS: -DGLM_TEST_FORCE_PURE=ON
|
20 |
+
- GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_SSE2=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
|
21 |
+
- GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
|
22 |
+
- GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
|
23 |
+
- GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
|
24 |
+
|
25 |
+
matrix:
|
26 |
+
exclude:
|
27 |
+
- image: Visual Studio 2013
|
28 |
+
GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
|
29 |
+
- image: Visual Studio 2013
|
30 |
+
GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
|
31 |
+
- image: Visual Studio 2013
|
32 |
+
GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
|
33 |
+
- image: Visual Studio 2013
|
34 |
+
configuration: Debug
|
35 |
+
- image: Visual Studio 2015
|
36 |
+
GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_SSE2=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
|
37 |
+
- image: Visual Studio 2015
|
38 |
+
GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
|
39 |
+
- image: Visual Studio 2015
|
40 |
+
GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
|
41 |
+
- image: Visual Studio 2015
|
42 |
+
platform: x86
|
43 |
+
- image: Visual Studio 2015
|
44 |
+
configuration: Debug
|
45 |
+
- image: Visual Studio 2017
|
46 |
+
platform: x86
|
47 |
+
- image: Visual Studio 2017
|
48 |
+
configuration: Debug
|
49 |
+
- image: Visual Studio 2019
|
50 |
+
platform: x64
|
51 |
+
|
52 |
+
branches:
|
53 |
+
only:
|
54 |
+
- master
|
55 |
+
|
56 |
+
before_build:
|
57 |
+
- ps: |
|
58 |
+
mkdir build
|
59 |
+
cd build
|
60 |
+
|
61 |
+
if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2013") {
|
62 |
+
$env:generator="Visual Studio 12 2013"
|
63 |
+
}
|
64 |
+
if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2015") {
|
65 |
+
$env:generator="Visual Studio 14 2015"
|
66 |
+
}
|
67 |
+
if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2017") {
|
68 |
+
$env:generator="Visual Studio 15 2017"
|
69 |
+
}
|
70 |
+
if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2019") {
|
71 |
+
$env:generator="Visual Studio 16 2019"
|
72 |
+
}
|
73 |
+
if ($env:PLATFORM -eq "x64") {
|
74 |
+
$env:generator="$env:generator Win64"
|
75 |
+
}
|
76 |
+
echo generator="$env:generator"
|
77 |
+
cmake .. -G "$env:generator" -DCMAKE_INSTALL_PREFIX="$env:APPVEYOR_BUILD_FOLDER/install" -DGLM_QUIET=ON -DGLM_TEST_ENABLE=ON "$env:GLM_ARGUMENTS"
|
78 |
+
|
79 |
+
build_script:
|
80 |
+
- cmake --build . --parallel --config %CONFIGURATION% -- /m /v:minimal
|
81 |
+
- cmake --build . --target install --parallel --config %CONFIGURATION% -- /m /v:minimal
|
82 |
+
|
83 |
+
test_script:
|
84 |
+
- ctest --parallel 4 --verbose -C %CONFIGURATION%
|
85 |
+
- cd ..
|
86 |
+
- ps: |
|
87 |
+
mkdir build_test_cmake
|
88 |
+
cd build_test_cmake
|
89 |
+
cmake ..\test\cmake\ -G "$env:generator" -DCMAKE_PREFIX_PATH="$env:APPVEYOR_BUILD_FOLDER/install"
|
90 |
+
- cmake --build . --parallel --config %CONFIGURATION% -- /m /v:minimal
|
91 |
+
|
92 |
+
deploy: off
|
submodules/diff-gaussian-rasterization/third_party/glm/.gitignore
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compiled Object files
|
2 |
+
*.slo
|
3 |
+
*.lo
|
4 |
+
*.o
|
5 |
+
*.obj
|
6 |
+
|
7 |
+
# Precompiled Headers
|
8 |
+
*.gch
|
9 |
+
*.pch
|
10 |
+
|
11 |
+
# Compiled Dynamic libraries
|
12 |
+
*.so
|
13 |
+
*.dylib
|
14 |
+
*.dll
|
15 |
+
|
16 |
+
# Fortran module files
|
17 |
+
*.mod
|
18 |
+
|
19 |
+
# Compiled Static libraries
|
20 |
+
*.lai
|
21 |
+
*.la
|
22 |
+
*.a
|
23 |
+
*.lib
|
24 |
+
|
25 |
+
# Executables
|
26 |
+
*.exe
|
27 |
+
*.out
|
28 |
+
*.app
|
29 |
+
|
30 |
+
# CMake
|
31 |
+
CMakeCache.txt
|
32 |
+
CMakeFiles
|
33 |
+
cmake_install.cmake
|
34 |
+
install_manifest.txt
|
35 |
+
*.cmake
|
36 |
+
!glmConfig.cmake
|
37 |
+
!glmConfig-version.cmake
|
38 |
+
# ^ May need to add future .cmake files as exceptions
|
39 |
+
|
40 |
+
# Test logs
|
41 |
+
Testing/*
|
42 |
+
|
43 |
+
# Test input
|
44 |
+
test/gtc/*.dds
|
45 |
+
|
46 |
+
# Project Files
|
47 |
+
Makefile
|
48 |
+
*.cbp
|
49 |
+
*.user
|
50 |
+
|
51 |
+
# Misc.
|
52 |
+
*.log
|
53 |
+
|
54 |
+
# local build(s)
|
55 |
+
build*
|
56 |
+
|
57 |
+
/.vs
|
58 |
+
/.vscode
|
59 |
+
/CMakeSettings.json
|
60 |
+
.DS_Store
|
61 |
+
*.swp
|