MultiMAE / multimae /output_adapters.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame
31.8 kB
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv MAE, DPT and ConvNeXt code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/facebookresearch/moco-v3
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/BUPT-PRIV/MAE-priv
# https://github.com/facebookresearch/mae
# https://github.com/isl-org/DPT
# https://github.com/facebookresearch/ConvNeXt
# --------------------------------------------------------
from functools import partial
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from .multimae_utils import (Block, CrossAttention, Mlp,
build_2d_sincos_posemb, pair, trunc_normal_)
from .output_adapter_utils import (ConvNeXtBlock, Interpolate,
make_fusion_block, make_scratch)
class SpatialOutputAdapter(nn.Module):
"""Cross-attention adapter for spatial outputs, like images or feature maps.
:param num_channels: Number of input channels of the image/feature map
:param stride_level: Stride level compared to the full-sized image.
E.g. 4 for 1/4th the size of the image.
:param patch_size_full: Int or tuple of the patch size over the full image size.
Patch size for smaller inputs will be computed accordingly.
:param dim_tokens_enc: Dimension of tokens coming from encoder. Can be set using init method.
:param dim_tokens: Dimension of decoder tokens
:param depth: Number of additional (full self-attention) transformer layers after initial cross attention and MLP
:param learnable_pos_emb: Set to True to learn positional embeddings instead
:param image_size: Default image size. Used to initialize size of positional embeddings.
:param mlp_ratio: MLP hidden dim ratio
:param num_heads: Number of attention heads
:param qkv_bias: Set to True to enable bias
:param drop_rate: Probability of dropping attention layer outputs
:param attn_drop_rate: Probability of dropping attention matrix elements
:param drop_path_rate: DropPath drop rate
:param norm_layer: Type of normalization layer
:param use_task_queries: When set to True, adds task specific tokens from encoder (if available)
to the corresponding query entries
:param task: Task for which encoder tokens are added to the queries of the decoder (e.g. RGB if decoder is used for RGB)
:param context_tasks: Tasks / modalities from the encoder. Used to create learned embeddings for each task.
:param use_xattn: When set to True, attend to the tokens from the encoder through a cross-attention layer
"""
def __init__(self,
num_channels: int,
stride_level: int,
patch_size_full: Union[int, Tuple[int, int]],
dim_tokens_enc: Optional[int] = None,
dim_tokens: int = 256,
depth: int = 0,
learnable_pos_emb: int = False,
image_size: Union[int, Tuple[int]] = 224,
mlp_ratio: int = 4.0,
num_heads: int = 8,
qkv_bias: bool = True,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
use_task_queries: bool = True,
task: Optional[str] = None,
context_tasks: Optional[list] = None,
use_xattn: bool = True
):
super().__init__()
self.num_channels = num_channels
self.stride_level = stride_level
self.patch_size_full = pair(patch_size_full)
self.dim_tokens_enc = dim_tokens_enc
self.dim_tokens = dim_tokens
self.learnable_pos_emb = learnable_pos_emb
self.image_size = pair(image_size)
self.use_task_queries = use_task_queries
self.task = task
self.use_xattn = use_xattn
# Actual patch height and width, taking into account stride of input
self.P_H = max(1, self.patch_size_full[0] // stride_level)
self.P_W = max(1, self.patch_size_full[1] // stride_level)
if context_tasks is not None:
self.task_embeddings = nn.ParameterDict(
{task: nn.Parameter(torch.zeros(1, 1, self.dim_tokens)) for task in context_tasks})
for embedding in self.task_embeddings.values():
trunc_normal_(embedding, std=0.02)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.dim_tokens))
# Fixed-size positional embeddings. Can be interpolated to different input sizes
h_posemb = self.image_size[0] // (self.stride_level * self.P_H)
w_posemb = self.image_size[1] // (self.stride_level * self.P_W)
if not self.learnable_pos_emb:
self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens)
self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=False)
else:
self.pos_emb = nn.Parameter(torch.zeros(1, h_posemb, w_posemb, self.dim_tokens))
trunc_normal_(self.pos_emb, std=0.02)
# One cross attention layer followed by MLP block, an optional transformer, and an output projection
if self.use_xattn:
self.decoder = CrossAttention(
dim=self.dim_tokens, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop_rate, proj_drop=drop_rate)
self.context_norm = norm_layer(self.dim_tokens)
self.query_norm = norm_layer(self.dim_tokens)
self.out_norm = norm_layer(self.dim_tokens)
mlp_hidden_dim = int(self.dim_tokens * mlp_ratio)
self.mlp = Mlp(in_features=self.dim_tokens, hidden_features=mlp_hidden_dim)
# Optional full self-attention transformer layers
if depth > 0:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.decoder_transformer = nn.Sequential(*[
Block(dim=self.dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)
])
else:
self.decoder_transformer = nn.Identity()
self.dim_patch = self.num_channels * self.P_H * self.P_W
self.out_proj = nn.Linear(self.dim_tokens, self.dim_patch)
if self.dim_tokens_enc is not None:
self.init(dim_tokens_enc=dim_tokens_enc)
def init(self, dim_tokens_enc: int = 768):
'''
Initialize parts of decoder that are dependent on dimension of encoder tokens.
Should be called when setting up MultiMAE.
:param dim_tokens_enc: Dimension of tokens coming from encoder
'''
self.dim_tokens_enc = dim_tokens_enc
# Projection of encoder tokens to the patch dimension
self.proj_context = nn.Linear(self.dim_tokens_enc, self.dim_tokens)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_emb', 'mask_token', 'task_embeddings'}
def generate_context_embeddings(self, input_info,
bs: int,
size: Tuple[int, int],
device: Optional[torch.device] = None):
context_embeddings = []
for task, info in input_info["tasks"].items():
if self.task_embeddings is not None and task in self.task_embeddings:
task_emb = repeat(self.task_embeddings[task], '() () d -> b n d', b=bs, n=info['num_tokens'])
else:
task_emb = torch.zeros((bs, info['num_tokens'], self.dim_tokens), device=device)
if info['has_2d_posemb']:
pos_emb = F.interpolate(self.pos_emb, size=size, mode='bilinear', align_corners=False)
pos_emb = rearrange(pos_emb, 'b d nh nw -> b (nh nw) d')
assert info['num_tokens'] == pos_emb.shape[1]
task_emb = task_emb + pos_emb
context_embeddings.append(task_emb)
context_embeddings = torch.cat(context_embeddings, dim=1)
return context_embeddings
def get_queries_and_context(self, context_tokens, input_info, ids_keep, ids_restore):
B = context_tokens.shape[0]
H, W = input_info['image_size']
# Number of patches in height and width
N_H = H // (self.stride_level * self.P_H)
N_W = W // (self.stride_level * self.P_W)
if 'num_global_tokens' in input_info:
context_tokens_without_global = context_tokens[:, :-input_info['num_global_tokens']]
else:
context_tokens_without_global = context_tokens
# Add mask tokens
mask_tokens = repeat(self.mask_token, '() () d -> b n d', b=B,
n=input_info['num_task_tokens'] - context_tokens_without_global.shape[1])
context_with_mask = torch.cat([context_tokens_without_global, mask_tokens], dim=1)
# Unshuffle context_with_mask
context_with_mask = torch.gather(context_with_mask, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, context_with_mask.shape[2]))
# Generate context_emb and add them to context
context_emb = self.generate_context_embeddings(input_info=input_info, bs=B, size=(N_H, N_W),
device=context_tokens.device)
context_with_mask = context_with_mask + context_emb
# Generate queries
if self.use_task_queries and self.task in input_info['tasks']:
start_idx = input_info['tasks'][self.task]['start_idx']
end_idx = input_info['tasks'][self.task]['end_idx']
queries = context_with_mask[:, start_idx:end_idx]
else:
queries = repeat(self.mask_token, '() () d -> b n d', b=B, n=N_H * N_W)
queries_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bilinear', align_corners=False)
queries_pos_emb = rearrange(queries_pos_emb, 'b d nh nw -> b (nh nw) d')
queries = queries + queries_pos_emb
if self.task_embeddings is not None and self.task in self.task_embeddings:
queries_task_emb = repeat(self.task_embeddings[self.task], '() () d -> b n d', b=B, n=N_H * N_W)
queries = queries + queries_task_emb
# Unshuffle context and keep only initial context (yes, again)
context_tokens_without_global = torch.gather(context_with_mask, dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, context_with_mask.shape[2]))
# Add back global tokens
if 'num_global_tokens' in input_info:
context_tokens = torch.cat(
[context_tokens_without_global, context_tokens[:, -input_info['num_global_tokens']:]], dim=1)
else:
context_tokens = context_tokens_without_global
return queries, context_tokens
def forward(self,
encoder_tokens: torch.Tensor,
input_info: Dict,
ids_keep: torch.Tensor,
ids_restore: torch.Tensor,
):
"""
Forward pass taking output tokens from encoder and optionally a subset of them corresponding
to this output adapter's task (needs an additional mask describing position of these tokens in the queries).
:param encoder_tokens: Output of encoder
:param input_info: Dictionary with information about the input modalities
:param ids_keep: IDs of unmasked tokens (tokens given to the encoder)
:param ids_restore: IDs to unshuffle tokens
"""
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
H, W = input_info['image_size']
# Number of patches in height and width
N_H = H // (self.stride_level * self.P_H)
N_W = W // (self.stride_level * self.P_W)
# Project encoder tokens to decoder tokens
context_tokens = self.proj_context(encoder_tokens)
# Get queries and context
queries, context_tokens = self.get_queries_and_context(context_tokens, input_info, ids_keep, ids_restore)
# Perform cross attention of queries to context tokens, followed by an MLP
if self.use_xattn:
x = self.decoder(self.query_norm(queries), self.context_norm(context_tokens))
x = x + self.mlp(self.out_norm(x))
else:
x = queries
# Optional transformer layers if depth > 0
x = self.decoder_transformer(x)
# Project each token to (C * P_H * P_W)
x = self.out_proj(x)
# Reshape sequence of patches into image
x = rearrange(
x, 'b (nh nw) (c ph pw) -> b c (nh ph) (nw pw)',
nh=N_H, nw=N_W, ph=self.P_H, pw=self.P_W, c=self.num_channels
)
return x
class LinearOutputAdapter(nn.Module):
"""
Linear output adapter.
:param num_classes: Number of classes
:param dim_tokens_enc: Dimension of tokens from the encoder
:param use_mean_pooling: When set to True, uses mean pooling before linear classification head.
Otherwise, use last token (usually the global token)
:param norm_layer: Normalization layer
:param init_scale: Initialization scale for linear classification head
"""
def __init__(self,
num_classes: int,
dim_tokens_enc: Optional[int] = None,
use_mean_pooling: bool = True,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
init_scale: float = 1.0):
super().__init__()
self.num_classes = num_classes
self.dim_tokens_enc = dim_tokens_enc
self.use_mean_pooling = use_mean_pooling
self.norm_layer = norm_layer
self.init_scale = init_scale
if self.dim_tokens_enc is not None:
self.init(dim_tokens_enc=dim_tokens_enc)
def init(self, dim_tokens_enc: int = 768):
"""
Initialize parts of decoder that are dependent on dimension of encoder tokens.
Should be called when setting up MultiMAE.
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
self.dim_tokens_enc = dim_tokens_enc
self.norm = self.norm_layer(self.dim_tokens_enc)
self.head = nn.Linear(dim_tokens_enc, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
self.head.weight.data.mul_(self.init_scale)
self.head.bias.data.mul_(self.init_scale)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.init(dim_tokens_enc=self.dim_tokens_enc)
def forward(self,
encoder_tokens: torch.Tensor,
**kwargs):
if self.use_mean_pooling:
x = encoder_tokens.mean(1)
else:
# Global token is added at the end
x = encoder_tokens[:, -1]
x = self.head(self.norm(x))
return x
class SegmenterMaskTransformerAdapter(nn.Module):
"""Output adapter inspired by the Segmenter-Mask architecture
This head is the implementation of `Segmenter: <https://arxiv.org/abs/2105.05633>`_.
:param num_classes: Number of classes
:param depth: Depth of decoder
:param num_heads: Number of attention heads
:param embed_dim: Dimension of decoder tokens
:param mlp_ratio: MLP hidden dim ratio
:param drop_path_rate: DropPath drop rate
:param drop_rate: Dropout after MLPs and Attention
:param attn_drop_rate: Attention matrix drop rate
:param qkv_bias: Set to False to disable bias
:param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept.
:param patch_size: Size of patches
:param norm_layer: Type of normalization layer
"""
def __init__(
self,
num_classes,
depth: int = 2,
num_heads: int = 12,
embed_dim: int = 768,
mlp_ratio=4,
drop_path_rate=0.1,
drop_rate=0.0,
attn_drop_rate=0.0,
qkv_bias=True,
main_tasks: str = ('rgb',),
patch_size: int = 16,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
**kwargs,
):
super().__init__()
self.main_tasks = main_tasks
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_classes = num_classes
self.cls_emb = nn.Parameter(torch.zeros(1, num_classes, embed_dim))
trunc_normal_(self.cls_emb, std=0.02)
self.patch_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.classes_proj = nn.Linear(embed_dim, embed_dim, bias=False)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)
])
self.decoder_norm = norm_layer(embed_dim)
self.mask_norm = norm_layer(num_classes)
self.apply(self._init_weights)
def init(self, dim_tokens_enc: int = 768):
"""
Initialize parts of decoder that are dependent on dimension of encoder tokens.
Should be called when setting up MultiMAE.
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
self.in_channels = dim_tokens_enc * len(self.main_tasks)
# Projection of encoder tokens to the patch dimension
self.proj_dec = nn.Linear(self.in_channels, self.embed_dim)
self._init_weights(self.proj_dec)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def adapt_tokens(self, encoder_tokens, input_info):
# Adapt tokens
x = []
for task in self.main_tasks:
start_idx = input_info['tasks'][task]['start_idx']
end_idx = input_info['tasks'][task]['end_idx']
x.append(encoder_tokens[:, start_idx:end_idx])
x = torch.cat(x, dim=-1)
return x
def forward(self, encoder_tokens: torch.Tensor, input_info: Dict):
H, W = input_info['image_size']
N_H, N_W = H // self.patch_size, W // self.patch_size
x = self.adapt_tokens(encoder_tokens, input_info)
x = self.proj_dec(x)
cls_emb = self.cls_emb.expand(x.shape[0], -1, -1)
x = torch.cat((x, cls_emb), 1)
for blk in self.blocks:
x = blk(x)
x = self.decoder_norm(x)
patches = self.patch_proj(x[:, :-self.num_classes])
cls_seg_feat = self.classes_proj(x[:, -self.num_classes:])
patches = F.normalize(patches, dim=2, p=2)
cls_seg_feat = F.normalize(cls_seg_feat, dim=2, p=2)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
masks = rearrange(masks, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W)
# Interpolate to semseg res
masks = F.interpolate(masks, size=(H, W), mode="bilinear")
return masks
class ConvNeXtAdapter(nn.Module):
"""Output adapter with ConvNext blocks for semantic segmentation
:param num_classes: Number of classes
:param num_heads: Number of attention heads
:param embed_dim: Token dimension after projection, and before reshaping operation.
:param preds_per_patch: Increases size of feature map by reshaping each patch Each patch gets reshaped
from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5)
:param main_tasks: Tasks to use for the adapter. Only tokens coming from these tasks are kept.
:param patch_size: Size of patches
:param depth: Number of ConvNeXt blocks
:interpolate_mode: Interpolation mode for final upsampling
"""
def __init__(
self,
num_classes,
embed_dim: int = 6144,
preds_per_patch: int = 16,
main_tasks: Iterable[str] = ('rgb',),
patch_size: int = 16,
depth: int = 4,
interpolate_mode: str = 'bilinear',
**kwargs,
):
super().__init__()
self.main_tasks = main_tasks
self.patch_size = patch_size
self.embed_dim = embed_dim
self.preds_per_patch = preds_per_patch
self.class_dim = embed_dim // preds_per_patch
self.num_classes = num_classes
self.interpolate_mode = interpolate_mode
self.blocks = nn.Sequential(*[
ConvNeXtBlock(dim=self.class_dim)
for _ in range(depth)
])
self.final_layer = nn.Conv2d(self.class_dim, self.num_classes, 1)
self.apply(self._init_weights)
def init(self, dim_tokens_enc: int = 768):
"""
Initialize parts of decoder that are dependent on dimension of encoder tokens.
Should be called when setting up MultiMAE.
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
self.in_channels = dim_tokens_enc * len(self.main_tasks)
# Projection of encoder tokens to the patch dimension
self.proj_dec = nn.Linear(self.in_channels, self.embed_dim)
self._init_weights(self.proj_dec)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def adapt_tokens(self, encoder_tokens, input_info):
# Adapt tokens
x = []
for task in self.main_tasks:
start_idx = input_info['tasks'][task]['start_idx']
end_idx = input_info['tasks'][task]['end_idx']
x.append(encoder_tokens[:, start_idx:end_idx])
x = torch.cat(x, dim=-1)
return x
def forward(self, encoder_tokens: torch.Tensor, input_info: Dict):
H, W = input_info['image_size']
N_H, N_W = H // self.patch_size, W // self.patch_size
x = self.adapt_tokens(encoder_tokens, input_info)
x = self.proj_dec(x)
x = rearrange(x, "b n (p c) -> b (n p) c", n=N_H * N_W, p=self.preds_per_patch, c=self.class_dim)
x = rearrange(x, "b (nh nw ph pw) c -> b c (nh ph) (nw pw)",
nh=N_H, nw=N_W,
ph=int(self.preds_per_patch ** 0.5),
pw=int(self.preds_per_patch ** 0.5))
x = self.blocks(x)
x = self.final_layer(x)
# Interpolate to semseg res
x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode)
return x
class DPTOutputAdapter(nn.Module):
"""DPT output adapter.
:param num_classes: Number of output channels
:param stride_level: tride level compared to the full-sized image.
E.g. 4 for 1/4th the size of the image.
:param patch_size_full: Int or tuple of the patch size over the full image size.
Patch size for smaller inputs will be computed accordingly.
:param hooks: Index of intermediate layers
:param layer_dims: Dimension of intermediate layers
:param feature_dim: Feature dimension
:param use_bn: If set to True, activates batch norm
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
def __init__(self,
num_classes: int = 3,
stride_level: int = 1,
patch_size: Union[int, Tuple[int, int]] = 16,
main_tasks: Iterable[str] = ('rgb',),
hooks: List[int] = [2, 5, 8, 11],
layer_dims: List[int] = [96, 192, 384, 768],
feature_dim: int = 256,
use_bn: bool = False,
dim_tokens_enc: Optional[int] = None,
head_type: str = 'regression',
**kwargs):
super().__init__()
self.num_channels = num_classes
self.stride_level = stride_level
self.patch_size = pair(patch_size)
self.main_tasks = main_tasks
self.hooks = hooks
self.layer_dims = layer_dims
self.feature_dim = feature_dim
self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
self.head_type = head_type
# Actual patch height and width, taking into account stride of input
self.P_H = max(1, self.patch_size[0] // stride_level)
self.P_W = max(1, self.patch_size[1] // stride_level)
self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn)
self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn)
self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn)
self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn)
if self.head_type == 'regression':
# The "DPTDepthModel" head
self.head = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(feature_dim // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, self.num_channels, kernel_size=1, stride=1, padding=0)
)
elif self.head_type == 'semseg':
# The "DPTSegmentationModel" head
self.head = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
nn.ReLU(True),
nn.Dropout(0.1, False),
nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
)
else:
raise ValueError('DPT head_type must be "regression" or "semseg".')
if self.dim_tokens_enc is not None:
self.init(dim_tokens_enc=dim_tokens_enc)
def init(self, dim_tokens_enc: int = 768):
"""
Initialize parts of decoder that are dependent on dimension of encoder tokens.
Should be called when setting up MultiMAE.
:param dim_tokens_enc: Dimension of tokens coming from encoder
"""
self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks)
# Set up activation postprocessing layers
self.act_1_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc,
out_channels=self.layer_dims[0],
kernel_size=1, stride=1, padding=0,
),
nn.ConvTranspose2d(
in_channels=self.layer_dims[0],
out_channels=self.layer_dims[0],
kernel_size=4, stride=4, padding=0,
bias=True, dilation=1, groups=1,
)
)
self.act_2_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc,
out_channels=self.layer_dims[1],
kernel_size=1, stride=1, padding=0,
),
nn.ConvTranspose2d(
in_channels=self.layer_dims[1],
out_channels=self.layer_dims[1],
kernel_size=2, stride=2, padding=0,
bias=True, dilation=1, groups=1,
)
)
self.act_3_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc,
out_channels=self.layer_dims[2],
kernel_size=1, stride=1, padding=0,
)
)
self.act_4_postprocess = nn.Sequential(
nn.Conv2d(
in_channels=self.dim_tokens_enc,
out_channels=self.layer_dims[3],
kernel_size=1, stride=1, padding=0,
),
nn.Conv2d(
in_channels=self.layer_dims[3],
out_channels=self.layer_dims[3],
kernel_size=3, stride=2, padding=1,
)
)
self.act_postprocess = nn.ModuleList([
self.act_1_postprocess,
self.act_2_postprocess,
self.act_3_postprocess,
self.act_4_postprocess
])
def adapt_tokens(self, encoder_tokens, input_info):
# Adapt tokens
x = []
for task in self.main_tasks:
start_idx = input_info['tasks'][task]['start_idx']
end_idx = input_info['tasks'][task]['end_idx']
x.append(encoder_tokens[:, start_idx:end_idx])
x = torch.cat(x, dim=-1)
return x
def forward(self, encoder_tokens: List[torch.Tensor], input_info: Dict):
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
H, W = input_info['image_size']
# Number of patches in height and width
N_H = H // (self.stride_level * self.P_H)
N_W = W // (self.stride_level * self.P_W)
# Hook decoder onto 4 layers from specified ViT layers
layers = [encoder_tokens[hook] for hook in self.hooks]
# Extract only task-relevant tokens and ignore global tokens.
layers = [self.adapt_tokens(l, input_info) for l in layers]
# Reshape tokens to spatial representation
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
# Postprocess activations
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
# Project layers to chosen feature dim
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
# Fuse layers using refinement stages
path_4 = self.scratch.refinenet4(layers[3])
path_3 = self.scratch.refinenet3(path_4, layers[2])
path_2 = self.scratch.refinenet2(path_3, layers[1])
path_1 = self.scratch.refinenet1(path_2, layers[0])
# Output head
out = self.head(path_1)
return out