Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 | |
import os | |
from typing import List, Dict, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .head_act import activate_head | |
from .utils import create_uv_grid, position_grid_to_embed | |
class DPTHead(nn.Module): | |
""" | |
DPT Head for dense prediction tasks. | |
This implementation follows the architecture described in "Vision Transformers for Dense Prediction" | |
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer | |
backbone and produces dense predictions by fusing multi-scale features. | |
Args: | |
dim_in (int): Input dimension (channels). | |
patch_size (int, optional): Patch size. Default is 14. | |
output_dim (int, optional): Number of output channels. Default is 4. | |
activation (str, optional): Activation type. Default is "inv_log". | |
conf_activation (str, optional): Confidence activation type. Default is "expp1". | |
features (int, optional): Feature channels for intermediate representations. Default is 256. | |
out_channels (List[int], optional): Output channels for each intermediate layer. | |
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. | |
pos_embed (bool, optional): Whether to use positional embedding. Default is True. | |
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. | |
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. | |
""" | |
def __init__( | |
self, | |
dim_in: int, | |
patch_size: int = 14, | |
output_dim: int = 4, | |
activation: str = "linear", | |
conf_activation: str = "expp1", | |
features: int = 256, | |
out_channels: List[int] = [256, 512, 1024, 1024], | |
intermediate_layer_idx: List[int] = [4, 11, 17, 23], | |
pos_embed: bool = True, | |
feature_only: bool = False, | |
down_ratio: int = 1, | |
) -> None: | |
super(DPTHead, self).__init__() | |
self.patch_size = patch_size | |
self.activation = activation | |
self.conf_activation = conf_activation | |
self.pos_embed = pos_embed | |
self.feature_only = feature_only | |
self.down_ratio = down_ratio | |
self.intermediate_layer_idx = intermediate_layer_idx | |
self.norm = nn.LayerNorm(dim_in) | |
# Projection layers for each output channel from tokens. | |
self.projects = nn.ModuleList( | |
[ | |
nn.Conv2d( | |
in_channels=dim_in, | |
out_channels=oc, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
for oc in out_channels | |
] | |
) | |
# Resize layers for upsampling feature maps. | |
self.resize_layers = nn.ModuleList( | |
[ | |
nn.ConvTranspose2d( | |
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 | |
), | |
nn.ConvTranspose2d( | |
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 | |
), | |
nn.Identity(), | |
nn.Conv2d( | |
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 | |
), | |
] | |
) | |
self.scratch = _make_scratch( | |
out_channels, | |
features, | |
expand=False, | |
) | |
# Attach additional modules to scratch. | |
self.scratch.stem_transpose = None | |
self.scratch.refinenet1 = _make_fusion_block(features) | |
self.scratch.refinenet2 = _make_fusion_block(features) | |
self.scratch.refinenet3 = _make_fusion_block(features) | |
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) | |
head_features_1 = features | |
head_features_2 = 32 | |
if feature_only: | |
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) | |
else: | |
self.scratch.output_conv1 = nn.Conv2d( | |
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 | |
) | |
conv2_in_channels = head_features_1 // 2 | |
# self.scratch.output_conv2 = nn.Sequential( | |
# nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), | |
# nn.ReLU(inplace=True), | |
# nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), | |
# ) | |
# self.scratch.output_conv3是为了normal estimation而修改 | |
# self.scratch.output_conv3 = nn.Sequential( | |
# # 第1层 3x3: 从 in_channels=1280 -> 512 | |
# nn.Conv2d(conv2_in_channels, 512, kernel_size=3, stride=1, padding=1), | |
# nn.ReLU(inplace=True), | |
# # 第2层 3x3: 512 -> 512 | |
# nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), | |
# nn.ReLU(inplace=True), | |
# # 第3层 3x3: 512 -> 256 | |
# nn.Conv2d(512, head_features_2, kernel_size=3, stride=1, padding=1), | |
# nn.ReLU(inplace=True), | |
# # 最后一层 1x1: 256 -> 4 | |
# nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), | |
# ) | |
def make_output_conv( | |
self, | |
in_channels: int, | |
output_dim: int = 4 | |
): | |
conv_mid1 = in_channels //2 | |
conv_mid2 = conv_mid1 | |
conv_mid3 = in_channels // 4 | |
layers = nn.Sequential( | |
# 第1层 3x3: 从 in_channels -> conv_mid1 | |
nn.Conv2d(in_channels, conv_mid1, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
# 第2层 3x3: conv_mid1 -> conv_mid2 | |
nn.Conv2d(conv_mid1, conv_mid2, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
# 第3层 3x3: conv_mid2 -> conv_mid3 | |
nn.Conv2d(conv_mid2, conv_mid3, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
# 第4层 1x1: conv_mid3 -> output_dim | |
nn.Conv2d(conv_mid3, output_dim, kernel_size=1, stride=1, padding=0), | |
).cuda().to(torch.bfloat16) | |
return layers | |
def forward( | |
self, | |
aggregated_tokens_list: List[torch.Tensor], | |
images: torch.Tensor, | |
patch_start_idx: int, | |
frames_chunk_size: int = 10, | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
""" | |
Forward pass through the DPT head, supports processing by chunking frames. | |
Args: | |
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. | |
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. | |
patch_start_idx (int): Starting index for patch tokens in the token sequence. | |
Used to separate patch tokens from other tokens (e.g., camera or register tokens). | |
frames_chunk_size (int, optional): Number of frames to process in each chunk. | |
If None or larger than S, all frames are processed at once. Default: 8. | |
Returns: | |
Tensor or Tuple[Tensor, Tensor]: | |
- If feature_only=True: Feature maps with shape [B, S, C, H, W] | |
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] | |
""" | |
B, S, _, H, W = images.shape | |
# If frames_chunk_size is not specified or greater than S, process all frames at once | |
if frames_chunk_size is None or frames_chunk_size >= S: | |
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) #frames_chunk_size:一次最多处理几帧 | |
# Otherwise, process frames in chunks to manage memory usage | |
assert frames_chunk_size > 0 | |
# Process frames in batches | |
all_preds = [] | |
all_conf = [] | |
for frames_start_idx in range(0, S, frames_chunk_size): | |
frames_end_idx = min(frames_start_idx + frames_chunk_size, S) | |
# Process batch of frames | |
if self.feature_only: | |
chunk_output = self._forward_impl( | |
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx | |
) | |
all_preds.append(chunk_output) | |
else: | |
chunk_preds, chunk_conf = self._forward_impl( | |
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx | |
) | |
all_preds.append(chunk_preds) | |
all_conf.append(chunk_conf) | |
# Concatenate results along the sequence dimension | |
if self.feature_only: | |
return torch.cat(all_preds, dim=1) | |
else: | |
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) | |
def _forward_impl( | |
self, | |
aggregated_tokens_list: List[torch.Tensor], | |
images: torch.Tensor, | |
patch_start_idx: int, | |
frames_start_idx: int = None, | |
frames_end_idx: int = None, | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
""" | |
Implementation of the forward pass through the DPT head. | |
This method processes a specific chunk of frames from the sequence. | |
Args: | |
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. | |
images (Tensor): Input images with shape [B, S, 3, H, W]. | |
patch_start_idx (int): Starting index for patch tokens. | |
frames_start_idx (int, optional): Starting index for frames to process. | |
frames_end_idx (int, optional): Ending index for frames to process. | |
Returns: | |
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). | |
""" | |
if frames_start_idx is not None and frames_end_idx is not None: | |
images = images[:, frames_start_idx:frames_end_idx].contiguous() | |
B, S, _, H, W = images.shape#就是图片 | |
patch_h, patch_w = H // self.patch_size, W // self.patch_size | |
out = [] | |
dpt_idx = 0 | |
# aggregated_tokens_list: [24,B,3,H*W//14 + 4,C] 24是AABlock的个数 | |
for layer_idx in self.intermediate_layer_idx:# 这段代码需要借鉴过来,相当于用transformer多层特征(相同尺寸)模仿conv的mutil-scale特征,有意思 | |
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] #[B,3,H*W//14,C] | |
# 为什么只选[4,11,17,23]这几个layer_idx呢? | |
# Select frames if processing a chunk | |
if frames_start_idx is not None and frames_end_idx is not None: | |
x = x[:, frames_start_idx:frames_end_idx] | |
x = x.reshape(B * S, -1, x.shape[-1]) | |
x = self.norm(x) | |
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) | |
x = self.projects[dpt_idx](x) | |
if self.pos_embed: | |
x = self._apply_pos_embed(x, W, H).to(torch.bfloat16) | |
x = self.resize_layers[dpt_idx](x) | |
out.append(x) #一个feature list | |
dpt_idx += 1 | |
# Fuse features from multiple layers. | |
out = self.scratch_forward(out) | |
# Interpolate fused output to match target image resolution. | |
out = custom_interpolate( | |
out.float(), | |
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), | |
mode="bilinear", | |
align_corners=True, | |
).to(torch.bfloat16) # feature upsample | |
if self.pos_embed: | |
out = self._apply_pos_embed(out, W, H).to(torch.bfloat16) | |
if self.feature_only: | |
return out.reshape(B, S, *out.shape[1:]) | |
#out = self.scratch.output_conv2(out) | |
out = out.reshape(B, S, *out.shape[1:]).flatten(1,2) # B S C H W | |
in_channels = out.shape[1] | |
out = self.make_output_conv(in_channels, output_dim=4)(out) # B S 4 H W | |
preds, conf = activate_head(out, activation='linear', conf_activation=self.conf_activation) | |
# preds默认使用的是inv_log(x) = exp(x) - 1 | |
preds = preds.permute(0, 3, 1, 2) # B C,H,W | |
preds = F.normalize(preds, dim=1, p=2, eps=1e-12) # B 1 H W | |
# conf[:, None, ...] # B 1 H W | |
return preds, conf.unsqueeze(1) # B 1 H W | |
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: | |
""" | |
Apply positional embedding to tensor x. | |
""" | |
patch_w = x.shape[-1] | |
patch_h = x.shape[-2] | |
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) | |
pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) | |
pos_embed = pos_embed * ratio | |
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) | |
return x + pos_embed | |
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: | |
""" | |
Forward pass through the fusion blocks. | |
Args: | |
features (List[Tensor]): List of feature maps from different layers. | |
Returns: | |
Tensor: Fused feature map. | |
""" | |
layer_1, layer_2, layer_3, layer_4 = features | |
layer_1_rn = self.scratch.layer1_rn(layer_1) | |
layer_2_rn = self.scratch.layer2_rn(layer_2) | |
layer_3_rn = self.scratch.layer3_rn(layer_3) | |
layer_4_rn = self.scratch.layer4_rn(layer_4) | |
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) | |
del layer_4_rn, layer_4 | |
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) | |
del layer_3_rn, layer_3 | |
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) | |
del layer_2_rn, layer_2 | |
out = self.scratch.refinenet1(out, layer_1_rn) | |
del layer_1_rn, layer_1 | |
out = self.scratch.output_conv1(out) | |
return out | |
################################################################################ | |
# Modules | |
################################################################################ | |
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: | |
return FeatureFusionBlock( | |
features, | |
nn.ReLU(inplace=True), | |
deconv=False, | |
bn=False, | |
expand=False, | |
align_corners=True, | |
size=size, | |
has_residual=has_residual, | |
groups=groups, | |
) | |
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: | |
scratch = nn.Module() | |
out_shape1 = out_shape | |
out_shape2 = out_shape | |
out_shape3 = out_shape | |
if len(in_shape) >= 4: | |
out_shape4 = out_shape | |
if expand: | |
out_shape1 = out_shape | |
out_shape2 = out_shape * 2 | |
out_shape3 = out_shape * 4 | |
if len(in_shape) >= 4: | |
out_shape4 = out_shape * 8 | |
scratch.layer1_rn = nn.Conv2d( | |
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups | |
) | |
scratch.layer2_rn = nn.Conv2d( | |
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups | |
) | |
scratch.layer3_rn = nn.Conv2d( | |
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups | |
) | |
if len(in_shape) >= 4: | |
scratch.layer4_rn = nn.Conv2d( | |
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups | |
) | |
return scratch | |
class ResidualConvUnit(nn.Module): | |
"""Residual convolution module.""" | |
def __init__(self, features, activation, bn, groups=1): | |
"""Init. | |
Args: | |
features (int): number of features | |
""" | |
super().__init__() | |
self.bn = bn | |
self.groups = groups | |
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) | |
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) | |
self.norm1 = None | |
self.norm2 = None | |
self.activation = activation | |
self.skip_add = nn.quantized.FloatFunctional() | |
def forward(self, x): | |
"""Forward pass. | |
Args: | |
x (tensor): input | |
Returns: | |
tensor: output | |
""" | |
out = self.activation(x) | |
out = self.conv1(out) | |
if self.norm1 is not None: | |
out = self.norm1(out) | |
out = self.activation(out) | |
out = self.conv2(out) | |
if self.norm2 is not None: | |
out = self.norm2(out) | |
return self.skip_add.add(out, x) | |
class FeatureFusionBlock(nn.Module): | |
"""Feature fusion block.""" | |
def __init__( | |
self, | |
features, | |
activation, | |
deconv=False, | |
bn=False, | |
expand=False, | |
align_corners=True, | |
size=None, | |
has_residual=True, | |
groups=1, | |
): | |
"""Init. | |
Args: | |
features (int): number of features | |
""" | |
super(FeatureFusionBlock, self).__init__() | |
self.deconv = deconv | |
self.align_corners = align_corners | |
self.groups = groups | |
self.expand = expand | |
out_features = features | |
if self.expand == True: | |
out_features = features // 2 | |
self.out_conv = nn.Conv2d( | |
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups | |
) | |
if has_residual: | |
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) | |
self.has_residual = has_residual | |
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) | |
self.skip_add = nn.quantized.FloatFunctional() | |
self.size = size | |
def forward(self, *xs, size=None): | |
"""Forward pass. | |
Returns: | |
tensor: output | |
""" | |
output = xs[0] | |
if self.has_residual: | |
res = self.resConfUnit1(xs[1]) | |
output = self.skip_add.add(output, res) | |
output = self.resConfUnit2(output) | |
if (size is None) and (self.size is None): | |
modifier = {"scale_factor": 2} | |
elif size is None: | |
modifier = {"size": self.size} | |
else: | |
modifier = {"size": size} | |
output = custom_interpolate(output.float(), **modifier, mode="bilinear", align_corners=self.align_corners).to(torch.bfloat16) | |
output = self.out_conv(output) | |
return output | |
def custom_interpolate( | |
x: torch.Tensor, | |
size: Tuple[int, int] = None, | |
scale_factor: float = None, | |
mode: str = "bilinear", | |
align_corners: bool = True, | |
) -> torch.Tensor: | |
""" | |
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. | |
""" | |
if size is None: | |
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) | |
INT_MAX = 1610612736 | |
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] | |
if input_elements > INT_MAX: | |
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) | |
interpolated_chunks = [ | |
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks | |
] | |
x = torch.cat(interpolated_chunks, dim=0) | |
return x.contiguous() | |
else: | |
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) | |