ironjr commited on
Commit
f8a5306
1 Parent(s): a462822

Delete scene/.ipynb_checkpoints

Browse files
scene/.ipynb_checkpoints/__init__-checkpoint.py DELETED
@@ -1,41 +0,0 @@
1
- ###
2
- # Copyright (C) 2023, Computer Vision Lab, Seoul National University, https://cv.snu.ac.kr
3
- # For permission requests, please contact robot0321@snu.ac.kr, esw0116@snu.ac.kr, namhj28@gmail.com, jarin.lee@gmail.com.
4
- # All rights reserved.
5
- ###
6
- import os
7
- import random
8
-
9
- from arguments import GSParams
10
- from utils.system import searchForMaxIteration
11
- from scene.dataset_readers import readDataInfo
12
- from scene.gaussian_model import GaussianModel
13
-
14
-
15
- class Scene:
16
- gaussians: GaussianModel
17
-
18
- def __init__(self, traindata, gaussians: GaussianModel, opt: GSParams):
19
- self.traindata = traindata
20
- self.gaussians = gaussians
21
-
22
- info = readDataInfo(traindata, opt.white_background)
23
- random.shuffle(info.train_cameras) # Multi-res consistent random shuffling
24
- self.cameras_extent = info.nerf_normalization["radius"]
25
-
26
- print("Loading Training Cameras")
27
- self.train_cameras = info.train_cameras
28
- print("Loading Preset Cameras")
29
- self.preset_cameras = {}
30
- for campath in info.preset_cameras.keys():
31
- self.preset_cameras[campath] = info.preset_cameras[campath]
32
-
33
- self.gaussians.create_from_pcd(info.point_cloud, self.cameras_extent)
34
- self.gaussians.training_setup(opt)
35
-
36
- def getTrainCameras(self):
37
- return self.train_cameras
38
-
39
- def getPresetCameras(self, preset):
40
- assert preset in self.preset_cameras
41
- return self.preset_cameras[preset]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scene/.ipynb_checkpoints/cameras-checkpoint.py DELETED
@@ -1,76 +0,0 @@
1
- #
2
- # Copyright (C) 2023, Inria
3
- # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
- # All rights reserved.
5
- #
6
- # This software is free for non-commercial, research and evaluation use
7
- # under the terms of the LICENSE.md file.
8
- #
9
- # For inquiries contact george.drettakis@inria.fr
10
- #
11
- import numpy as np
12
-
13
- import torch
14
- from torch import nn
15
-
16
- from utils.graphics import getWorld2View2, getProjectionMatrix
17
- from utils.loss import image2canny
18
-
19
-
20
- class Camera(nn.Module):
21
- def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
22
- image_name, uid,
23
- trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
24
- ):
25
- super(Camera, self).__init__()
26
-
27
- self.uid = uid
28
- self.colmap_id = colmap_id
29
- self.R = R
30
- self.T = T
31
- self.FoVx = FoVx
32
- self.FoVy = FoVy
33
- self.image_name = image_name
34
-
35
- try:
36
- self.data_device = torch.device(data_device)
37
- except Exception as e:
38
- print(e)
39
- print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
40
- self.data_device = torch.device("cuda")
41
-
42
- self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
43
- self.canny_mask = image2canny(self.original_image.permute(1,2,0), 50, 150, isEdge1=False).detach().to(self.data_device)
44
- self.image_width = self.original_image.shape[2]
45
- self.image_height = self.original_image.shape[1]
46
-
47
- if gt_alpha_mask is not None:
48
- self.original_image *= gt_alpha_mask.to(self.data_device)
49
- else:
50
- self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
51
-
52
- self.zfar = 100.0
53
- self.znear = 0.01
54
-
55
- self.trans = trans
56
- self.scale = scale
57
-
58
- self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
59
- self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
60
- self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
61
- self.camera_center = self.world_view_transform.inverse()[3, :3]
62
-
63
-
64
- class MiniCam:
65
- def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
66
- self.image_width = width
67
- self.image_height = height
68
- self.FoVy = fovy
69
- self.FoVx = fovx
70
- self.znear = znear
71
- self.zfar = zfar
72
- self.world_view_transform = world_view_transform
73
- self.full_proj_transform = full_proj_transform
74
- view_inv = torch.inverse(self.world_view_transform)
75
- self.camera_center = view_inv[3][:3]
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scene/.ipynb_checkpoints/colmap_loader-checkpoint.py DELETED
@@ -1,301 +0,0 @@
1
- #
2
- # Copyright (C) 2023, Inria
3
- # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
- # All rights reserved.
5
- #
6
- # This software is free for non-commercial, research and evaluation use
7
- # under the terms of the LICENSE.md file.
8
- #
9
- # For inquiries contact george.drettakis@inria.fr
10
- #
11
- import numpy as np
12
- import collections
13
- import struct
14
-
15
-
16
- CameraModel = collections.namedtuple(
17
- "CameraModel", ["model_id", "model_name", "num_params"])
18
- Camera = collections.namedtuple(
19
- "Camera", ["id", "model", "width", "height", "params"])
20
- BaseImage = collections.namedtuple(
21
- "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
22
- Point3D = collections.namedtuple(
23
- "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
24
- CAMERA_MODELS = {
25
- CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26
- CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27
- CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28
- CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29
- CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30
- CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31
- CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32
- CameraModel(model_id=7, model_name="FOV", num_params=5),
33
- CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34
- CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35
- CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36
- }
37
- CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38
- for camera_model in CAMERA_MODELS])
39
- CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40
- for camera_model in CAMERA_MODELS])
41
-
42
-
43
- def qvec2rotmat(qvec):
44
- return np.array([
45
- [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
46
- 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
47
- 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
48
- [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
49
- 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
50
- 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
51
- [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
52
- 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
53
- 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
54
-
55
-
56
- def rotmat2qvec(R):
57
- Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
58
- K = np.array([
59
- [Rxx - Ryy - Rzz, 0, 0, 0],
60
- [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
61
- [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
62
- [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
63
- eigvals, eigvecs = np.linalg.eigh(K)
64
- qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
65
- if qvec[0] < 0:
66
- qvec *= -1
67
- return qvec
68
-
69
-
70
- class Image(BaseImage):
71
- def qvec2rotmat(self):
72
- return qvec2rotmat(self.qvec)
73
-
74
-
75
- def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
76
- """Read and unpack the next bytes from a binary file.
77
- :param fid:
78
- :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
79
- :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
80
- :param endian_character: Any of {@, =, <, >, !}
81
- :return: Tuple of read and unpacked values.
82
- """
83
- data = fid.read(num_bytes)
84
- return struct.unpack(endian_character + format_char_sequence, data)
85
-
86
-
87
- def read_points3D_text(path):
88
- """
89
- see: src/base/reconstruction.cc
90
- void Reconstruction::ReadPoints3DText(const std::string& path)
91
- void Reconstruction::WritePoints3DText(const std::string& path)
92
- """
93
- xyzs = None
94
- rgbs = None
95
- errors = None
96
- num_points = 0
97
- with open(path, "r") as fid:
98
- while True:
99
- line = fid.readline()
100
- if not line:
101
- break
102
- line = line.strip()
103
- if len(line) > 0 and line[0] != "#":
104
- num_points += 1
105
-
106
-
107
- xyzs = np.empty((num_points, 3))
108
- rgbs = np.empty((num_points, 3))
109
- errors = np.empty((num_points, 1))
110
- count = 0
111
- with open(path, "r") as fid:
112
- while True:
113
- line = fid.readline()
114
- if not line:
115
- break
116
- line = line.strip()
117
- if len(line) > 0 and line[0] != "#":
118
- elems = line.split()
119
- xyz = np.array(tuple(map(float, elems[1:4])))
120
- rgb = np.array(tuple(map(int, elems[4:7])))
121
- error = np.array(float(elems[7]))
122
- xyzs[count] = xyz
123
- rgbs[count] = rgb
124
- errors[count] = error
125
- count += 1
126
-
127
- return xyzs, rgbs, errors
128
-
129
-
130
- def read_points3D_binary(path_to_model_file):
131
- """
132
- see: src/base/reconstruction.cc
133
- void Reconstruction::ReadPoints3DBinary(const std::string& path)
134
- void Reconstruction::WritePoints3DBinary(const std::string& path)
135
- """
136
-
137
-
138
- with open(path_to_model_file, "rb") as fid:
139
- num_points = read_next_bytes(fid, 8, "Q")[0]
140
-
141
- xyzs = np.empty((num_points, 3))
142
- rgbs = np.empty((num_points, 3))
143
- errors = np.empty((num_points, 1))
144
-
145
- for p_id in range(num_points):
146
- binary_point_line_properties = read_next_bytes(
147
- fid, num_bytes=43, format_char_sequence="QdddBBBd")
148
- xyz = np.array(binary_point_line_properties[1:4])
149
- rgb = np.array(binary_point_line_properties[4:7])
150
- error = np.array(binary_point_line_properties[7])
151
- track_length = read_next_bytes(
152
- fid, num_bytes=8, format_char_sequence="Q")[0]
153
- track_elems = read_next_bytes(
154
- fid, num_bytes=8*track_length,
155
- format_char_sequence="ii"*track_length)
156
- xyzs[p_id] = xyz
157
- rgbs[p_id] = rgb
158
- errors[p_id] = error
159
- return xyzs, rgbs, errors
160
-
161
-
162
- def read_intrinsics_text(path):
163
- """
164
- Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
165
- """
166
- cameras = {}
167
- with open(path, "r") as fid:
168
- while True:
169
- line = fid.readline()
170
- if not line:
171
- break
172
- line = line.strip()
173
- if len(line) > 0 and line[0] != "#":
174
- elems = line.split()
175
- camera_id = int(elems[0])
176
- model = elems[1]
177
- assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
178
- width = int(elems[2])
179
- height = int(elems[3])
180
- params = np.array(tuple(map(float, elems[4:])))
181
- cameras[camera_id] = Camera(id=camera_id, model=model,
182
- width=width, height=height,
183
- params=params)
184
- return cameras
185
-
186
-
187
- def read_extrinsics_binary(path_to_model_file):
188
- """
189
- see: src/base/reconstruction.cc
190
- void Reconstruction::ReadImagesBinary(const std::string& path)
191
- void Reconstruction::WriteImagesBinary(const std::string& path)
192
- """
193
- images = {}
194
- with open(path_to_model_file, "rb") as fid:
195
- num_reg_images = read_next_bytes(fid, 8, "Q")[0]
196
- for _ in range(num_reg_images):
197
- binary_image_properties = read_next_bytes(
198
- fid, num_bytes=64, format_char_sequence="idddddddi")
199
- image_id = binary_image_properties[0]
200
- qvec = np.array(binary_image_properties[1:5])
201
- tvec = np.array(binary_image_properties[5:8])
202
- camera_id = binary_image_properties[8]
203
- image_name = ""
204
- current_char = read_next_bytes(fid, 1, "c")[0]
205
- while current_char != b"\x00": # look for the ASCII 0 entry
206
- image_name += current_char.decode("utf-8")
207
- current_char = read_next_bytes(fid, 1, "c")[0]
208
- num_points2D = read_next_bytes(fid, num_bytes=8,
209
- format_char_sequence="Q")[0]
210
- x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
211
- format_char_sequence="ddq"*num_points2D)
212
- xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
213
- tuple(map(float, x_y_id_s[1::3]))])
214
- point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
215
- images[image_id] = Image(
216
- id=image_id, qvec=qvec, tvec=tvec,
217
- camera_id=camera_id, name=image_name,
218
- xys=xys, point3D_ids=point3D_ids)
219
- return images
220
-
221
-
222
- def read_intrinsics_binary(path_to_model_file):
223
- """
224
- see: src/base/reconstruction.cc
225
- void Reconstruction::WriteCamerasBinary(const std::string& path)
226
- void Reconstruction::ReadCamerasBinary(const std::string& path)
227
- """
228
- cameras = {}
229
- with open(path_to_model_file, "rb") as fid:
230
- num_cameras = read_next_bytes(fid, 8, "Q")[0]
231
- for _ in range(num_cameras):
232
- camera_properties = read_next_bytes(
233
- fid, num_bytes=24, format_char_sequence="iiQQ")
234
- camera_id = camera_properties[0]
235
- model_id = camera_properties[1]
236
- model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
237
- width = camera_properties[2]
238
- height = camera_properties[3]
239
- num_params = CAMERA_MODEL_IDS[model_id].num_params
240
- params = read_next_bytes(fid, num_bytes=8*num_params,
241
- format_char_sequence="d"*num_params)
242
- cameras[camera_id] = Camera(id=camera_id,
243
- model=model_name,
244
- width=width,
245
- height=height,
246
- params=np.array(params))
247
- assert len(cameras) == num_cameras
248
- return cameras
249
-
250
-
251
- def read_extrinsics_text(path):
252
- """
253
- Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
254
- """
255
- images = {}
256
- with open(path, "r") as fid:
257
- while True:
258
- line = fid.readline()
259
- if not line:
260
- break
261
- line = line.strip()
262
- if len(line) > 0 and line[0] != "#":
263
- elems = line.split()
264
- image_id = int(elems[0])
265
- qvec = np.array(tuple(map(float, elems[1:5])))
266
- tvec = np.array(tuple(map(float, elems[5:8])))
267
- camera_id = int(elems[8])
268
- image_name = elems[9]
269
- elems = fid.readline().split()
270
- xys = np.column_stack([tuple(map(float, elems[0::3])),
271
- tuple(map(float, elems[1::3]))])
272
- point3D_ids = np.array(tuple(map(int, elems[2::3])))
273
- images[image_id] = Image(
274
- id=image_id, qvec=qvec, tvec=tvec,
275
- camera_id=camera_id, name=image_name,
276
- xys=xys, point3D_ids=point3D_ids)
277
- return images
278
-
279
-
280
- def read_colmap_bin_array(path):
281
- """
282
- Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
283
-
284
- :param path: path to the colmap binary file.
285
- :return: nd array with the floating point values in the value
286
- """
287
- with open(path, "rb") as fid:
288
- width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
289
- usecols=(0, 1, 2), dtype=int)
290
- fid.seek(0)
291
- num_delimiter = 0
292
- byte = fid.read(1)
293
- while True:
294
- if byte == b"&":
295
- num_delimiter += 1
296
- if num_delimiter >= 3:
297
- break
298
- byte = fid.read(1)
299
- array = np.fromfile(fid, np.float32)
300
- array = array.reshape((width, height, channels), order="F")
301
- return np.transpose(array, (1, 0, 2)).squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scene/.ipynb_checkpoints/dataset_readers-checkpoint.py DELETED
@@ -1,434 +0,0 @@
1
- #
2
- # Copyright (C) 2023, Inria
3
- # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
- # All rights reserved.
5
- #
6
- # This software is free for non-commercial, research and evaluation use
7
- # under the terms of the LICENSE.md file.
8
- #
9
- # For inquiries contact george.drettakis@inria.fr
10
- #
11
- import os
12
- import sys
13
- import json
14
- from typing import NamedTuple
15
- from pathlib import Path
16
-
17
- import imageio
18
- import torch
19
- import numpy as np
20
- from PIL import Image
21
- from plyfile import PlyData, PlyElement
22
-
23
- from scene.gaussian_model import BasicPointCloud
24
- from scene.cameras import MiniCam, Camera
25
- from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
26
- read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
27
- from utils.graphics import getWorld2View2, focal2fov, fov2focal
28
- from utils.graphics import getProjectionMatrix
29
- from utils.trajectory import get_camerapaths
30
- from utils.sh import SH2RGB
31
-
32
-
33
- class CameraInfo(NamedTuple):
34
- uid: int
35
- R: np.array
36
- T: np.array
37
- FovY: np.array
38
- FovX: np.array
39
- image: np.array
40
- image_path: str
41
- image_name: str
42
- width: int
43
- height: int
44
-
45
-
46
- class SceneInfo(NamedTuple):
47
- point_cloud: BasicPointCloud
48
- train_cameras: list
49
- test_cameras: list
50
- preset_cameras: list
51
- nerf_normalization: dict
52
- ply_path: str
53
-
54
-
55
- def getNerfppNorm(cam_info):
56
- def get_center_and_diag(cam_centers):
57
- cam_centers = np.hstack(cam_centers)
58
- avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
59
- center = avg_cam_center
60
- dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
61
- diagonal = np.max(dist)
62
- return center.flatten(), diagonal
63
-
64
- cam_centers = []
65
-
66
- for cam in cam_info:
67
- W2C = getWorld2View2(cam.R, cam.T)
68
- C2W = np.linalg.inv(W2C)
69
- cam_centers.append(C2W[:3, 3:4])
70
-
71
- center, diagonal = get_center_and_diag(cam_centers)
72
- radius = diagonal * 1.1
73
-
74
- translate = -center
75
-
76
- return {"translate": translate, "radius": radius}
77
-
78
-
79
- def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
80
- cam_infos = []
81
- for idx, key in enumerate(cam_extrinsics):
82
- sys.stdout.write('\r')
83
- # the exact output you're looking for:
84
- sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
85
- sys.stdout.flush()
86
-
87
- extr = cam_extrinsics[key]
88
- intr = cam_intrinsics[extr.camera_id]
89
- height = intr.height
90
- width = intr.width
91
-
92
- uid = intr.id
93
- R = np.transpose(qvec2rotmat(extr.qvec))
94
- T = np.array(extr.tvec)
95
-
96
- if intr.model=="SIMPLE_PINHOLE":
97
- focal_length_x = intr.params[0]
98
- FovY = focal2fov(focal_length_x, height)
99
- FovX = focal2fov(focal_length_x, width)
100
- elif intr.model=="PINHOLE":
101
- focal_length_x = intr.params[0]
102
- focal_length_y = intr.params[1]
103
- FovY = focal2fov(focal_length_y, height)
104
- FovX = focal2fov(focal_length_x, width)
105
- else:
106
- assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
107
-
108
- image_path = os.path.join(images_folder, os.path.basename(extr.name))
109
- image_name = os.path.basename(image_path).split(".")[0]
110
- image = Image.open(image_path)
111
-
112
- cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
113
- image_path=image_path, image_name=image_name, width=width, height=height)
114
- cam_infos.append(cam_info)
115
- sys.stdout.write('\n')
116
- return cam_infos
117
-
118
-
119
- def fetchPly(path):
120
- plydata = PlyData.read(path)
121
- vertices = plydata['vertex']
122
- idx = np.random.choice(len(vertices['x']),size=(min(len(vertices['x']), 100_000),),replace=False)
123
- positions = np.vstack([vertices['x'][idx], vertices['y'][idx], vertices['z'][idx]]).T if 'x' in vertices else None
124
- colors = np.vstack([vertices['red'][idx], vertices['green'][idx], vertices['blue'][idx]]).T / 255.0 if 'red' in vertices else None
125
- normals = np.vstack([vertices['nx'][idx], vertices['ny'][idx], vertices['nz'][idx]]).T if 'nx' in vertices else None
126
- return BasicPointCloud(points=positions, colors=colors, normals=normals)
127
-
128
-
129
- def storePly(path, xyz, rgb):
130
- # Define the dtype for the structured array
131
- dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
132
- ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
133
- ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
134
-
135
- normals = np.zeros_like(xyz)
136
-
137
- elements = np.empty(xyz.shape[0], dtype=dtype)
138
- attributes = np.concatenate((xyz, normals, rgb), axis=1)
139
- elements[:] = list(map(tuple, attributes))
140
-
141
- # Create the PlyData object and write to file
142
- vertex_element = PlyElement.describe(elements, 'vertex')
143
- ply_data = PlyData([vertex_element])
144
- ply_data.write(path)
145
-
146
-
147
- def readColmapSceneInfo(path, images, eval, preset=None, llffhold=8):
148
- try:
149
- cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
150
- cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
151
- cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
152
- cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
153
- except:
154
- cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
155
- cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
156
- cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
157
- cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
158
-
159
- reading_dir = "images" if images == None else images
160
- cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
161
- cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
162
-
163
- if eval:
164
- # train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
165
- # test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
166
- train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % 5 == 2 or idx % 5 == 0]
167
- test_cam_infos = [c for idx, c in enumerate(cam_infos) if not (idx % 5 == 2 or idx % 5 == 0)]
168
- else:
169
- train_cam_infos = cam_infos
170
- test_cam_infos = []
171
-
172
- nerf_normalization = getNerfppNorm(train_cam_infos)
173
-
174
- ply_path = os.path.join(path, "sparse/0/points3D.ply")
175
- bin_path = os.path.join(path, "sparse/0/points3D.bin")
176
- txt_path = os.path.join(path, "sparse/0/points3D.txt")
177
- if not os.path.exists(ply_path):
178
- print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
179
- try:
180
- xyz, rgb, _ = read_points3D_binary(bin_path)
181
- except:
182
- xyz, rgb, _ = read_points3D_text(txt_path)
183
- storePly(ply_path, xyz, rgb)
184
- try:
185
- pcd = fetchPly(ply_path)
186
- except:
187
- pcd = None
188
-
189
- if preset:
190
- preset_cam_infos = readCamerasFromPreset('/home/chung/workspace/gaussian-splatting/poses_supplementary', f"{preset}.json")
191
- else:
192
- preset_cam_infos = None
193
-
194
- scene_info = SceneInfo(point_cloud=pcd,
195
- train_cameras=train_cam_infos,
196
- test_cameras=test_cam_infos,
197
- preset_cameras=preset_cam_infos,
198
- nerf_normalization=nerf_normalization,
199
- ply_path=ply_path)
200
- return scene_info
201
-
202
-
203
- def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
204
- cam_infos = []
205
-
206
- with open(os.path.join(path, transformsfile)) as json_file:
207
- contents = json.load(json_file)
208
- fovx = contents["camera_angle_x"]
209
-
210
- frames = contents["frames"]
211
- for idx, frame in enumerate(frames):
212
- cam_name = os.path.join(path, frame["file_path"] + extension)
213
-
214
- # NeRF 'transform_matrix' is a camera-to-world transform
215
- c2w = np.array(frame["transform_matrix"])
216
- # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
217
- c2w[:3, 1:3] *= -1
218
-
219
- # get the world-to-camera transform and set R, T
220
- w2c = np.linalg.inv(c2w)
221
- R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
222
- T = w2c[:3, 3]
223
-
224
- image_path = os.path.join(path, cam_name)
225
- image_name = Path(cam_name).stem
226
- image = Image.open(image_path)
227
-
228
- # if os.path.exists(os.path.join(path, frame["file_path"].replace("/train/", "/depths_train/")+'.npy')):
229
- # depth = np.load(os.path.join(path, frame["file_path"].replace("/train/", "/depths_train/")+'.npy'))
230
- # if os.path.exists(os.path.join(path, frame["file_path"].replace("/train/", "/masks_train/")+'.png')):
231
- # mask = imageio.v3.imread(os.path.join(path, frame["file_path"].replace("/train/", "/masks_train/")+'.png'))[:,:,0]/255.
232
- # else:
233
- # mask = np.ones_like(depth)
234
- # final_depth = depth*mask
235
- # else:
236
- # final_depth = None
237
-
238
- im_data = np.array(image.convert("RGBA"))
239
-
240
- bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
241
-
242
- norm_data = im_data / 255.0
243
- arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
244
- image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
245
-
246
- fovy = focal2fov(fov2focal(fovx, image.size[1]), image.size[0])
247
- FovY = fovy
248
- FovX = fovx
249
-
250
- cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
251
- image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
252
-
253
- return cam_infos
254
-
255
-
256
- def readCamerasFromPreset(path, transformsfile):
257
- cam_infos = []
258
-
259
- with open(os.path.join(path, transformsfile)) as json_file:
260
- contents = json.load(json_file)
261
- FOV = contents["camera_angle_x"]*1.2
262
-
263
- frames = contents["frames"]
264
- for idx, frame in enumerate(frames):
265
- # NeRF 'transform_matrix' is a camera-to-world transform
266
- c2w = np.array(frame["transform_matrix"])
267
- # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
268
- c2w[:3, 1:3] *= -1
269
-
270
- # get the world-to-camera transform and set R, T
271
- w2c = np.linalg.inv(np.concatenate((c2w, np.array([0,0,0,1]).reshape(1,4)), axis=0))
272
- R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
273
- T = w2c[:3, 3]
274
- # R = c2w[:3,:3]
275
- # T = - np.transpose(R).dot(c2w[:3,3])
276
-
277
- image = Image.fromarray(np.zeros((512,512)), "RGB")
278
- FovY = focal2fov(fov2focal(FOV, 512), image.size[0])
279
- FovX = focal2fov(fov2focal(FOV, 512), image.size[1])
280
- # FovX, FovY = contents["camera_angle_x"], contents["camera_angle_x"]
281
-
282
- cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
283
- image_path='None', image_name='None', width=image.size[1], height=image.size[0]))
284
-
285
- return cam_infos
286
-
287
-
288
- def readNerfSyntheticInfo(path, white_background, eval, preset=None, extension=".png"):
289
- print("Reading Training Transforms")
290
- train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
291
- print("Reading Test Transforms")
292
- test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
293
-
294
- if preset:
295
- preset_cam_infos = readCamerasFromPreset('/home/chung/workspace/gaussian-splatting/poses_supplementary', f"{preset}.json")
296
- else:
297
- preset_cam_infos = None
298
-
299
- if not eval:
300
- train_cam_infos.extend(test_cam_infos)
301
- test_cam_infos = []
302
-
303
- nerf_normalization = getNerfppNorm(train_cam_infos)
304
-
305
- ply_path = os.path.join(path, "points3d.ply")
306
- if not os.path.exists(ply_path):
307
- # Since this data set has no colmap data, we start with random points
308
- num_pts = 100_000
309
- print(f"Generating random point cloud ({num_pts})...")
310
-
311
- # We create random points inside the bounds of the synthetic Blender scenes
312
- xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
313
- shs = np.random.random((num_pts, 3)) / 255.0
314
- pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
315
-
316
- storePly(ply_path, xyz, SH2RGB(shs) * 255)
317
-
318
- try:
319
- pcd = fetchPly(ply_path)
320
- except:
321
- pcd = None
322
-
323
- scene_info = SceneInfo(point_cloud=pcd,
324
- train_cameras=train_cam_infos,
325
- test_cameras=test_cam_infos,
326
- preset_cameras=preset_cam_infos,
327
- nerf_normalization=nerf_normalization,
328
- ply_path=ply_path)
329
- return scene_info
330
-
331
-
332
- def loadCamerasFromData(traindata, white_background):
333
- cameras = []
334
-
335
- fovx = traindata["camera_angle_x"]
336
- frames = traindata["frames"]
337
- for idx, frame in enumerate(frames):
338
- # NeRF 'transform_matrix' is a camera-to-world transform
339
- c2w = np.array(frame["transform_matrix"])
340
- # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
341
- c2w[:3, 1:3] *= -1
342
-
343
- # get the world-to-camera transform and set R, T
344
- w2c = np.linalg.inv(c2w)
345
- R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
346
- T = w2c[:3, 3]
347
-
348
- image = frame["image"] if "image" in frame else None
349
- im_data = np.array(image.convert("RGBA"))
350
-
351
- bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
352
-
353
- norm_data = im_data / 255.0
354
- arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
355
- image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
356
- loaded_mask = np.ones_like(norm_data[:, :, 3:4])
357
-
358
- fovy = focal2fov(fov2focal(fovx, image.size[1]), image.size[0])
359
- FovY = fovy
360
- FovX = fovx
361
-
362
- image = torch.Tensor(arr).permute(2,0,1)
363
- loaded_mask = None #torch.Tensor(loaded_mask).permute(2,0,1)
364
-
365
- ### torch로 바꿔야함
366
- cameras.append(Camera(colmap_id=idx, R=R, T=T, FoVx=FovX, FoVy=FovY, image=image,
367
- gt_alpha_mask=loaded_mask, image_name='', uid=idx, data_device='cuda'))
368
-
369
- return cameras
370
-
371
-
372
- def loadCameraPreset(traindata, presetdata):
373
- cam_infos = {}
374
- ## camera setting (for H, W and focal)
375
- fovx = traindata["camera_angle_x"] * 1.2
376
- W, H = traindata["frames"][0]["image"].size
377
- # W, H = traindata["W"], traindata["H"]
378
-
379
- for camkey in presetdata:
380
- cam_infos[camkey] = []
381
- for idx, frame in enumerate(presetdata[camkey]["frames"]):
382
- # NeRF 'transform_matrix' is a camera-to-world transform
383
- c2w = np.array(frame["transform_matrix"])
384
- # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
385
- c2w[:3, 1:3] *= -1
386
-
387
- # get the world-to-camera transform and set R, T
388
- w2c = np.linalg.inv(c2w)
389
- R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
390
- T = w2c[:3, 3]
391
-
392
- fovy = focal2fov(fov2focal(fovx, W), H)
393
- FovY = fovy
394
- FovX = fovx
395
-
396
- znear, zfar = 0.01, 100
397
- world_view_transform = torch.tensor(getWorld2View2(R, T, np.array([0.0, 0.0, 0.0]), 1.0)).transpose(0, 1).cuda()
398
- projection_matrix = getProjectionMatrix(znear=znear, zfar=zfar, fovX=FovX, fovY=FovY).transpose(0,1).cuda()
399
- full_proj_transform = (world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))).squeeze(0)
400
-
401
- cam_infos[camkey].append(MiniCam(width=W, height=H, fovy=FovY, fovx=FovX, znear=znear, zfar=zfar,
402
- world_view_transform=world_view_transform, full_proj_transform=full_proj_transform))
403
-
404
- return cam_infos
405
-
406
-
407
- def readDataInfo(traindata, white_background):
408
- print("Reading Training Transforms")
409
-
410
- train_cameras = loadCamerasFromData(traindata, white_background)
411
- preset_minicams = loadCameraPreset(traindata, presetdata=get_camerapaths())
412
-
413
- # if not eval:
414
- # train_cam_infos.extend(test_cam_infos)
415
- # test_cam_infos = []
416
-
417
- nerf_normalization = getNerfppNorm(train_cameras)
418
-
419
- pcd = BasicPointCloud(points=traindata['pcd_points'].T, colors=traindata['pcd_colors'], normals=None)
420
-
421
-
422
- scene_info = SceneInfo(point_cloud=pcd,
423
- train_cameras=train_cameras,
424
- test_cameras=[],
425
- preset_cameras=preset_minicams,
426
- nerf_normalization=nerf_normalization,
427
- ply_path='')
428
- return scene_info
429
-
430
-
431
- sceneLoadTypeCallbacks = {
432
- "Colmap": readColmapSceneInfo,
433
- "Blender" : readNerfSyntheticInfo
434
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scene/.ipynb_checkpoints/gaussian_model-checkpoint.py DELETED
@@ -1,407 +0,0 @@
1
- #
2
- # Copyright (C) 2023, Inria
3
- # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
- # All rights reserved.
5
- #
6
- # This software is free for non-commercial, research and evaluation use
7
- # under the terms of the LICENSE.md file.
8
- #
9
- # For inquiries contact george.drettakis@inria.fr
10
- #
11
- import os
12
-
13
- import numpy as np
14
- from plyfile import PlyData, PlyElement
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from simple_knn._C import distCUDA2
20
- from utils.general import inverse_sigmoid, get_expon_lr_func, build_rotation
21
- from utils.system import mkdir_p
22
- from utils.sh import RGB2SH
23
- from utils.graphics import BasicPointCloud
24
- from utils.general import strip_symmetric, build_scaling_rotation
25
-
26
-
27
- class GaussianModel:
28
- def setup_functions(self):
29
- def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
30
- L = build_scaling_rotation(scaling_modifier * scaling, rotation)
31
- actual_covariance = L @ L.transpose(1, 2)
32
- symm = strip_symmetric(actual_covariance)
33
- return symm
34
-
35
- self.scaling_activation = torch.exp
36
- self.scaling_inverse_activation = torch.log
37
-
38
- self.covariance_activation = build_covariance_from_scaling_rotation
39
-
40
- self.opacity_activation = torch.sigmoid
41
- self.inverse_opacity_activation = inverse_sigmoid
42
-
43
- self.rotation_activation = torch.nn.functional.normalize
44
-
45
-
46
- def __init__(self, sh_degree : int):
47
- self.active_sh_degree = 0
48
- self.max_sh_degree = sh_degree
49
- self._xyz = torch.empty(0)
50
- self._features_dc = torch.empty(0)
51
- self._features_rest = torch.empty(0)
52
- self._scaling = torch.empty(0)
53
- self._rotation = torch.empty(0)
54
- self._opacity = torch.empty(0)
55
- self.max_radii2D = torch.empty(0)
56
- self.xyz_gradient_accum = torch.empty(0)
57
- self.denom = torch.empty(0)
58
- self.optimizer = None
59
- self.percent_dense = 0
60
- self.spatial_lr_scale = 0
61
- self.setup_functions()
62
-
63
- def capture(self):
64
- return (
65
- self.active_sh_degree,
66
- self._xyz,
67
- self._features_dc,
68
- self._features_rest,
69
- self._scaling,
70
- self._rotation,
71
- self._opacity,
72
- self.max_radii2D,
73
- self.xyz_gradient_accum,
74
- self.denom,
75
- self.optimizer.state_dict(),
76
- self.spatial_lr_scale,
77
- )
78
-
79
- def restore(self, model_args, training_args):
80
- (self.active_sh_degree,
81
- self._xyz,
82
- self._features_dc,
83
- self._features_rest,
84
- self._scaling,
85
- self._rotation,
86
- self._opacity,
87
- self.max_radii2D,
88
- xyz_gradient_accum,
89
- denom,
90
- opt_dict,
91
- self.spatial_lr_scale) = model_args
92
- self.training_setup(training_args)
93
- self.xyz_gradient_accum = xyz_gradient_accum
94
- self.denom = denom
95
- self.optimizer.load_state_dict(opt_dict)
96
-
97
- @property
98
- def get_scaling(self):
99
- return self.scaling_activation(self._scaling)
100
-
101
- @property
102
- def get_rotation(self):
103
- return self.rotation_activation(self._rotation)
104
-
105
- @property
106
- def get_xyz(self):
107
- return self._xyz
108
-
109
- @property
110
- def get_features(self):
111
- features_dc = self._features_dc
112
- features_rest = self._features_rest
113
- return torch.cat((features_dc, features_rest), dim=1)
114
-
115
- @property
116
- def get_opacity(self):
117
- return self.opacity_activation(self._opacity)
118
-
119
- def get_covariance(self, scaling_modifier = 1):
120
- return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
121
-
122
- def oneupSHdegree(self):
123
- if self.active_sh_degree < self.max_sh_degree:
124
- self.active_sh_degree += 1
125
-
126
- def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
127
- self.spatial_lr_scale = spatial_lr_scale
128
- fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
129
- fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
130
- features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
131
- features[:, :3, 0 ] = fused_color
132
- features[:, 3:, 1:] = 0.0
133
-
134
- print("Number of points at initialisation : ", fused_point_cloud.shape[0])
135
-
136
- dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
137
- scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
138
- rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
139
- rots[:, 0] = 1
140
-
141
- opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
142
-
143
- self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
144
- self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
145
- self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
146
- self._scaling = nn.Parameter(scales.requires_grad_(True))
147
- self._rotation = nn.Parameter(rots.requires_grad_(True))
148
- self._opacity = nn.Parameter(opacities.requires_grad_(True))
149
- self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
150
-
151
- def training_setup(self, training_args):
152
- self.percent_dense = training_args.percent_dense
153
- self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
154
- self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
155
-
156
- l = [
157
- {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
158
- {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
159
- {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
160
- {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
161
- {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
162
- {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
163
- ]
164
-
165
- self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
166
- self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
167
- lr_final=training_args.position_lr_final*self.spatial_lr_scale,
168
- lr_delay_mult=training_args.position_lr_delay_mult,
169
- max_steps=training_args.position_lr_max_steps)
170
-
171
- def update_learning_rate(self, iteration):
172
- ''' Learning rate scheduling per step '''
173
- for param_group in self.optimizer.param_groups:
174
- if param_group["name"] == "xyz":
175
- lr = self.xyz_scheduler_args(iteration)
176
- param_group['lr'] = lr
177
- return lr
178
-
179
- def construct_list_of_attributes(self):
180
- l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
181
- # All channels except the 3 DC
182
- for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
183
- l.append('f_dc_{}'.format(i))
184
- for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
185
- l.append('f_rest_{}'.format(i))
186
- l.append('opacity')
187
- for i in range(self._scaling.shape[1]):
188
- l.append('scale_{}'.format(i))
189
- for i in range(self._rotation.shape[1]):
190
- l.append('rot_{}'.format(i))
191
- return l
192
-
193
- def save_ply(self, filepath):
194
- xyz = self._xyz.detach().cpu().numpy()
195
- normals = np.zeros_like(xyz)
196
- f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
197
- f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
198
- opacities = self._opacity.detach().cpu().numpy()
199
- scale = self._scaling.detach().cpu().numpy()
200
- rotation = self._rotation.detach().cpu().numpy()
201
-
202
- dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
203
-
204
- elements = np.empty(xyz.shape[0], dtype=dtype_full)
205
- attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
206
- elements[:] = list(map(tuple, attributes))
207
- el = PlyElement.describe(elements, 'vertex')
208
- PlyData([el]).write(filepath)
209
-
210
- def reset_opacity(self):
211
- opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
212
- optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
213
- self._opacity = optimizable_tensors["opacity"]
214
-
215
- def load_ply(self, path):
216
- plydata = PlyData.read(path)
217
-
218
- xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
219
- np.asarray(plydata.elements[0]["y"]),
220
- np.asarray(plydata.elements[0]["z"])), axis=1)
221
- opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
222
-
223
- features_dc = np.zeros((xyz.shape[0], 3, 1))
224
- features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
225
- features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
226
- features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
227
-
228
- extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
229
- extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
230
- assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
231
- features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
232
- for idx, attr_name in enumerate(extra_f_names):
233
- features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
234
- # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
235
- features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
236
-
237
- scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
238
- scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
239
- scales = np.zeros((xyz.shape[0], len(scale_names)))
240
- for idx, attr_name in enumerate(scale_names):
241
- scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
242
-
243
- rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
244
- rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
245
- rots = np.zeros((xyz.shape[0], len(rot_names)))
246
- for idx, attr_name in enumerate(rot_names):
247
- rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
248
-
249
- self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
250
- self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
251
- self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
252
- self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
253
- self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
254
- self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
255
-
256
- self.active_sh_degree = self.max_sh_degree
257
-
258
- def replace_tensor_to_optimizer(self, tensor, name):
259
- optimizable_tensors = {}
260
- for group in self.optimizer.param_groups:
261
- if group["name"] == name:
262
- stored_state = self.optimizer.state.get(group['params'][0], None)
263
- stored_state["exp_avg"] = torch.zeros_like(tensor)
264
- stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
265
-
266
- del self.optimizer.state[group['params'][0]]
267
- group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
268
- self.optimizer.state[group['params'][0]] = stored_state
269
-
270
- optimizable_tensors[group["name"]] = group["params"][0]
271
- return optimizable_tensors
272
-
273
- def _prune_optimizer(self, mask):
274
- optimizable_tensors = {}
275
- for group in self.optimizer.param_groups:
276
- stored_state = self.optimizer.state.get(group['params'][0], None)
277
- if stored_state is not None:
278
- stored_state["exp_avg"] = stored_state["exp_avg"][mask]
279
- stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
280
-
281
- del self.optimizer.state[group['params'][0]]
282
- group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
283
- self.optimizer.state[group['params'][0]] = stored_state
284
-
285
- optimizable_tensors[group["name"]] = group["params"][0]
286
- else:
287
- group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
288
- optimizable_tensors[group["name"]] = group["params"][0]
289
- return optimizable_tensors
290
-
291
- def prune_points(self, mask):
292
- valid_points_mask = ~mask
293
- optimizable_tensors = self._prune_optimizer(valid_points_mask)
294
-
295
- self._xyz = optimizable_tensors["xyz"]
296
- self._features_dc = optimizable_tensors["f_dc"]
297
- self._features_rest = optimizable_tensors["f_rest"]
298
- self._opacity = optimizable_tensors["opacity"]
299
- self._scaling = optimizable_tensors["scaling"]
300
- self._rotation = optimizable_tensors["rotation"]
301
-
302
- self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
303
-
304
- self.denom = self.denom[valid_points_mask]
305
- self.max_radii2D = self.max_radii2D[valid_points_mask]
306
-
307
- def cat_tensors_to_optimizer(self, tensors_dict):
308
- optimizable_tensors = {}
309
- for group in self.optimizer.param_groups:
310
- assert len(group["params"]) == 1
311
- extension_tensor = tensors_dict[group["name"]]
312
- stored_state = self.optimizer.state.get(group['params'][0], None)
313
- if stored_state is not None:
314
-
315
- stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
316
- stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
317
-
318
- del self.optimizer.state[group['params'][0]]
319
- group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
320
- self.optimizer.state[group['params'][0]] = stored_state
321
-
322
- optimizable_tensors[group["name"]] = group["params"][0]
323
- else:
324
- group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
325
- optimizable_tensors[group["name"]] = group["params"][0]
326
-
327
- return optimizable_tensors
328
-
329
- def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
330
- d = {"xyz": new_xyz,
331
- "f_dc": new_features_dc,
332
- "f_rest": new_features_rest,
333
- "opacity": new_opacities,
334
- "scaling" : new_scaling,
335
- "rotation" : new_rotation}
336
-
337
- optimizable_tensors = self.cat_tensors_to_optimizer(d)
338
- self._xyz = optimizable_tensors["xyz"]
339
- self._features_dc = optimizable_tensors["f_dc"]
340
- self._features_rest = optimizable_tensors["f_rest"]
341
- self._opacity = optimizable_tensors["opacity"]
342
- self._scaling = optimizable_tensors["scaling"]
343
- self._rotation = optimizable_tensors["rotation"]
344
-
345
- self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
346
- self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
347
- self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
348
-
349
- def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
350
- n_init_points = self.get_xyz.shape[0]
351
- # Extract points that satisfy the gradient condition
352
- padded_grad = torch.zeros((n_init_points), device="cuda")
353
- padded_grad[:grads.shape[0]] = grads.squeeze()
354
- selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
355
- selected_pts_mask = torch.logical_and(selected_pts_mask,
356
- torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
357
-
358
- stds = self.get_scaling[selected_pts_mask].repeat(N,1)
359
- means =torch.zeros((stds.size(0), 3),device="cuda")
360
- samples = torch.normal(mean=means, std=stds)
361
- rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
362
- new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
363
- new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
364
- new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
365
- new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
366
- new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
367
- new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
368
-
369
- self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
370
-
371
- prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
372
- self.prune_points(prune_filter)
373
-
374
- def densify_and_clone(self, grads, grad_threshold, scene_extent):
375
- # Extract points that satisfy the gradient condition
376
- selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
377
- selected_pts_mask = torch.logical_and(selected_pts_mask,
378
- torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
379
-
380
- new_xyz = self._xyz[selected_pts_mask]
381
- new_features_dc = self._features_dc[selected_pts_mask]
382
- new_features_rest = self._features_rest[selected_pts_mask]
383
- new_opacities = self._opacity[selected_pts_mask]
384
- new_scaling = self._scaling[selected_pts_mask]
385
- new_rotation = self._rotation[selected_pts_mask]
386
-
387
- self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
388
-
389
- def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
390
- grads = self.xyz_gradient_accum / self.denom
391
- grads[grads.isnan()] = 0.0
392
-
393
- self.densify_and_clone(grads, max_grad, extent)
394
- self.densify_and_split(grads, max_grad, extent)
395
-
396
- prune_mask = (self.get_opacity < min_opacity).squeeze()
397
- if max_screen_size:
398
- big_points_vs = self.max_radii2D > max_screen_size
399
- big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
400
- prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
401
- self.prune_points(prune_mask)
402
-
403
- torch.cuda.empty_cache()
404
-
405
- def add_densification_stats(self, viewspace_point_tensor, update_filter):
406
- self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
407
- self.denom[update_filter] += 1