|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import math |
|
from typing import Iterable, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class DepthProEncoder(nn.Module): |
|
"""DepthPro Encoder. |
|
|
|
An encoder aimed at creating multi-resolution encodings from Vision Transformers. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dims_encoder: Iterable[int], |
|
patch_encoder: nn.Module, |
|
image_encoder: nn.Module, |
|
hook_block_ids: Iterable[int], |
|
decoder_features: int, |
|
): |
|
"""Initialize DepthProEncoder. |
|
|
|
The framework |
|
1. creates an image pyramid, |
|
2. generates overlapping patches with a sliding window at each pyramid level, |
|
3. creates batched encodings via vision transformer backbones, |
|
4. produces multi-resolution encodings. |
|
|
|
Args: |
|
---- |
|
img_size: Backbone image resolution. |
|
dims_encoder: Dimensions of the encoder at different layers. |
|
patch_encoder: Backbone used for patches. |
|
image_encoder: Backbone used for global image encoder. |
|
hook_block_ids: Hooks to obtain intermediate features for the patch encoder model. |
|
decoder_features: Number of feature output in the decoder. |
|
|
|
""" |
|
super().__init__() |
|
|
|
self.dims_encoder = list(dims_encoder) |
|
self.patch_encoder = patch_encoder |
|
self.image_encoder = image_encoder |
|
self.hook_block_ids = list(hook_block_ids) |
|
|
|
patch_encoder_embed_dim = patch_encoder.embed_dim |
|
image_encoder_embed_dim = image_encoder.embed_dim |
|
|
|
self.out_size = int( |
|
patch_encoder.patch_embed.img_size[0] // patch_encoder.patch_embed.patch_size[0] |
|
) |
|
|
|
def _create_project_upsample_block( |
|
dim_in: int, |
|
dim_out: int, |
|
upsample_layers: int, |
|
dim_int: Optional[int] = None, |
|
) -> nn.Module: |
|
if dim_int is None: |
|
dim_int = dim_out |
|
|
|
blocks = [ |
|
nn.Conv2d( |
|
in_channels=dim_in, |
|
out_channels=dim_int, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
) |
|
] |
|
|
|
|
|
blocks += [ |
|
nn.ConvTranspose2d( |
|
in_channels=dim_int if i == 0 else dim_out, |
|
out_channels=dim_out, |
|
kernel_size=2, |
|
stride=2, |
|
padding=0, |
|
bias=False, |
|
) |
|
for i in range(upsample_layers) |
|
] |
|
|
|
return nn.Sequential(*blocks) |
|
|
|
self.upsample_latent0 = _create_project_upsample_block( |
|
dim_in=patch_encoder_embed_dim, |
|
dim_int=self.dims_encoder[0], |
|
dim_out=decoder_features, |
|
upsample_layers=3, |
|
) |
|
self.upsample_latent1 = _create_project_upsample_block( |
|
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[0], upsample_layers=2 |
|
) |
|
|
|
self.upsample0 = _create_project_upsample_block( |
|
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=1 |
|
) |
|
self.upsample1 = _create_project_upsample_block( |
|
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1 |
|
) |
|
self.upsample2 = _create_project_upsample_block( |
|
dim_in=patch_encoder_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1 |
|
) |
|
|
|
self.upsample_lowres = nn.ConvTranspose2d( |
|
in_channels=image_encoder_embed_dim, |
|
out_channels=self.dims_encoder[3], |
|
kernel_size=2, |
|
stride=2, |
|
padding=0, |
|
bias=True, |
|
) |
|
self.fuse_lowres = nn.Conv2d( |
|
in_channels=(self.dims_encoder[3] + self.dims_encoder[3]), |
|
out_channels=self.dims_encoder[3], |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=True, |
|
) |
|
|
|
|
|
self.patch_encoder.blocks[self.hook_block_ids[0]].register_forward_hook( |
|
self._hook0 |
|
) |
|
self.patch_encoder.blocks[self.hook_block_ids[1]].register_forward_hook( |
|
self._hook1 |
|
) |
|
|
|
def _hook0(self, model, input, output): |
|
self.backbone_highres_hook0 = output |
|
|
|
def _hook1(self, model, input, output): |
|
self.backbone_highres_hook1 = output |
|
|
|
@property |
|
def img_size(self) -> int: |
|
"""Return the full image size of the SPN network.""" |
|
return self.patch_encoder.patch_embed.img_size[0] * 4 |
|
|
|
def _create_pyramid( |
|
self, x: torch.Tensor |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Create a 3-level image pyramid.""" |
|
|
|
x0 = x |
|
|
|
|
|
x1 = F.interpolate( |
|
x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False |
|
) |
|
|
|
|
|
x2 = F.interpolate( |
|
x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False |
|
) |
|
|
|
return x0, x1, x2 |
|
|
|
def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: |
|
"""Split the input into small patches with sliding window.""" |
|
patch_size = 384 |
|
patch_stride = int(patch_size * (1 - overlap_ratio)) |
|
|
|
image_size = x.shape[-1] |
|
steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 |
|
|
|
x_patch_list = [] |
|
for j in range(steps): |
|
j0 = j * patch_stride |
|
j1 = j0 + patch_size |
|
|
|
for i in range(steps): |
|
i0 = i * patch_stride |
|
i1 = i0 + patch_size |
|
x_patch_list.append(x[..., j0:j1, i0:i1]) |
|
|
|
return torch.cat(x_patch_list, dim=0) |
|
|
|
def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: |
|
"""Merge the patched input into a image with sliding window.""" |
|
steps = int(math.sqrt(x.shape[0] // batch_size)) |
|
|
|
idx = 0 |
|
|
|
output_list = [] |
|
for j in range(steps): |
|
output_row_list = [] |
|
for i in range(steps): |
|
output = x[batch_size * idx : batch_size * (idx + 1)] |
|
|
|
if j != 0: |
|
output = output[..., padding:, :] |
|
if i != 0: |
|
output = output[..., :, padding:] |
|
if j != steps - 1: |
|
output = output[..., :-padding, :] |
|
if i != steps - 1: |
|
output = output[..., :, :-padding] |
|
|
|
output_row_list.append(output) |
|
idx += 1 |
|
|
|
output_row = torch.cat(output_row_list, dim=-1) |
|
output_list.append(output_row) |
|
output = torch.cat(output_list, dim=-2) |
|
return output |
|
|
|
def reshape_feature( |
|
self, embeddings: torch.Tensor, width, height, cls_token_offset=1 |
|
): |
|
"""Discard class token and reshape 1D feature map to a 2D grid.""" |
|
b, hw, c = embeddings.shape |
|
|
|
|
|
if cls_token_offset > 0: |
|
embeddings = embeddings[:, cls_token_offset:, :] |
|
|
|
|
|
embeddings = embeddings.reshape(b, height, width, c).permute(0, 3, 1, 2) |
|
return embeddings |
|
|
|
def forward(self, x: torch.Tensor) -> list[torch.Tensor]: |
|
"""Encode input at multiple resolutions. |
|
|
|
Args: |
|
---- |
|
x (torch.Tensor): Input image. |
|
|
|
Returns: |
|
------- |
|
Multi resolution encoded features. |
|
|
|
""" |
|
batch_size = x.shape[0] |
|
|
|
|
|
x0, x1, x2 = self._create_pyramid(x) |
|
|
|
|
|
|
|
|
|
x0_patches = self.split(x0, overlap_ratio=0.25) |
|
|
|
x1_patches = self.split(x1, overlap_ratio=0.5) |
|
|
|
x2_patches = x2 |
|
|
|
|
|
x_pyramid_patches = torch.cat( |
|
(x0_patches, x1_patches, x2_patches), |
|
dim=0, |
|
) |
|
|
|
|
|
x_pyramid_encodings = self.patch_encoder(x_pyramid_patches) |
|
x_pyramid_encodings = self.reshape_feature( |
|
x_pyramid_encodings, self.out_size, self.out_size |
|
) |
|
|
|
|
|
|
|
x_latent0_encodings = self.reshape_feature( |
|
self.backbone_highres_hook0, |
|
self.out_size, |
|
self.out_size, |
|
) |
|
x_latent0_features = self.merge( |
|
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 |
|
) |
|
|
|
x_latent1_encodings = self.reshape_feature( |
|
self.backbone_highres_hook1, |
|
self.out_size, |
|
self.out_size, |
|
) |
|
x_latent1_features = self.merge( |
|
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 |
|
) |
|
|
|
|
|
x0_encodings, x1_encodings, x2_encodings = torch.split( |
|
x_pyramid_encodings, |
|
[len(x0_patches), len(x1_patches), len(x2_patches)], |
|
dim=0, |
|
) |
|
|
|
|
|
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3) |
|
|
|
|
|
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6) |
|
|
|
|
|
x2_features = x2_encodings |
|
|
|
|
|
x_global_features = self.image_encoder(x2_patches) |
|
x_global_features = self.reshape_feature( |
|
x_global_features, self.out_size, self.out_size |
|
) |
|
|
|
|
|
x_latent0_features = self.upsample_latent0(x_latent0_features) |
|
x_latent1_features = self.upsample_latent1(x_latent1_features) |
|
|
|
x0_features = self.upsample0(x0_features) |
|
x1_features = self.upsample1(x1_features) |
|
x2_features = self.upsample2(x2_features) |
|
|
|
x_global_features = self.upsample_lowres(x_global_features) |
|
x_global_features = self.fuse_lowres( |
|
torch.cat((x2_features, x_global_features), dim=1) |
|
) |
|
|
|
return [ |
|
x_latent0_features, |
|
x_latent1_features, |
|
x0_features, |
|
x1_features, |
|
x_global_features, |
|
] |
|
|