ml-sharp / src /sharp /models /predictor.py
amael-apple's picture
Initial commit
c20d7cc
"""Contains definition of RGB-only gaussian predictor.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
import logging
import torch
from torch import nn
from sharp.models.monodepth import MonodepthWithEncodingAdaptor
from sharp.utils.gaussians import Gaussians3D
from .composer import GaussianComposer
LOGGER = logging.getLogger(__name__)
class DepthAlignment(nn.Module):
"""Depth alignment in a dedicated nn.Module.
Wrap scale_map_estimator to perform the conditional logic in a separated torch
module outside the forward of RGBGaussianPredictor. This module can be then
excluded during symbolic tracing.
"""
def __init__(self, scale_map_estimator: nn.Module | None):
"""Initialize DepthAlignmentWrapper.
Args:
scale_map_estimator: Module to align monodepth to ground truth depth.
"""
super().__init__()
self.scale_map_estimator = scale_map_estimator
def forward(
self,
monodepth: torch.Tensor,
depth: torch.Tensor,
depth_decoder_features: torch.Tensor | None = None,
):
"""Optionally align monodepth to ground truth with a local scale map.
Args:
monodepth: The monodepth model with intermediate features to use.
depth: Ground truth depth to align predicted depth to.
depth_decoder_features: The (optional) monodepth decoder features.
"""
if depth is not None and self.scale_map_estimator is not None:
depth_alignment_map = self.scale_map_estimator(
monodepth[:, 0:1], depth, depth_decoder_features
)
monodepth = depth_alignment_map * monodepth
else:
# Some losses rely on the presence of an alignment map.
# We ensure that they can be computed by creating a fake alignment map.
depth_alignment_map = torch.ones_like(monodepth)
return monodepth, depth_alignment_map
class RGBGaussianPredictor(nn.Module):
"""Predicts 3D Gaussians from images."""
feature_model: nn.Module
def __init__(
self,
init_model: nn.Module,
monodepth_model: MonodepthWithEncodingAdaptor,
feature_model: nn.Module,
prediction_head: nn.Module,
gaussian_composer: GaussianComposer,
scale_map_estimator: nn.Module | None,
) -> None:
"""Initialize RGBGaussianPredictor.
Args:
init_model: A model mapping image and depth to base values.
monodepth_model: The monodepth model with intermediate features to use.
feature_model: The image2image model to predict Gaussians from.
prediction_head: Head to decode image features.
gaussian_composer: Module to compose final prediction from deltas and
base values.
scale_map_estimator: Module to align monodepth to ground truth depth.
Note:
----
when monodepth_model is trainable, using local depth alignment can
result in the monodepth model losing its ability to predict shapes. It is
hence recommend to deactivate the corresponding flag.
"""
super().__init__()
self.init_model = init_model
self.feature_model = feature_model
self.monodepth_model = monodepth_model
self.prediction_head = prediction_head
self.gaussian_composer = gaussian_composer
self.depth_alignment = DepthAlignment(scale_map_estimator)
def forward(
self,
image: torch.Tensor,
disparity_factor: torch.Tensor,
depth: torch.Tensor | None = None,
) -> Gaussians3D:
"""Predict 3D Gaussians.
Args:
image: The image to process.
disparity_factor: Factor to convert depth to disparities.
depth: Ground truth depth to align predicted depth to.
Returns:
The predicted 3D Gaussians.
Note:
----
During training, it is recommended to feed an additional ground truth depth
map to the network to align the predicted depth to. During inference, it is
recommended to use depth_gt=None and use monodepth_disparity output from the
model instead to compute depth.
"""
# Estimate depth and align to ground truth (if available).
monodepth_output = self.monodepth_model(image)
monodepth_disparity = monodepth_output.disparity
disparity_factor = disparity_factor[:, None, None, None]
monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4)
# In the model we apply additional alignment to provided ground truth depth
# as well as additional normalization.
#
# The overall graph looks as follows:
#
# monodepth depth # Both monodepth and depth are metric here.
# | |
# +------+-------+
# |
# +-------+--------+ # Optionally align monodepth to ground truth
# |depth_alignement| # with a local scale map.
# +-------+--------+
# |
# v
# monodepth (aligned) # Monodepth is now aligned to ground truth.
# |
# +-----+----+ # Normalize depth and compute base gaussians.
# |init_model| # in these normalized coordinates.
# +-----+----+
# |
# v
# +------ init_output # Init_output consists of features, base
# | | # gaussians and a global scale.
# | +------+-----+
# | |main network| # Compute delta values to base gaussians.
# | +------+-----+
# | |
# | V
# | delta_values # The delta values are computed with normalized depth.
# | |
# | +-------+---------+
# +--> |gaussian_composer| # Add delta to base values and unscale gaussians.
# +-------+---------+
# |
# v
# gaussians # The final Gaussians are metric again.
#
# The logic to decide whether to align monodepth to the ground truth is wrapped
# in a submodule 'DepthAlignement' to facilitate the symbolic tracing of the
# predictor. This way, the depth alignment submodule containing the conditional
# logic can be excluded during the tracing and the graph of the predictors is
# static.
monodepth, _ = self.depth_alignment(
monodepth,
depth,
monodepth_output.decoder_features,
)
init_output = self.init_model(image, monodepth)
image_features = self.feature_model(
init_output.feature_input, encodings=monodepth_output.output_features
)
delta_values = self.prediction_head(image_features)
gaussians = self.gaussian_composer(
delta=delta_values,
base_values=init_output.gaussian_base_values,
global_scale=init_output.global_scale,
)
return gaussians
def internal_resolution(self) -> int:
"""Internal resolution."""
return self.monodepth_model.internal_resolution()
@property
def output_resolution(self) -> int:
"""Output resolution of Gaussians."""
return self.internal_resolution() // 2