|
|
"""Contains decoder head for direct prediction of delta values. |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from .gaussian_decoder import ImageFeatures |
|
|
|
|
|
|
|
|
class DirectPredictionHead(nn.Module): |
|
|
"""Decodes features into delta values using convolutions.""" |
|
|
|
|
|
def __init__(self, feature_dim: int, num_layers: int) -> None: |
|
|
"""Initialize DirectGaussianPredictor. |
|
|
|
|
|
Args: |
|
|
feature_dim: Number of input features. |
|
|
num_layers: The number of layers of Gaussians to predict. |
|
|
""" |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
|
|
|
|
|
|
self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1) |
|
|
self.geometry_prediction_head.weight.data.zero_() |
|
|
assert self.geometry_prediction_head.bias is not None |
|
|
self.geometry_prediction_head.bias.data.zero_() |
|
|
|
|
|
self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1) |
|
|
self.texture_prediction_head.weight.data.zero_() |
|
|
assert self.texture_prediction_head.bias is not None |
|
|
self.texture_prediction_head.bias.data.zero_() |
|
|
|
|
|
def forward(self, image_features: ImageFeatures) -> torch.Tensor: |
|
|
"""Predict deltas for 3D Gaussians. |
|
|
|
|
|
Args: |
|
|
image_features: Image features from decoder. |
|
|
|
|
|
Returns: |
|
|
The predicted deltas for Gaussian attributes. |
|
|
""" |
|
|
delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features) |
|
|
delta_values_texture = self.texture_prediction_head(image_features.texture_features) |
|
|
delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers)) |
|
|
delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers)) |
|
|
delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1) |
|
|
return delta_values |
|
|
|