Spaces:
Runtime error
Runtime error
# Copyright (c) EPFL VILAB. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases | |
# https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# https://github.com/facebookresearch/deit | |
# https://github.com/facebookresearch/dino | |
# https://github.com/facebookresearch/moco-v3 | |
# https://github.com/microsoft/unilm/tree/master/beit | |
# https://github.com/BUPT-PRIV/MAE-priv | |
# https://github.com/facebookresearch/mae | |
# -------------------------------------------------------- | |
import itertools | |
import math | |
from collections import OrderedDict | |
from functools import partial | |
from typing import Dict, List, Optional, Union | |
import torch | |
from einops import rearrange, repeat | |
from torch import nn | |
from torch.distributions.dirichlet import Dirichlet | |
from utils.registry import register_model | |
from .multimae_utils import Block, trunc_normal_ | |
__all__ = [ | |
'pretrain_multimae_base', | |
'pretrain_multimae_large', | |
'multivit_base', | |
'multivit_large', | |
] | |
class MultiMAE(nn.Module): | |
"""MultiMAE: Multi-task Multi-modal Masked Autoencoder | |
This module performs masking in its forward pass. | |
The MultiViT module defined below inherits from this module and performs a regular forward pass, | |
and should be used instead for downstream tasks | |
:param input_adapters: Dictionary of task -> input adapters | |
:param output_adapters: Optional dictionary of task -> output adapters | |
:param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 | |
:param dim_tokens: Dimension of encoder tokens | |
:param depth: Depth of encoder | |
:param num_heads: Number of attention heads | |
:param mlp_ratio: MLP hidden dim ratio | |
:param qkv_bias: Set to False to disable bias | |
:param drop_rate: Dropout after MLPs and Attention | |
:param attn_drop_rate: Attention matrix drop rate | |
:param drop_path_rate: DropPath drop rate | |
:param norm_layer: Type of normalization layer | |
""" | |
def __init__(self, | |
input_adapters: Dict[str, nn.Module], | |
output_adapters: Optional[Dict[str, nn.Module]], | |
num_global_tokens: int = 1, | |
dim_tokens: int = 768, | |
depth: int = 12, | |
num_heads: int = 12, | |
mlp_ratio: float = 4.0, | |
qkv_bias: bool = True, | |
drop_rate: float = 0.0, | |
attn_drop_rate: float = 0.0, | |
drop_path_rate: float = 0.0, | |
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6)): | |
super().__init__() | |
# Initialize input and output adapters | |
for adapter in input_adapters.values(): | |
adapter.init(dim_tokens=dim_tokens) | |
self.input_adapters = nn.ModuleDict(input_adapters) | |
if output_adapters is not None: | |
for adapter in output_adapters.values(): | |
adapter.init(dim_tokens_enc=dim_tokens) | |
self.output_adapters = nn.ModuleDict(output_adapters) | |
else: | |
self.output_adapters = None | |
# Additional learnable tokens that can be used by encoder to process/store global information | |
self.num_global_tokens = num_global_tokens | |
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens)) | |
trunc_normal_(self.global_tokens, std=0.02) | |
# Transformer encoder | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
self.encoder = nn.Sequential(*[ | |
Block(dim=dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, | |
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) | |
for i in range(depth) | |
]) | |
self.apply(self._init_weights) | |
for name, m in self.named_modules(): | |
if isinstance(m, nn.Linear): | |
if 'qkv' in name: | |
# treat the weights of Q, K, V separately | |
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
nn.init.uniform_(m.weight, -val, val) | |
elif 'kv' in name: | |
# treat the weights of K, V separately | |
val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
nn.init.uniform_(m.weight, -val, val) | |
if isinstance(m, nn.Conv2d): | |
if '.proj' in name: | |
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
w = m.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def get_num_layers(self): | |
return len(self.encoder) | |
def no_weight_decay(self): | |
no_wd_set = {'global_tokens'} | |
for task, adapter in self.input_adapters.items(): | |
if hasattr(adapter, 'no_weight_decay'): | |
to_skip = adapter.no_weight_decay() | |
to_skip = set([f'input_adapters.{task}.{name}' for name in to_skip]) | |
no_wd_set = no_wd_set | to_skip | |
for task, adapter in self.output_adapters.items(): | |
if hasattr(adapter, 'no_weight_decay'): | |
to_skip = adapter.no_weight_decay() | |
to_skip = set([f'output_adapters.{task}.{name}' for name in to_skip]) | |
no_wd_set = no_wd_set | to_skip | |
return no_wd_set | |
def sample_alphas(self, B: int, n_tasks: int, alphas: float = 1.0, eps: float = 1e-5): | |
""" | |
Sample alphas for Dirichlet sampling such that tasks are first uniformly chosen and then Dirichlet sampling | |
is performed over the chosen ones. | |
:param B: Batch size | |
:param n_tasks: Number of input tasks | |
:param alphas: Float or list to multiply task choices {0,1} by | |
:param eps: Small constant since Dirichlet alphas need to be positive | |
""" | |
valid_task_choices = torch.Tensor([list(i) for i in itertools.product([0, 1], repeat=n_tasks)][1:]) | |
rand_per_sample_choice = torch.randint(0, len(valid_task_choices), (B,)) | |
alphas_tensor = torch.index_select(valid_task_choices, 0, rand_per_sample_choice) | |
alphas_tensor = alphas_tensor * torch.tensor(alphas) + eps | |
return alphas_tensor | |
def generate_random_masks(self, | |
input_tokens: Dict[str, torch.Tensor], | |
num_encoded_tokens: int, | |
alphas: Union[float, List[float]] = 1.0, | |
sample_tasks_uniformly: bool = False) : | |
""" | |
Sample a total of num_encoded_tokens from different tasks using Dirichlet sampling. | |
:param input_tokens: Dictionary of tensors to sample num_encoded_tokens from | |
:param num_encoded_tokens: Number of tokens to select | |
:param alphas: Dirichlet distribution parameter alpha. Lower alpha = harder, | |
less uniform sampling. Can be float or list of floats. | |
:param sample_tasks_uniformly: Set to True to first sample 1-n_tasks uniformly at random | |
for each sample in the batch. Dirichlet sampling is then done over selected subsets. | |
""" | |
B = list(input_tokens.values())[0].shape[0] | |
device = list(input_tokens.values())[0].device | |
alphas = [alphas] * len(input_tokens) if isinstance(alphas, float) else alphas | |
if sample_tasks_uniformly: | |
alphas = self.sample_alphas(B, len(input_tokens), alphas=alphas) | |
task_sampling_dist = Dirichlet(alphas).sample().to(device) | |
else: | |
task_sampling_dist = Dirichlet(torch.Tensor(alphas)).sample((B,)).to(device) | |
samples_per_task = (task_sampling_dist * num_encoded_tokens).round().long() | |
task_masks = [] | |
num_tokens_per_task = [task_tokens.shape[1] for task_tokens in input_tokens.values()] | |
for i, num_tokens in enumerate(num_tokens_per_task): | |
# Use noise to shuffle arange | |
noise = torch.rand(B, num_tokens, device=device) # noise in [0, 1] | |
ids_arange_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
mask = torch.arange(num_tokens, device=device).unsqueeze(0).expand(B, -1) | |
mask = torch.gather(mask, dim=1, index=ids_arange_shuffle) | |
# 0 is keep (unmasked), 1 is remove (masked) | |
mask = torch.where(mask < samples_per_task[:, i].unsqueeze(1), 0, 1) | |
task_masks.append(mask) | |
mask_all = torch.cat(task_masks, dim=1) | |
ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float()), dim=1) | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
ids_keep = ids_shuffle[:, :num_encoded_tokens] | |
# Update binary mask to adjust for task rounding | |
mask_all = torch.ones_like(mask_all) | |
mask_all[:, :num_encoded_tokens] = 0 | |
# Unshuffle to get the binary mask | |
mask_all = torch.gather(mask_all, dim=1, index=ids_restore) | |
# Split to get task masks | |
task_masks = torch.split(mask_all, num_tokens_per_task, dim=1) | |
# Convert to dict | |
task_masks = {domain: mask for domain, mask in zip(input_tokens.keys(), task_masks)} | |
return task_masks, ids_keep, ids_restore | |
def make_mask(N_H, N_W, xy_idxs, full_tasks=[], indicate_visible=True, flatten=True, device='cuda'): | |
""" | |
Creates masks for each task, given lists of un-masked x,y coordinates. | |
""" | |
xy_idxs = { | |
k: torch.LongTensor(v) | |
for k, v in xy_idxs.items() | |
} | |
task_masks = { | |
k: torch.ones(N_H, N_W).to(device) | |
for k in xy_idxs.keys() | |
} | |
for k in xy_idxs.keys(): | |
if len(xy_idxs[k]) > 0: | |
task_masks[k][xy_idxs[k][:, 1], xy_idxs[k][:, 0]] = 0 | |
for task in full_tasks: | |
task_masks[task][:] = 0 | |
if not indicate_visible: | |
task_masks = {k: 1 - v for k, v in task_masks.items()} | |
if flatten: | |
task_masks = {k: v.flatten().unsqueeze(0) for k, v in task_masks.items()} | |
return task_masks | |
def generate_input_info(self, input_task_tokens, image_size): | |
input_info = OrderedDict() | |
i = 0 | |
input_info['tasks'] = {} | |
for domain, tensor in input_task_tokens.items(): | |
num_tokens = tensor.shape[1] | |
d = { | |
'num_tokens': num_tokens, | |
'has_2d_posemb': True, # TODO: Modify when adding non-2D tasks | |
'start_idx': i, | |
'end_idx': i + num_tokens, | |
} | |
i += num_tokens | |
input_info['tasks'][domain] = d | |
input_info['image_size'] = image_size | |
input_info['num_task_tokens'] = i | |
input_info['num_global_tokens'] = self.num_global_tokens | |
return input_info | |
def forward(self, | |
x: Union[Dict[str, torch.Tensor], torch.Tensor], | |
mask_inputs: bool = True, | |
task_masks: Dict[str, torch.Tensor] = None, | |
num_encoded_tokens: int = 128, | |
alphas: Union[float, List[float]] = 1.0, | |
sample_tasks_uniformly: bool = False, | |
fp32_output_adapters: List[str] = []): | |
""" | |
Forward pass through input adapters, transformer encoder and output adapters. | |
If specified, will randomly drop input tokens. | |
:param x: Input tensor or dictionary of tensors | |
:param mask_inputs: Set to True to enable random masking of input patches | |
:param task_masks: Optional dictionary of task->mask pairs. | |
:param num_encoded_tokens: Number of tokens to randomly select for encoder. | |
Only used if mask_inputs is True. | |
:param alphas: Dirichlet distribution parameter alpha for task sampling. | |
Higher alpha = harder, less uniform sampling. Can be float or list of floats. | |
:param sample_tasks_uniformly: Set to True if tasks should be uniformly presampled, | |
before Dirichlet sampling decides share of masked tokens between them. | |
:param fp32_output_adapters: List of task identifiers to force output adapters to | |
run with mixed precision turned off for stability reasons. | |
""" | |
## Processing input modalities | |
# If input x is a Tensor, assume it's RGB | |
x = {'rgb': x} if isinstance(x, torch.Tensor) else x | |
# Need image size for tokens->image reconstruction | |
# We assume that at least one of rgb or semseg is given as input before masking | |
if 'rgb' in x: | |
B, C, H, W = x['rgb'].shape | |
elif 'semseg' in x: | |
B, H, W = x['semseg'].shape | |
H *= self.input_adapters['semseg'].stride_level | |
W *= self.input_adapters['semseg'].stride_level | |
else: | |
B, C, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape | |
# Encode selected inputs to tokens | |
input_task_tokens = { | |
domain: self.input_adapters[domain](tensor) | |
for domain, tensor in x.items() | |
if domain in self.input_adapters | |
} | |
input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) | |
# Select random subset of tokens from the chosen input tasks and concatenate them | |
if mask_inputs: | |
num_encoded_tokens = num_encoded_tokens if num_encoded_tokens is not None else self.num_encoded_tokens | |
else: | |
num_encoded_tokens = sum([tensor.shape[1] for tensor in input_task_tokens.values()]) | |
## Generating masks | |
if task_masks is None: | |
task_masks, ids_keep, ids_restore = self.generate_random_masks( | |
input_task_tokens, | |
num_encoded_tokens, | |
alphas=alphas, | |
sample_tasks_uniformly=sample_tasks_uniformly | |
) | |
else: | |
mask_all = torch.cat([task_masks[task] for task in input_task_tokens.keys()], dim=1) | |
ids_shuffle = torch.argsort(mask_all, dim=1) | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
ids_keep = ids_shuffle[:, :(mask_all == 0).sum()] | |
input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1) | |
# Apply mask | |
input_tokens = torch.gather(input_tokens, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, input_tokens.shape[2])) | |
# Add global tokens to input tokens | |
global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B) | |
input_tokens = torch.cat([input_tokens, global_tokens], dim=1) | |
## Transformer forward pass | |
encoder_tokens = self.encoder(input_tokens) | |
## Output decoders | |
if self.output_adapters is None: | |
return encoder_tokens, task_masks | |
# Decode tokens for each task using task-specific output adapters | |
preds = { | |
domain: self.output_adapters[domain]( | |
encoder_tokens=encoder_tokens, | |
input_info=input_info, | |
ids_keep=ids_keep, | |
ids_restore=ids_restore, | |
) | |
for domain in self.output_adapters | |
if domain not in fp32_output_adapters | |
} | |
# Force running selected output adapters in fp32 mode | |
with torch.cuda.amp.autocast(enabled=False): | |
for domain in fp32_output_adapters: | |
if domain not in self.output_adapters: | |
continue | |
preds[domain] = self.output_adapters[domain]( | |
encoder_tokens=encoder_tokens.float(), | |
input_info=input_info, | |
ids_keep=ids_keep, | |
ids_restore=ids_restore, | |
) | |
return preds, task_masks | |
def pretrain_multimae_base( | |
input_adapters: Dict[str, nn.Module], | |
output_adapters: Optional[Dict[str, nn.Module]], | |
**kwargs): | |
model = MultiMAE( | |
input_adapters=input_adapters, | |
output_adapters=output_adapters, | |
dim_tokens=768, | |
depth=12, | |
num_heads=12, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
def pretrain_multimae_large( | |
input_adapters: Dict[str, nn.Module], | |
output_adapters: Optional[Dict[str, nn.Module]], | |
**kwargs): | |
model = MultiMAE( | |
input_adapters=input_adapters, | |
output_adapters=output_adapters, | |
dim_tokens=1024, | |
depth=24, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
class MultiViT(MultiMAE): | |
"""MultiViT: Multi-modal Vision Transformer | |
This is MultiMAE without masking and with a simplified / faster forward pass | |
:param input_adapters: Dictionary of task -> input adapters | |
:param output_adapters: Optional dictionary of task -> output adapters | |
:param num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 | |
:param dim_tokens: Dimension of encoder tokens | |
:param depth: Depth of encoder | |
:param num_heads: Number of attention heads | |
:param mlp_ratio: MLP hidden dim ratio | |
:param qkv_bias: Set to False to disable bias | |
:param drop_rate: Dropout after MLPs and Attention | |
:param attn_drop_rate: Attention matrix drop rate | |
:param drop_path_rate: DropPath drop rate | |
:param norm_layer: Type of normalization layer | |
""" | |
def process_input(self, x): | |
# If input x is a Tensor, assume it's RGB | |
x = {'rgb': x} if isinstance(x, torch.Tensor) else x | |
# Need image size for tokens->image reconstruction | |
if 'rgb' in x: | |
B, _, H, W = x['rgb'].shape | |
elif 'semseg' in x: | |
B, H, W = x['semseg'].shape | |
H *= self.input_adapters['semseg'].stride_level | |
W *= self.input_adapters['semseg'].stride_level | |
else: | |
B, _, H, W = list(x.values())[0].shape # TODO: Deal with case where not all have same shape | |
# Encode selected inputs to tokens | |
input_task_tokens = { | |
domain: self.input_adapters[domain](tensor) | |
for domain, tensor in x.items() | |
if domain in self.input_adapters | |
} | |
input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) | |
input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1) | |
# Add global tokens to input tokens | |
global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B) | |
input_tokens = torch.cat([input_tokens, global_tokens], dim=1) | |
return input_tokens, input_info | |
def forward(self, x: Union[Dict[str, torch.Tensor], torch.Tensor], return_all_layers=False, **kwargs): | |
""" | |
Forward pass through input adapters, transformer encoder and output adapters. | |
:param x: Input tensor or dictionary of tensors | |
:param return_all_layers: Set to True to return all transformer layers | |
""" | |
input_tokens, input_info = self.process_input(x) | |
# Pass tokens through Transformer | |
if not return_all_layers: | |
encoder_tokens = self.encoder(input_tokens) | |
else: | |
# Optionally access every intermediate layer | |
encoder_tokens = [] | |
tokens = input_tokens | |
for block in self.encoder: | |
tokens = block(tokens) | |
encoder_tokens.append(tokens) | |
if self.output_adapters is None: | |
return encoder_tokens | |
# Decode tokens for each task using task-specific output adapters | |
preds = { | |
domain: self.output_adapters[domain]( | |
encoder_tokens=encoder_tokens, | |
input_info=input_info, | |
) | |
for domain in self.output_adapters | |
} | |
return preds | |
def multivit_base( | |
input_adapters: Dict[str, nn.Module], | |
output_adapters: Optional[Dict[str, nn.Module]], | |
**kwargs): | |
model = MultiViT( | |
input_adapters=input_adapters, | |
output_adapters=output_adapters, | |
dim_tokens=768, | |
depth=12, | |
num_heads=12, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |
def multivit_large( | |
input_adapters: Dict[str, nn.Module], | |
output_adapters: Optional[Dict[str, nn.Module]], | |
**kwargs): | |
model = MultiViT( | |
input_adapters=input_adapters, | |
output_adapters=output_adapters, | |
dim_tokens=1024, | |
depth=24, | |
num_heads=16, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs | |
) | |
return model | |