File size: 11,619 Bytes
e2ebf5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import sys
sys.path.append('./gaussian_splatting')
import os
import torch
import plotly.graph_objs as go
from sugar.gaussian_splatting.scene.gaussian_model import GaussianModel
from sugar.gaussian_splatting.gaussian_renderer import render as gs_render
from sugar.gaussian_splatting.scene.dataset_readers import fetchPly
from sugar.sugar_utils.spherical_harmonics import SH2RGB
from sugar.sugar_scene.cameras import CamerasWrapper, load_gs_cameras


class ModelParams(): 
    """Parameters of the Gaussian Splatting model.
    Largely inspired by the original implementation of the 3D Gaussian Splatting paper:
    https://github.com/graphdeco-inria/gaussian-splatting
    """
    def __init__(self):
        self.sh_degree = 3
        self.source_path = ""
        self.model_path = ""
        self.images = "images"
        self.resolution = -1
        self.white_background = False
        self.data_device = "cuda"
        self.eval = False
    
        
class PipelineParams():
    """Parameters of the Gaussian Splatting pipeline.
    Largely inspired by the original implementation of the 3D Gaussian Splatting paper:
    https://github.com/graphdeco-inria/gaussian-splatting
    """
    def __init__(self):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.debug = False


class OptimizationParams():
    """Parameters of the Gaussian Splatting optimization.
    Largely inspired by the original implementation of the 3D Gaussian Splatting paper:
    https://github.com/graphdeco-inria/gaussian-splatting
    """
    def __init__(self):
        self.iterations = 30_000
        self.position_lr_init = 0.00016
        self.position_lr_final = 0.0000016
        self.position_lr_delay_mult = 0.01
        self.position_lr_max_steps = 30_000
        self.feature_lr = 0.0025
        self.opacity_lr = 0.05
        self.scaling_lr = 0.005
        self.rotation_lr = 0.001
        self.percent_dense = 0.01
        self.lambda_dssim = 0.2
        self.densification_interval = 100
        self.opacity_reset_interval = 3000
        self.densify_from_iter = 500
        self.densify_until_iter = 15_000
        self.densify_grad_threshold = 0.0002


class GaussianSplattingWrapper:
    """Class to wrap original Gaussian Splatting models and facilitates both usage and integration with PyTorch3D.
    """
    def __init__(self, 
                 source_path: str,
                 output_path: str,
                 iteration_to_load:int=30_000,
                 model_params: ModelParams=None,
                 pipeline_params: PipelineParams=None,
                 opt_params: OptimizationParams=None,
                 load_gt_images=True,
                 eval_split=False,
                 eval_split_interval=8,
                 ) -> None:
        """Initialize the Gaussian Splatting model wrapper.
        
        Args:
            source_path (str): Path to the directory containing the source images.
            output_path (str): Path to the directory containing the output of the Gaussian Splatting optimization.
            iteration_to_load (int, optional): Checkpoint to load. Should be 7000 or 30_000. Defaults to 30_000.
            model_params (ModelParams, optional): Model parameters. Defaults to None.
            pipeline_params (PipelineParams, optional): Pipeline parameters. Defaults to None.
            opt_params (OptimizationParams, optional): Optimization parameters. Defaults to None.
            load_gt_images (bool, optional): If True, will load all GT images in the source folder.
                Useful for evaluating the model, but loading can take a few minutes. Defaults to True.
            eval_split (bool, optional): If True, will split images and cameras into a training set and an evaluation set. 
                Defaults to False.
            eval_split_interval (int, optional): Every eval_split_interval images, an image is added to the evaluation set. 
                Defaults to 8 (following standard practice).
        """
        self.source_path = source_path
        self.output_path = output_path
        self.loaded_iteration = iteration_to_load
        
        if model_params is None:
            model_params = ModelParams()
        if pipeline_params is None:
            pipeline_params = PipelineParams()
        if opt_params is None:
            opt_params = OptimizationParams()
        
        self.model_params = model_params
        self.pipeline_params = pipeline_params
        self.opt_params = opt_params
        
        self._C0 = 0.28209479177387814
        
        cam_list = load_gs_cameras(
            source_path=source_path,
            gs_output_path=output_path,
            load_gt_images=load_gt_images,
            )
        
        if eval_split:
            self.cam_list = []
            self.test_cam_list = []
            for i, cam in enumerate(cam_list):
                if i % eval_split_interval == 0:
                    self.test_cam_list.append(cam)
                else:
                    self.cam_list.append(cam)
            # test_ns_cameras = convert_camera_from_gs_to_nerfstudio(self.test_cam_list)
            # self.test_cameras = NeRFCameras.from_ns_cameras(test_ns_cameras)
            self.test_cameras = CamerasWrapper(self.test_cam_list)

        else:
            self.cam_list = cam_list
            self.test_cam_list = None
            self.test_cameras = None
            
        # ns_cameras = convert_camera_from_gs_to_nerfstudio(self.cam_list)
        # self.training_cameras = NeRFCameras.from_ns_cameras(ns_cameras)
        self.training_cameras = CamerasWrapper(self.cam_list)
            
        self.gaussians = GaussianModel(self.model_params.sh_degree)
        self.gaussians.load_ply(
            os.path.join(
                output_path,
                "point_cloud",
                "iteration_" + str(iteration_to_load),
                "point_cloud.ply"
                )
            )

    @property
    def device(self):
        with torch.no_grad():
            return self.gaussians.get_xyz.device
    
    @property
    def image_height(self):
        return self.cam_list[0].image_height
    
    @property
    def image_width(self):
        return self.cam_list[0].image_width
    
    def render_image(
        self,
        nerf_cameras:CamerasWrapper=None, 
        camera_indices:int=0,
        return_whole_package=False):
        """Render an image with Gaussian Splatting rasterizer.

        Args:
            nerf_cameras (CamerasWrapper, optional): Set of cameras. 
                If None, uses the training cameras, but can be any set of cameras. Defaults to None.
            camera_indices (int, optional): Index of the camera to render in the set of cameras. 
                Defaults to 0.
            return_whole_package (bool, optional): If True, returns the whole output package 
                as computed in the original rasterizer from 3D Gaussian Splatting paper. Defaults to False.

        Returns:
            Tensor or Dict: A tensor of the rendered RGB image, or the whole output package.
        """
        
        if nerf_cameras is None:
            gs_cameras = self.cam_list
        else:
            gs_cameras = nerf_cameras.gs_cameras
        
        camera = gs_cameras[camera_indices]
        render_pkg = gs_render(camera, self.gaussians, 
                            self.pipeline_params, 
                            bg_color=torch.zeros(3, device='cuda'))
        
        if return_whole_package:
            return render_pkg
        else:
            image = render_pkg["render"]
            return image.permute(1, 2, 0)
    
    def get_gt_image(self, camera_indices:int, to_cuda=False):
        """Returns the ground truth image corresponding to the training camera at the given index.

        Args:
            camera_indices (int): Index of the camera in the set of cameras.
            to_cuda (bool, optional): If True, moves the image to GPU. Defaults to False.

        Returns:
            Tensor: The ground truth image.
        """
        gt_image = self.cam_list[camera_indices].original_image
        if to_cuda:
            gt_image = gt_image.cuda()
        return gt_image.permute(1, 2, 0)
    
    def get_test_gt_image(self, camera_indices:int, to_cuda=False):
        """Returns the ground truth image corresponding to the test camera at the given index.
        
        Args:
            camera_indices (int): Index of the camera in the set of cameras.
            to_cuda (bool, optional): If True, moves the image to GPU. Defaults to False.
        
        Returns:
            Tensor: The ground truth image.
        """
        gt_image = self.test_cam_list[camera_indices].original_image
        if to_cuda:
            gt_image = gt_image.cuda()
        return gt_image.permute(1, 2, 0)
    
    def downscale_output_resolution(self, downscale_factor):
        """Downscale the output resolution of the Gaussian Splatting model.

        Args:
            downscale_factor (float): Factor by which to downscale the resolution.
        """
        self.training_cameras.rescale_output_resolution(1.0 / downscale_factor)
    
    def generate_point_cloud(self):
        """Generate a point cloud from the Gaussian Splatting model.

        Returns:
            (Tensor, Tensor): The points and the colors of the point cloud.
                Each has shape (N, 3), where N is the number of Gaussians.
        """
        with torch.no_grad():
            points = self.gaussians.get_xyz
            # colors = self.gaussians.get_features[:, 0] * self._C0 + 0.5
            colors = SH2RGB(self.gaussians.get_features[:, 0])
            
        return points, colors
    
    def plot_point_cloud(
        self,
        points=None,
        colors=None,
        n_points_to_plot: int = 50000,
        width=1000,
        height=500,
    ):
        """Plot the generated 3D point cloud with plotly.

        Args:
            n_points_to_plot (int, optional): _description_. Defaults to 50000.
            points (_type_, optional): _description_. Defaults to None.
            colors (_type_, optional): _description_. Defaults to None.
            width (int, optional): Defaults to 1000.
            height (int, optional): Defaults to 1000.

        Raises:
            ValueError: _description_

        Returns:
            go.Figure: The plotly figure.
        """
        
        with torch.no_grad():
            if points is None:
                points, colors = self.generate_point_cloud()

            points_idx = torch.randperm(points.shape[0])[:n_points_to_plot]
            points_to_plot = points[points_idx].cpu()
            colors_to_plot = colors[points_idx].cpu()

            z = points_to_plot[:, 2]
            x = points_to_plot[:, 0]
            y = points_to_plot[:, 1]
            trace = go.Scatter3d(
                x=x,
                y=y,
                z=z,
                mode="markers",
                marker=dict(
                    size=3,
                    color=colors_to_plot,  # set color to an array/list of desired values
                    # colorscale = 'Magma'
                ),
            )
            layout = go.Layout(
                scene=dict(bgcolor="white", aspectmode="data"),
                template="none",
                width=width,
                height=height,
            )
            fig = go.Figure(data=[trace], layout=layout)
            # fig.update_layout(template='none', scene_aspectmode='data')

            # fig.show()
            return fig