MultiMAE / multimae /multimae.py
roman-bachmann's picture
Update multimae/multimae.py
ff0b3d2
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/facebookresearch/moco-v3
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/BUPT-PRIV/MAE-priv
# https://github.com/facebookresearch/mae
# --------------------------------------------------------
import itertools
import math
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional, Union
import torch
from einops import rearrange, repeat
from torch import nn
from torch.distributions.dirichlet import Dirichlet
from utils.registry import register_model
from .multimae_utils import Block, trunc_normal_
__all__ = [
'pretrain_multimae_base',
'pretrain_multimae_large',
'multivit_base',
'multivit_large',
]
class MultiMAE(nn.Module):
"""MultiMAE: Multi-task Multi-modal Masked Autoencoder
This module performs masking in its forward pass.
The MultiViT module defined below inherits from this module and performs a regular forward pass,
and should be used instead for downstream tasks
:param input_adapters: Dictionary of task -> input adapters
:param output_adapters: Optional dictionary of task -> output adapters
:param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1
:param dim_tokens: Dimension of encoder tokens
:param depth: Depth of encoder
:param num_heads: Number of attention heads
:param mlp_ratio: MLP hidden dim ratio
:param qkv_bias: Set to False to disable bias
:param drop_rate: Dropout after MLPs and Attention
:param attn_drop_rate: Attention matrix drop rate
:param drop_path_rate: DropPath drop rate
:param norm_layer: Type of normalization layer
"""
def __init__(self,
input_adapters: Dict[str, nn.Module],
output_adapters: Optional[Dict[str, nn.Module]],
num_global_tokens: int = 1,
dim_tokens: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
# Initialize input and output adapters
for adapter in input_adapters.values():
adapter.init(dim_tokens=dim_tokens)
self.input_adapters = nn.ModuleDict(input_adapters)
if output_adapters is not None:
for adapter in output_adapters.values():
adapter.init(dim_tokens_enc=dim_tokens)
self.output_adapters = nn.ModuleDict(output_adapters)
else:
self.output_adapters = None
# Additional learnable tokens that can be used by encoder to process/store global information
self.num_global_tokens = num_global_tokens
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens))
trunc_normal_(self.global_tokens, std=0.02)
# Transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.encoder = nn.Sequential(*[
Block(dim=dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)
])
self.apply(self._init_weights)
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
elif 'kv' in name:
# treat the weights of K, V separately
val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
if isinstance(m, nn.Conv2d):
if '.proj' in name:
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d)
w = m.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.encoder)
@torch.jit.ignore
def no_weight_decay(self):
no_wd_set = {'global_tokens'}
for task, adapter in self.input_adapters.items():
if hasattr(adapter, 'no_weight_decay'):
to_skip = adapter.no_weight_decay()
to_skip = set([f'input_adapters.{task}.{name}' for name in to_skip])
no_wd_set = no_wd_set | to_skip
for task, adapter in self.output_adapters.items():
if hasattr(adapter, 'no_weight_decay'):
to_skip = adapter.no_weight_decay()
to_skip = set([f'output_adapters.{task}.{name}' for name in to_skip])
no_wd_set = no_wd_set | to_skip
return no_wd_set
def sample_alphas(self, B: int, n_tasks: int, alphas: float = 1.0, eps: float = 1e-5):
"""
Sample alphas for Dirichlet sampling such that tasks are first uniformly chosen and then Dirichlet sampling
is performed over the chosen ones.
:param B: Batch size
:param n_tasks: Number of input tasks
:param alphas: Float or list to multiply task choices {0,1} by
:param eps: Small constant since Dirichlet alphas need to be positive
"""
valid_task_choices = torch.Tensor([list(i) for i in itertools.product([0, 1], repeat=n_tasks)][1:])
rand_per_sample_choice = torch.randint(0, len(valid_task_choices), (B,))
alphas_tensor = torch.index_select(valid_task_choices, 0, rand_per_sample_choice)
alphas_tensor = alphas_tensor * torch.tensor(alphas) + eps
return alphas_tensor
def generate_random_masks(self,
input_tokens: Dict[str, torch.Tensor],
num_encoded_tokens: int,
alphas: Union[float, List[float]] = 1.0,
sample_tasks_uniformly: bool = False) :
"""
Sample a total of num_encoded_tokens from different tasks using Dirichlet sampling.
:param input_tokens: Dictionary of tensors to sample num_encoded_tokens from
:param num_encoded_tokens: Number of tokens to select
:param alphas: Dirichlet distribution parameter alpha. Lower alpha = harder,
less uniform sampling. Can be float or list of floats.
:param sample_tasks_uniformly: Set to True to first sample 1-n_tasks uniformly at random
for each sample in the batch. Dirichlet sampling is then done over selected subsets.
"""
B = list(input_tokens.values())[0].shape[0]
device = list(input_tokens.values())[0].device
alphas = [alphas] * len(input_tokens) if isinstance(alphas, float) else alphas
if sample_tasks_uniformly:
alphas = self.sample_alphas(B, len(input_tokens), alphas=alphas)
task_sampling_dist = Dirichlet(alphas).sample().to(device)
else:
task_sampling_dist = Dirichlet(torch.Tensor(alphas)).sample((B,)).to(device)
samples_per_task = (task_sampling_dist * num_encoded_tokens).round().long()
task_masks = []
num_tokens_per_task = [task_tokens.shape[1] for task_tokens in input_tokens.values()]
for i, num_tokens in enumerate(num_tokens_per_task):
# Use noise to shuffle arange
noise = torch.rand(B, num_tokens, device=device) # noise in [0, 1]
ids_arange_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
mask = torch.arange(num_tokens, device=device).unsqueeze(0).expand(B, -1)
mask = torch.gather(mask, dim=1, index=ids_arange_shuffle)
# 0 is keep (unmasked), 1 is remove (masked)
mask = torch.where(mask < samples_per_task[:, i].unsqueeze(1), 0, 1)
task_masks.append(mask)
mask_all = torch.cat(task_masks, dim=1)
ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float()), dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :num_encoded_tokens]
# Update binary mask to adjust for task rounding
mask_all = torch.ones_like(mask_all)
mask_all[:, :num_encoded_tokens] = 0
# Unshuffle to get the binary mask
mask_all = torch.gather(mask_all, dim=1, index=ids_restore)
# Split to get task masks
task_masks = torch.split(mask_all, num_tokens_per_task, dim=1)
# Convert to dict
task_masks = {domain: mask for domain, mask in zip(input_tokens.keys(), task_masks)}
return task_masks, ids_keep, ids_restore
@staticmethod
def make_mask(N_H, N_W, xy_idxs, full_tasks=[], indicate_visible=True, flatten=True, device='cuda'):
"""
Creates masks for each task, given lists of un-masked x,y coordinates.
"""
xy_idxs = {
k: torch.LongTensor(v)
for k, v in xy_idxs.items()
}
task_masks = {
k: torch.ones(N_H, N_W).to(device)
for k in xy_idxs.keys()
}
for k in xy_idxs.keys():
if len(xy_idxs[k]) > 0:
task_masks[k][xy_idxs[k][:, 1], xy_idxs[k][:, 0]] = 0
for task in full_tasks:
task_masks[task][:] = 0
if not indicate_visible:
task_masks = {k: 1 - v for k, v in task_masks.items()}
if flatten:
task_masks = {k: v.flatten().unsqueeze(0) for k, v in task_masks.items()}
return task_masks
def generate_input_info(self, input_task_tokens, image_size):
input_info = OrderedDict()
i = 0
input_info['tasks'] = {}
for domain, tensor in input_task_tokens.items():
num_tokens = tensor.shape[1]
d = {
'num_tokens': num_tokens,
'has_2d_posemb': True, # TODO: Modify when adding non-2D tasks
'start_idx': i,
'end_idx': i + num_tokens,
}
i += num_tokens
input_info['tasks'][domain] = d
input_info['image_size'] = image_size
input_info['num_task_tokens'] = i
input_info['num_global_tokens'] = self.num_global_tokens
return input_info
def forward(self,
x: Union[Dict[str, torch.Tensor], torch.Tensor],
mask_inputs: bool = True,
task_masks: Dict[str, torch.Tensor] = None,
num_encoded_tokens: int = 128,
alphas: Union[float, List[float]] = 1.0,
sample_tasks_uniformly: bool = False,
fp32_output_adapters: List[str] = []):
"""
Forward pass through input adapters, transformer encoder and output adapters.
If specified, will randomly drop input tokens.
:param x: Input tensor or dictionary of tensors
:param mask_inputs: Set to True to enable random masking of input patches
:param task_masks: Optional dictionary of task->mask pairs.
:param num_encoded_tokens: Number of tokens to randomly select for encoder.
Only used if mask_inputs is True.
:param alphas: Dirichlet distribution parameter alpha for task sampling.
Higher alpha = harder, less uniform sampling. Can be float or list of floats.
:param sample_tasks_uniformly: Set to True if tasks should be uniformly presampled,
before Dirichlet sampling decides share of masked tokens between them.
:param fp32_output_adapters: List of task identifiers to force output adapters to
run with mixed precision turned off for stability reasons.
"""
## Processing input modalities
# If input x is a Tensor, assume it's RGB
x = {'rgb': x} if isinstance(x, torch.Tensor) else x
# Need image size for tokens->image reconstruction
# We assume that at least one of rgb or semseg is given as input before masking
if 'rgb' in x:
B, C, H, W = x['rgb'].shape
elif 'semseg' in x:
B, H, W = x['semseg'].shape
H *= self.input_adapters['semseg'].stride_level
W *= self.input_adapters['semseg'].stride_level
else:
B, C, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape
# Encode selected inputs to tokens
input_task_tokens = {
domain: self.input_adapters[domain](tensor)
for domain, tensor in x.items()
if domain in self.input_adapters
}
input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W))
# Select random subset of tokens from the chosen input tasks and concatenate them
if mask_inputs:
num_encoded_tokens = num_encoded_tokens if num_encoded_tokens is not None else self.num_encoded_tokens
else:
num_encoded_tokens = sum([tensor.shape[1] for tensor in input_task_tokens.values()])
## Generating masks
if task_masks is None:
task_masks, ids_keep, ids_restore = self.generate_random_masks(
input_task_tokens,
num_encoded_tokens,
alphas=alphas,
sample_tasks_uniformly=sample_tasks_uniformly
)
else:
mask_all = torch.cat([task_masks[task] for task in input_task_tokens.keys()], dim=1)
ids_shuffle = torch.argsort(mask_all, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :(mask_all == 0).sum()]
input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1)
# Apply mask
input_tokens = torch.gather(input_tokens, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, input_tokens.shape[2]))
# Add global tokens to input tokens
global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B)
input_tokens = torch.cat([input_tokens, global_tokens], dim=1)
## Transformer forward pass
encoder_tokens = self.encoder(input_tokens)
## Output decoders
if self.output_adapters is None:
return encoder_tokens, task_masks
# Decode tokens for each task using task-specific output adapters
preds = {
domain: self.output_adapters[domain](
encoder_tokens=encoder_tokens,
input_info=input_info,
ids_keep=ids_keep,
ids_restore=ids_restore,
)
for domain in self.output_adapters
if domain not in fp32_output_adapters
}
# Force running selected output adapters in fp32 mode
with torch.cuda.amp.autocast(enabled=False):
for domain in fp32_output_adapters:
if domain not in self.output_adapters:
continue
preds[domain] = self.output_adapters[domain](
encoder_tokens=encoder_tokens.float(),
input_info=input_info,
ids_keep=ids_keep,
ids_restore=ids_restore,
)
return preds, task_masks
@register_model
def pretrain_multimae_base(
input_adapters: Dict[str, nn.Module],
output_adapters: Optional[Dict[str, nn.Module]],
**kwargs):
model = MultiMAE(
input_adapters=input_adapters,
output_adapters=output_adapters,
dim_tokens=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
return model
@register_model
def pretrain_multimae_large(
input_adapters: Dict[str, nn.Module],
output_adapters: Optional[Dict[str, nn.Module]],
**kwargs):
model = MultiMAE(
input_adapters=input_adapters,
output_adapters=output_adapters,
dim_tokens=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
return model
class MultiViT(MultiMAE):
"""MultiViT: Multi-modal Vision Transformer
This is MultiMAE without masking and with a simplified / faster forward pass
:param input_adapters: Dictionary of task -> input adapters
:param output_adapters: Optional dictionary of task -> output adapters
:param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1
:param dim_tokens: Dimension of encoder tokens
:param depth: Depth of encoder
:param num_heads: Number of attention heads
:param mlp_ratio: MLP hidden dim ratio
:param qkv_bias: Set to False to disable bias
:param drop_rate: Dropout after MLPs and Attention
:param attn_drop_rate: Attention matrix drop rate
:param drop_path_rate: DropPath drop rate
:param norm_layer: Type of normalization layer
"""
def process_input(self, x):
# If input x is a Tensor, assume it's RGB
x = {'rgb': x} if isinstance(x, torch.Tensor) else x
# Need image size for tokens->image reconstruction
if 'rgb' in x:
B, _, H, W = x['rgb'].shape
elif 'semseg' in x:
B, H, W = x['semseg'].shape
H *= self.input_adapters['semseg'].stride_level
W *= self.input_adapters['semseg'].stride_level
else:
B, _, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape
# Encode selected inputs to tokens
input_task_tokens = {
domain: self.input_adapters[domain](tensor)
for domain, tensor in x.items()
if domain in self.input_adapters
}
input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W))
input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1)
# Add global tokens to input tokens
global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B)
input_tokens = torch.cat([input_tokens, global_tokens], dim=1)
return input_tokens, input_info
def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor], return_all_layers=False, **kwargs):
"""
Forward pass through input adapters, transformer encoder and output adapters.
:param x: Input tensor or dictionary of tensors
:param return_all_layers: Set to True to return all transformer layers
"""
input_tokens, input_info = self.process_input(x)
# Pass tokens through Transformer
if not return_all_layers:
encoder_tokens = self.encoder(input_tokens)
else:
# Optionally access every intermediate layer
encoder_tokens = []
tokens = input_tokens
for block in self.encoder:
tokens = block(tokens)
encoder_tokens.append(tokens)
if self.output_adapters is None:
return encoder_tokens
# Decode tokens for each task using task-specific output adapters
preds = {
domain: self.output_adapters[domain](
encoder_tokens=encoder_tokens,
input_info=input_info,
)
for domain in self.output_adapters
}
return preds
@register_model
def multivit_base(
input_adapters: Dict[str, nn.Module],
output_adapters: Optional[Dict[str, nn.Module]],
**kwargs):
model = MultiViT(
input_adapters=input_adapters,
output_adapters=output_adapters,
dim_tokens=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
return model
@register_model
def multivit_large(
input_adapters: Dict[str, nn.Module],
output_adapters: Optional[Dict[str, nn.Module]],
**kwargs):
model = MultiViT(
input_adapters=input_adapters,
output_adapters=output_adapters,
dim_tokens=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs
)
return model