KyanChen's picture
Upload 303 files
4d0eb62
raw
history blame
5.95 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
from mmpretrain.registry import MODELS
@MODELS.register_module()
class BEiTV2Neck(BaseModule):
"""Neck for BEiTV2 Pre-training.
This module construct the decoder for the final prediction.
Args:
num_layers (int): Number of encoder layers of neck. Defaults to 2.
early_layers (int): The layer index of the early output from the
backbone. Defaults to 9.
backbone_arch (str): Vision Transformer architecture. Defaults to base.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initialization value for the
learnable scaling of attention and FFN. Defaults to 0.1.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'depth': 12,
'num_heads': 12,
'feedforward_channels': 3072,
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'depth': 24,
'num_heads': 16,
'feedforward_channels': 4096,
}),
}
def __init__(
self,
num_layers: int = 2,
early_layers: int = 9,
backbone_arch: str = 'base',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: float = 0.1,
use_rel_pos_bias: bool = False,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(backbone_arch, str):
backbone_arch = backbone_arch.lower()
assert backbone_arch in set(self.arch_zoo), \
(f'Arch {backbone_arch} is not in default archs '
f'{set(self.arch_zoo)}')
self.arch_settings = self.arch_zoo[backbone_arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(backbone_arch, dict) and essential_keys <= set(
backbone_arch
), f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = backbone_arch
# stochastic depth decay rule
self.early_layers = early_layers
depth = self.arch_settings['depth']
dpr = np.linspace(0, drop_path_rate,
max(depth, early_layers + num_layers))
self.patch_aggregation = nn.ModuleList()
for i in range(early_layers, early_layers + num_layers):
_layer_cfg = dict(
embed_dims=self.arch_settings['embed_dims'],
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
window_size=None,
use_rel_pos_bias=use_rel_pos_bias)
self.patch_aggregation.append(
BEiTTransformerEncoderLayer(**_layer_cfg))
self.rescale_patch_aggregation_init_weight()
embed_dims = self.arch_settings['embed_dims']
_, norm = build_norm_layer(norm_cfg, embed_dims)
self.add_module('norm', norm)
def rescale_patch_aggregation_init_weight(self):
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.patch_aggregation):
rescale(layer.attn.proj.weight.data,
self.early_layers + layer_id + 1)
rescale(layer.ffn.layers[1].weight.data,
self.early_layers + layer_id + 1)
def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the latent prediction and final prediction.
Args:
x (Tuple[torch.Tensor]): Features of tokens.
rel_pos_bias (torch.Tensor): Shared relative position bias table.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- ``x``: The final layer features from backbone, which are normed
in ``BEiTV2Neck``.
- ``x_cls_pt``: The early state features from backbone, which are
consist of final layer cls_token and early state patch_tokens
from backbone and sent to PatchAggregation layers in the neck.
"""
early_states, x = inputs[0], inputs[1]
x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
for layer in self.patch_aggregation:
x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)
# shared norm
x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)
# remove cls_token
x = x[:, 1:]
x_cls_pt = x_cls_pt[:, 1:]
return x, x_cls_pt