|
"""Copyright (C) 2024 Apple Inc. All Rights Reserved. |
|
|
|
Dense Prediction Transformer Decoder architecture. |
|
|
|
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413 |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
from typing import Iterable |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
class MultiresConvDecoder(nn.Module): |
|
"""Decoder for multi-resolution encodings.""" |
|
|
|
def __init__( |
|
self, |
|
dims_encoder: Iterable[int], |
|
dim_decoder: int, |
|
): |
|
"""Initialize multiresolution convolutional decoder. |
|
|
|
Args: |
|
---- |
|
dims_encoder: Expected dims at each level from the encoder. |
|
dim_decoder: Dim of decoder features. |
|
|
|
""" |
|
super().__init__() |
|
self.dims_encoder = list(dims_encoder) |
|
self.dim_decoder = dim_decoder |
|
self.dim_out = dim_decoder |
|
|
|
num_encoders = len(self.dims_encoder) |
|
|
|
|
|
|
|
|
|
conv0 = ( |
|
nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False) |
|
if self.dims_encoder[0] != dim_decoder |
|
else nn.Identity() |
|
) |
|
|
|
convs = [conv0] |
|
for i in range(1, num_encoders): |
|
convs.append( |
|
nn.Conv2d( |
|
self.dims_encoder[i], |
|
dim_decoder, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=False, |
|
) |
|
) |
|
|
|
self.convs = nn.ModuleList(convs) |
|
|
|
fusions = [] |
|
for i in range(num_encoders): |
|
fusions.append( |
|
FeatureFusionBlock2d( |
|
num_features=dim_decoder, |
|
deconv=(i != 0), |
|
batch_norm=False, |
|
) |
|
) |
|
self.fusions = nn.ModuleList(fusions) |
|
|
|
def forward(self, encodings: torch.Tensor) -> torch.Tensor: |
|
"""Decode the multi-resolution encodings.""" |
|
num_levels = len(encodings) |
|
num_encoders = len(self.dims_encoder) |
|
|
|
if num_levels != num_encoders: |
|
raise ValueError( |
|
f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}." |
|
) |
|
|
|
|
|
|
|
|
|
features = self.convs[-1](encodings[-1]) |
|
lowres_features = features |
|
features = self.fusions[-1](features) |
|
for i in range(num_levels - 2, -1, -1): |
|
features_i = self.convs[i](encodings[i]) |
|
features = self.fusions[i](features, features_i) |
|
return features, lowres_features |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
"""Generic implementation of residual blocks. |
|
|
|
This implements a generic residual block from |
|
He et al. - Identity Mappings in Deep Residual Networks (2016), |
|
https://arxiv.org/abs/1603.05027 |
|
which can be further customized via factory functions. |
|
""" |
|
|
|
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None: |
|
"""Initialize ResidualBlock.""" |
|
super().__init__() |
|
self.residual = residual |
|
self.shortcut = shortcut |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Apply residual block.""" |
|
delta_x = self.residual(x) |
|
|
|
if self.shortcut is not None: |
|
x = self.shortcut(x) |
|
|
|
return x + delta_x |
|
|
|
|
|
class FeatureFusionBlock2d(nn.Module): |
|
"""Feature fusion for DPT.""" |
|
|
|
def __init__( |
|
self, |
|
num_features: int, |
|
deconv: bool = False, |
|
batch_norm: bool = False, |
|
): |
|
"""Initialize feature fusion block. |
|
|
|
Args: |
|
---- |
|
num_features: Input and output dimensions. |
|
deconv: Whether to use deconv before the final output conv. |
|
batch_norm: Whether to use batch normalization in resnet blocks. |
|
|
|
""" |
|
super().__init__() |
|
|
|
self.resnet1 = self._residual_block(num_features, batch_norm) |
|
self.resnet2 = self._residual_block(num_features, batch_norm) |
|
|
|
self.use_deconv = deconv |
|
if deconv: |
|
self.deconv = nn.ConvTranspose2d( |
|
in_channels=num_features, |
|
out_channels=num_features, |
|
kernel_size=2, |
|
stride=2, |
|
padding=0, |
|
bias=False, |
|
) |
|
|
|
self.out_conv = nn.Conv2d( |
|
num_features, |
|
num_features, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=True, |
|
) |
|
|
|
self.skip_add = nn.quantized.FloatFunctional() |
|
|
|
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor: |
|
"""Process and fuse input features.""" |
|
x = x0 |
|
|
|
if x1 is not None: |
|
res = self.resnet1(x1) |
|
x = self.skip_add.add(x, res) |
|
|
|
x = self.resnet2(x) |
|
|
|
if self.use_deconv: |
|
x = self.deconv(x) |
|
x = self.out_conv(x) |
|
|
|
return x |
|
|
|
@staticmethod |
|
def _residual_block(num_features: int, batch_norm: bool): |
|
"""Create a residual block.""" |
|
|
|
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]: |
|
layers = [ |
|
nn.ReLU(False), |
|
nn.Conv2d( |
|
num_features, |
|
num_features, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
bias=not batch_norm, |
|
), |
|
] |
|
if batch_norm: |
|
layers.append(nn.BatchNorm2d(dim)) |
|
return layers |
|
|
|
residual = nn.Sequential( |
|
*_create_block(dim=num_features, batch_norm=batch_norm), |
|
*_create_block(dim=num_features, batch_norm=batch_norm), |
|
) |
|
return ResidualBlock(residual) |
|
|