|
|
"""Contains Dense Transformer Prediction architecture. |
|
|
|
|
|
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 |
|
|
|
|
|
For licensing see accompanying LICENSE file. |
|
|
Copyright (C) 2025 Apple Inc. All Rights Reserved. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from sharp.models.presets import ( |
|
|
MONODEPTH_ENCODER_DIMS_MAP, |
|
|
MONODEPTH_HOOK_IDS_MAP, |
|
|
ViTPreset, |
|
|
) |
|
|
|
|
|
from .base_encoder import BaseEncoder |
|
|
from .spn_encoder import SlidingPyramidNetwork |
|
|
from .vit_encoder import create_vit |
|
|
|
|
|
|
|
|
def create_monodepth_encoder( |
|
|
patch_encoder_preset: ViTPreset, |
|
|
image_encoder_preset: ViTPreset, |
|
|
use_patch_overlap: bool = True, |
|
|
last_encoder: int = 256, |
|
|
) -> SlidingPyramidNetwork: |
|
|
"""Creates DepthDensePredictionTransformer model. |
|
|
|
|
|
Args: |
|
|
patch_encoder_preset: The preset patch encoder architecture in SPN. |
|
|
image_encoder_preset: The preset image encoder architecture in SPN. |
|
|
use_patch_overlap: Whether to use overlap between patches in SPN. |
|
|
last_encoder: last number of encoder features. |
|
|
""" |
|
|
dims_encoder = [last_encoder] + MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset] |
|
|
patch_encoder_block_ids = MONODEPTH_HOOK_IDS_MAP[patch_encoder_preset] |
|
|
|
|
|
patch_encoder = create_vit( |
|
|
preset=patch_encoder_preset, |
|
|
intermediate_features_ids=patch_encoder_block_ids, |
|
|
|
|
|
) |
|
|
image_encoder = create_vit( |
|
|
preset=image_encoder_preset, |
|
|
intermediate_features_ids=None, |
|
|
) |
|
|
|
|
|
encoder = SlidingPyramidNetwork( |
|
|
dims_encoder=dims_encoder, |
|
|
patch_encoder=patch_encoder, |
|
|
image_encoder=image_encoder, |
|
|
use_patch_overlap=use_patch_overlap, |
|
|
) |
|
|
|
|
|
return encoder |
|
|
|
|
|
|
|
|
class ProjectionModule(nn.Module): |
|
|
"""Apply projection of features.""" |
|
|
|
|
|
def __init__(self, dims_in: list[int], dims_out: list[int]) -> None: |
|
|
"""Initialize projection module.""" |
|
|
super().__init__() |
|
|
if len(dims_in) != len(dims_out): |
|
|
raise ValueError("Length of dims_in must be same as length of dims_out.") |
|
|
self.convs = nn.ModuleList( |
|
|
[nn.Conv2d(dim_in, dim_out, 1) for dim_in, dim_out in zip(dims_in, dims_out)] |
|
|
) |
|
|
|
|
|
def forward(self, encodings: list[torch.Tensor]) -> list[torch.Tensor]: |
|
|
"""Apply projection module.""" |
|
|
if len(encodings) != len(self.convs): |
|
|
raise ValueError("Number of encodings must be equal to number of projections.") |
|
|
return [conv(encoding) for conv, encoding in zip(self.convs, encodings)] |
|
|
|
|
|
|
|
|
class MonodepthFeatureEncoder(BaseEncoder): |
|
|
"""A wrapper around monodepth network to extract features.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
monodepth_encoder: SlidingPyramidNetwork, |
|
|
output_dims: list[int] | None = None, |
|
|
freeze_projection: bool = False, |
|
|
) -> None: |
|
|
"""Initialize MonodepthFeatureExtractor.""" |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = monodepth_encoder |
|
|
|
|
|
|
|
|
|
|
|
monodepth_dims = self.encoder.dims_encoder |
|
|
monodepth_dims = monodepth_dims |
|
|
|
|
|
if output_dims is not None: |
|
|
if not len(output_dims) == len(monodepth_dims): |
|
|
raise ValueError( |
|
|
"When set, number of output dimensions must be equal to output " |
|
|
f"dimensions of monodepth model {len(monodepth_dims)}." |
|
|
) |
|
|
|
|
|
self.projection = ProjectionModule(monodepth_dims, output_dims) |
|
|
self.output_dims = output_dims |
|
|
else: |
|
|
self.projection = nn.Identity() |
|
|
self.output_dims = monodepth_dims |
|
|
|
|
|
if freeze_projection: |
|
|
self.projection.requires_grad_(False) |
|
|
|
|
|
def forward(self, input_features: torch.Tensor) -> list[torch.Tensor]: |
|
|
"""Extract multi-resolution features.""" |
|
|
encodings = self.encoder(input_features[:, :3].contiguous()) |
|
|
return self.projection(encodings) |
|
|
|
|
|
def internal_resolution(self) -> int: |
|
|
"""Internal resolution of the encoder.""" |
|
|
return self.encoder.internal_resolution() |
|
|
|