|
|
"""Contains definition of RGB-only gaussian predictor. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from sharp.models.monodepth import MonodepthWithEncodingAdaptor |
|
|
from sharp.utils.gaussians import Gaussians3D |
|
|
|
|
|
from .composer import GaussianComposer |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DepthAlignment(nn.Module): |
|
|
"""Depth alignment in a dedicated nn.Module. |
|
|
|
|
|
Wrap scale_map_estimator to perform the conditional logic in a separated torch |
|
|
module outside the forward of RGBGaussianPredictor. This module can be then |
|
|
excluded during symbolic tracing. |
|
|
""" |
|
|
|
|
|
def __init__(self, scale_map_estimator: nn.Module | None): |
|
|
"""Initialize DepthAlignmentWrapper. |
|
|
|
|
|
Args: |
|
|
scale_map_estimator: Module to align monodepth to ground truth depth. |
|
|
""" |
|
|
super().__init__() |
|
|
self.scale_map_estimator = scale_map_estimator |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
monodepth: torch.Tensor, |
|
|
depth: torch.Tensor, |
|
|
depth_decoder_features: torch.Tensor | None = None, |
|
|
): |
|
|
"""Optionally align monodepth to ground truth with a local scale map. |
|
|
|
|
|
Args: |
|
|
monodepth: The monodepth model with intermediate features to use. |
|
|
depth: Ground truth depth to align predicted depth to. |
|
|
depth_decoder_features: The (optional) monodepth decoder features. |
|
|
""" |
|
|
if depth is not None and self.scale_map_estimator is not None: |
|
|
depth_alignment_map = self.scale_map_estimator( |
|
|
monodepth[:, 0:1], depth, depth_decoder_features |
|
|
) |
|
|
monodepth = depth_alignment_map * monodepth |
|
|
else: |
|
|
|
|
|
|
|
|
depth_alignment_map = torch.ones_like(monodepth) |
|
|
return monodepth, depth_alignment_map |
|
|
|
|
|
|
|
|
class RGBGaussianPredictor(nn.Module): |
|
|
"""Predicts 3D Gaussians from images.""" |
|
|
|
|
|
feature_model: nn.Module |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
init_model: nn.Module, |
|
|
monodepth_model: MonodepthWithEncodingAdaptor, |
|
|
feature_model: nn.Module, |
|
|
prediction_head: nn.Module, |
|
|
gaussian_composer: GaussianComposer, |
|
|
scale_map_estimator: nn.Module | None, |
|
|
) -> None: |
|
|
"""Initialize RGBGaussianPredictor. |
|
|
|
|
|
Args: |
|
|
init_model: A model mapping image and depth to base values. |
|
|
monodepth_model: The monodepth model with intermediate features to use. |
|
|
feature_model: The image2image model to predict Gaussians from. |
|
|
prediction_head: Head to decode image features. |
|
|
gaussian_composer: Module to compose final prediction from deltas and |
|
|
base values. |
|
|
scale_map_estimator: Module to align monodepth to ground truth depth. |
|
|
|
|
|
Note: |
|
|
---- |
|
|
when monodepth_model is trainable, using local depth alignment can |
|
|
result in the monodepth model losing its ability to predict shapes. It is |
|
|
hence recommend to deactivate the corresponding flag. |
|
|
""" |
|
|
super().__init__() |
|
|
self.init_model = init_model |
|
|
self.feature_model = feature_model |
|
|
self.monodepth_model = monodepth_model |
|
|
self.prediction_head = prediction_head |
|
|
self.gaussian_composer = gaussian_composer |
|
|
self.depth_alignment = DepthAlignment(scale_map_estimator) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image: torch.Tensor, |
|
|
disparity_factor: torch.Tensor, |
|
|
depth: torch.Tensor | None = None, |
|
|
) -> Gaussians3D: |
|
|
"""Predict 3D Gaussians. |
|
|
|
|
|
Args: |
|
|
image: The image to process. |
|
|
disparity_factor: Factor to convert depth to disparities. |
|
|
depth: Ground truth depth to align predicted depth to. |
|
|
|
|
|
Returns: |
|
|
The predicted 3D Gaussians. |
|
|
|
|
|
Note: |
|
|
---- |
|
|
During training, it is recommended to feed an additional ground truth depth |
|
|
map to the network to align the predicted depth to. During inference, it is |
|
|
recommended to use depth_gt=None and use monodepth_disparity output from the |
|
|
model instead to compute depth. |
|
|
""" |
|
|
|
|
|
monodepth_output = self.monodepth_model(image) |
|
|
monodepth_disparity = monodepth_output.disparity |
|
|
|
|
|
disparity_factor = disparity_factor[:, None, None, None] |
|
|
monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
monodepth, _ = self.depth_alignment( |
|
|
monodepth, |
|
|
depth, |
|
|
monodepth_output.decoder_features, |
|
|
) |
|
|
|
|
|
init_output = self.init_model(image, monodepth) |
|
|
image_features = self.feature_model( |
|
|
init_output.feature_input, encodings=monodepth_output.output_features |
|
|
) |
|
|
delta_values = self.prediction_head(image_features) |
|
|
gaussians = self.gaussian_composer( |
|
|
delta=delta_values, |
|
|
base_values=init_output.gaussian_base_values, |
|
|
global_scale=init_output.global_scale, |
|
|
) |
|
|
return gaussians |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Internal resolution.""" |
|
|
return self.monodepth_model.internal_resolution() |
|
|
|
|
|
@property |
|
|
def output_resolution(self) -> int: |
|
|
"""Output resolution of Gaussians.""" |
|
|
return self.internal_resolution() // 2 |
|
|
|