Haisong Liu
Release model: vit_eva02_1600x640_trainval_future (#46)
102ac67 unverified
import logging
import torch
import torch.nn as nn
from mmcv.runner.checkpoint import load_state_dict
from mmdet.models.builder import BACKBONES
from .vit import ViT, SimpleFeaturePyramid, partial
from .fpn import LastLevelMaxPool
@BACKBONES.register_module()
class EVA02(nn.Module):
def __init__(
self,
# args for ViT
img_size=1024,
real_img_size=(256, 704),
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4*2/3,
qkv_bias=True,
drop_path_rate=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_abs_pos=True,
pt_hw_seq_len=16,
intp_freq=True,
window_size=0,
window_block_indexes=(),
residual_block_indexes=(),
use_act_checkpoint=False,
pretrain_img_size=224,
pretrain_use_cls_token=True,
out_feature="last_feat",
xattn=False,
frozen_blocks=-1,
# args for simple FPN
fpn_in_feature="last_feat",
fpn_out_channels=256,
fpn_scale_factors=(4.0, 2.0, 1.0, 0.5),
fpn_top_block=False,
fpn_norm="LN",
fpn_square_pad=0,
pretrained=None
):
super().__init__()
self.backbone = SimpleFeaturePyramid(
ViT(
img_size=img_size,
real_img_size=real_img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
use_abs_pos=use_abs_pos,
pt_hw_seq_len=pt_hw_seq_len,
intp_freq=intp_freq,
window_size=window_size,
window_block_indexes=window_block_indexes,
residual_block_indexes=residual_block_indexes,
use_act_checkpoint=use_act_checkpoint,
pretrain_img_size=pretrain_img_size,
pretrain_use_cls_token=pretrain_use_cls_token,
out_feature=out_feature,
xattn=xattn,
frozen_blocks=frozen_blocks,
),
in_feature=fpn_in_feature,
out_channels=fpn_out_channels,
scale_factors=fpn_scale_factors,
top_block=LastLevelMaxPool() if fpn_top_block else None,
norm=fpn_norm,
square_pad=fpn_square_pad,
)
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
if pretrained is None:
return
logging.info('Loading pretrained weights from %s' % pretrained)
state_dict = torch.load(pretrained)['model']
load_state_dict(self, state_dict, strict=False)
def forward(self, x):
outs = self.backbone(x)
return list(outs.values())