|
|
from abc import ABCMeta, abstractmethod |
|
|
from typing import List, Optional, Tuple |
|
|
from torch import Tensor |
|
|
|
|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from mmcv import ops |
|
|
from mmcv.cnn import ConvModule, Linear |
|
|
from mmengine.model import BaseModule |
|
|
|
|
|
class BaseRoIExtractor(BaseModule, metaclass=ABCMeta): |
|
|
"""Base class for RoI extractor. |
|
|
|
|
|
Args: |
|
|
roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and |
|
|
arguments. |
|
|
out_channels (int): Output channels of RoI layers. |
|
|
featmap_strides (list[int]): Strides of input feature maps. |
|
|
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ |
|
|
dict], optional): Initialization config dict. Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
roi_layer, |
|
|
out_channels: int, |
|
|
featmap_strides: List[int], |
|
|
init_cfg=None) -> None: |
|
|
super().__init__(init_cfg=init_cfg) |
|
|
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) |
|
|
self.out_channels = out_channels |
|
|
self.featmap_strides = featmap_strides |
|
|
|
|
|
@property |
|
|
def num_inputs(self) -> int: |
|
|
"""int: Number of input feature maps.""" |
|
|
return len(self.featmap_strides) |
|
|
|
|
|
def build_roi_layers(self, layer_cfg, |
|
|
featmap_strides: List[int]) -> nn.ModuleList: |
|
|
"""Build RoI operator to extract feature from each level feature map. |
|
|
|
|
|
Args: |
|
|
layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and |
|
|
config RoI layer operation. Options are modules under |
|
|
``mmcv/ops`` such as ``RoIAlign``. |
|
|
featmap_strides (list[int]): The stride of input feature map w.r.t |
|
|
to the original image size, which would be used to scale RoI |
|
|
coordinate (original image coordinate system) to feature |
|
|
coordinate system. |
|
|
|
|
|
Returns: |
|
|
:obj:`nn.ModuleList`: The RoI extractor modules for each level |
|
|
feature map. |
|
|
""" |
|
|
|
|
|
cfg = layer_cfg.copy() |
|
|
layer_type = cfg.pop('type') |
|
|
if isinstance(layer_type, str): |
|
|
assert hasattr(ops, layer_type) |
|
|
layer_cls = getattr(ops, layer_type) |
|
|
else: |
|
|
layer_cls = layer_type |
|
|
roi_layers = nn.ModuleList( |
|
|
[layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides]) |
|
|
return roi_layers |
|
|
|
|
|
def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor: |
|
|
"""Scale RoI coordinates by scale factor. |
|
|
|
|
|
Args: |
|
|
rois (Tensor): RoI (Region of Interest), shape (n, 5) |
|
|
scale_factor (float): Scale factor that RoI will be multiplied by. |
|
|
|
|
|
Returns: |
|
|
Tensor: Scaled RoI. |
|
|
""" |
|
|
|
|
|
cx = (rois[:, 1] + rois[:, 3]) * 0.5 |
|
|
cy = (rois[:, 2] + rois[:, 4]) * 0.5 |
|
|
w = rois[:, 3] - rois[:, 1] |
|
|
h = rois[:, 4] - rois[:, 2] |
|
|
new_w = w * scale_factor |
|
|
new_h = h * scale_factor |
|
|
x1 = cx - new_w * 0.5 |
|
|
x2 = cx + new_w * 0.5 |
|
|
y1 = cy - new_h * 0.5 |
|
|
y2 = cy + new_h * 0.5 |
|
|
new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) |
|
|
return new_rois |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, |
|
|
feats: Tuple[Tensor], |
|
|
rois: Tensor, |
|
|
roi_scale_factor: Optional[float] = None) -> Tensor: |
|
|
"""Extractor ROI feats. |
|
|
|
|
|
Args: |
|
|
feats (Tuple[Tensor]): Multi-scale features. |
|
|
rois (Tensor): RoIs with the shape (n, 5) where the first |
|
|
column indicates batch id of each RoI. |
|
|
roi_scale_factor (Optional[float]): RoI scale factor. |
|
|
Defaults to None. |
|
|
|
|
|
Returns: |
|
|
Tensor: RoI feature. |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class MLVLFuseModule(nn.Module): |
|
|
def __init__(self, input_dims=1024, embed_dims=1024, num_levels=3, num_fuse=4): |
|
|
super(MLVLFuseModule, self).__init__() |
|
|
self.embed_dims = embed_dims |
|
|
self.num_levels = num_levels |
|
|
self.num_fuse = num_fuse |
|
|
self.input_dims = input_dims |
|
|
self.shuffle_channles = embed_dims // 4 |
|
|
|
|
|
|
|
|
self.fuse_lvl_list = [] |
|
|
num_levels = self.num_levels |
|
|
for lvl in range(num_levels): |
|
|
top_lvl = min(lvl + 1, num_levels - 1) |
|
|
dow_lvl = max(lvl - 1, 0) |
|
|
tar_lvl = lvl |
|
|
self.fuse_lvl_list.append((tar_lvl, top_lvl, dow_lvl)) |
|
|
|
|
|
self.remain_chs = self.embed_dims - self.shuffle_channles * 2 |
|
|
self._init_layers() |
|
|
|
|
|
def generate_coordinate(self, featmap_sizes, device='cuda'): |
|
|
|
|
|
x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) |
|
|
y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) |
|
|
y, x = torch.meshgrid(y_range, x_range) |
|
|
y = y.expand([featmap_sizes[0], 1, -1, -1]) |
|
|
x = x.expand([featmap_sizes[0], 1, -1, -1]) |
|
|
coord_feat = torch.cat([x, y], 1) |
|
|
|
|
|
return coord_feat |
|
|
|
|
|
def _init_layers(self): |
|
|
self.input_conv = nn.ModuleList([nn.Conv2d(self.input_dims + 2, |
|
|
self.embed_dims, 1) |
|
|
for _ in range(self.num_levels)]) |
|
|
self.fuse_convs = nn.ModuleList() |
|
|
for i in range(self.num_fuse): |
|
|
self.fuse_convs.append( |
|
|
ConvModule(self.embed_dims, |
|
|
self.embed_dims, |
|
|
3, |
|
|
stride=1, |
|
|
padding=3 // 2, |
|
|
conv_cfg=None, |
|
|
norm_cfg=dict(type='GN', |
|
|
num_groups=64, |
|
|
requires_grad=True) |
|
|
)) |
|
|
|
|
|
def init_weights(self): |
|
|
pass |
|
|
|
|
|
def _single_shuffle(self, inputs, conv_module): |
|
|
if not isinstance(conv_module, (nn.ModuleList, list)): |
|
|
conv_module = [conv_module] |
|
|
for single_conv_m in conv_module: |
|
|
fused_inputs = [] |
|
|
for fuse_lvl_tuple in self.fuse_lvl_list: |
|
|
tar_lvl, top_lvl, dow_lvl = fuse_lvl_tuple |
|
|
tar_input = inputs[tar_lvl] |
|
|
top_input = inputs[top_lvl] |
|
|
down_input = inputs[dow_lvl] |
|
|
remain = tar_input[:, :self.remain_chs] |
|
|
from_top = top_input[:, self.remain_chs:][:, self.shuffle_channles:] |
|
|
from_top = F.interpolate(from_top.to(torch.float32), |
|
|
size=tar_input.shape[-2:], |
|
|
mode='bilinear', |
|
|
align_corners=True) |
|
|
from_down = down_input[:, self.remain_chs:][:, :self.shuffle_channles] |
|
|
from_down = F.interpolate(from_down.to(torch.float32), |
|
|
size=tar_input.shape[-2:], |
|
|
mode='bilinear', |
|
|
align_corners=True) |
|
|
fused_inputs.append( |
|
|
torch.cat([remain, from_top.to(remain.dtype), from_down.to(remain.dtype)], dim=1)) |
|
|
fused_inputs = [single_conv_m(item) for item in fused_inputs] |
|
|
inputs = fused_inputs |
|
|
return inputs |
|
|
|
|
|
def forward(self, inputs, ): |
|
|
feat_size = [item.shape for item in inputs] |
|
|
new_inputs = [] |
|
|
for feat, single_feat_size in zip(inputs, feat_size): |
|
|
coord_feat = self.generate_coordinate( |
|
|
single_feat_size, device=inputs[0].device) |
|
|
|
|
|
feat = torch.cat([feat, coord_feat.to(feat.dtype)], dim=1) |
|
|
new_inputs.append(feat) |
|
|
inputs = new_inputs |
|
|
|
|
|
inputs = [self.input_conv[lvl](item) |
|
|
for lvl, item in enumerate(inputs)] |
|
|
|
|
|
for conv_m in self.fuse_convs: |
|
|
inputs = self._single_shuffle(inputs, [conv_m]) |
|
|
return inputs |
|
|
|
|
|
|
|
|
class MlvlRoIExtractor(BaseRoIExtractor): |
|
|
def __init__(self, |
|
|
roi_layer, |
|
|
out_channels, |
|
|
featmap_strides, |
|
|
embed_dims=1024, |
|
|
stride=1, |
|
|
norm_init=True, |
|
|
fuse_level=3, |
|
|
finest_scale=56, |
|
|
init_cfg=None): |
|
|
super(MlvlRoIExtractor, self).__init__(roi_layer, out_channels, |
|
|
featmap_strides, init_cfg) |
|
|
self.embed_dims = embed_dims |
|
|
self.finest_scale = finest_scale |
|
|
self.fuse_level = fuse_level |
|
|
self.norm_init = norm_init |
|
|
|
|
|
self.pconvs = nn.ModuleList( |
|
|
nn.Conv2d(self.embed_dims, self.embed_dims, 3, stride=1, padding=1) |
|
|
for _ in range(self.fuse_level)) |
|
|
self.pos_embedd = nn.Sequential( |
|
|
nn.Linear(4, 256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.LayerNorm(256), |
|
|
nn.Linear(256, 1024), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.LayerNorm(1024), |
|
|
) |
|
|
self.updims = nn.Linear(1024, 4096) |
|
|
|
|
|
self.flatten_linear = nn.Linear( |
|
|
self.embed_dims * self.roi_layers[0].output_size[0] ** 2, 1024) |
|
|
|
|
|
self.norm_init_weights() |
|
|
|
|
|
|
|
|
def norm_init_weights(self): |
|
|
pass |
|
|
|
|
|
def forward(self, feats, rois, roi_scale_factor=None): |
|
|
"""Forward function.""" |
|
|
num_imgs = len(rois) |
|
|
|
|
|
batch_rois = torch.cat(rois, dim=0).to(feats[0].dtype) |
|
|
pos_embedd = self.pos_embedd(batch_rois) |
|
|
out_size = self.roi_layers[0].output_size |
|
|
num_levels = len(feats) |
|
|
if feats[0].dim() == 3: |
|
|
h = w = int(math.sqrt(feats[0].shape[1])) |
|
|
assert h == 16 |
|
|
assert w == 16 |
|
|
b, c = feats[0].shape[0], feats[0].shape[-1] |
|
|
feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) |
|
|
for item in feats] |
|
|
new_rois = [] |
|
|
for img_id, single_img_roi in enumerate(rois): |
|
|
|
|
|
single_img_roi = single_img_roi * 224 |
|
|
|
|
|
roi_img_id = single_img_roi.new_ones(len(single_img_roi)) * img_id |
|
|
single_img_roi = torch.cat( |
|
|
[roi_img_id[:, None], single_img_roi], dim=1) |
|
|
new_rois.append(single_img_roi) |
|
|
rois = torch.cat(new_rois) |
|
|
|
|
|
roi_feats = feats[0].new_zeros(self.fuse_level, |
|
|
rois.size(0), self.out_channels, *out_size) |
|
|
|
|
|
for i in range(num_levels): |
|
|
if len(rois) > 0: |
|
|
rois_ = rois |
|
|
ori_dtype = feats[i].dtype |
|
|
roi_feats_t = self.roi_layers[i](feats[i].to( |
|
|
torch.float32), rois_.to(torch.float32)) |
|
|
|
|
|
roi_feats[i] = roi_feats_t.to(ori_dtype) |
|
|
|
|
|
else: |
|
|
roi_feats += sum( |
|
|
x.view(-1)[0] |
|
|
for x in self.parameters()) * 0. + feats[i].sum() * 0. |
|
|
|
|
|
fuse_roi_feats = [] |
|
|
for i in range(self.fuse_level): |
|
|
fuse_roi_feats.append(self.pconvs[i](roi_feats[i])) |
|
|
|
|
|
fuse_roi_feats = sum(fuse_roi_feats) |
|
|
fuse_roi_feats = F.relu(fuse_roi_feats) |
|
|
fuse_roi_feats = fuse_roi_feats.flatten(1, -1) |
|
|
fuse_roi_feats = self.flatten_linear(fuse_roi_feats) |
|
|
fuse_roi_feats = fuse_roi_feats + pos_embedd |
|
|
fuse_roi_feats = self.updims(fuse_roi_feats) |
|
|
query_feats = [] |
|
|
for i in range(num_imgs): |
|
|
mask = rois[:, 0] == i |
|
|
query_feats.append(fuse_roi_feats[mask]) |
|
|
|
|
|
return query_feats |
|
|
|
|
|
|
|
|
class MLVLROIQueryModule(nn.Module): |
|
|
def __init__(self, embed_dims=1024, out_dims=4096, |
|
|
num_levels=3): |
|
|
super(MLVLROIQueryModule, self).__init__() |
|
|
self.mlvl_fuse = MLVLFuseModule(input_dims=embed_dims, |
|
|
embed_dims=embed_dims, |
|
|
num_levels=num_levels, |
|
|
num_fuse=5) |
|
|
strids = [14 / 8, 14 / 4, 14 / 2, 14] |
|
|
assert len(strids) == num_levels |
|
|
bbox_roi_extractor = dict(roi_layer=dict(type='RoIAlign', |
|
|
output_size=14, |
|
|
sampling_ratio=2), |
|
|
out_channels=embed_dims, |
|
|
embed_dims=embed_dims, |
|
|
fuse_level=num_levels, |
|
|
featmap_strides=strids) |
|
|
|
|
|
self.roi_align = MlvlRoIExtractor(**bbox_roi_extractor) |
|
|
|
|
|
def forward(self, mlvl_feats, bboxes): |
|
|
if mlvl_feats[0].dim() == 3: |
|
|
h = w = int(math.sqrt(mlvl_feats[0].shape[1])) |
|
|
assert h == 24 |
|
|
assert w == 24 |
|
|
b, c = mlvl_feats[0].shape[0], mlvl_feats[0].shape[-1] |
|
|
mlvl_feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) for item in mlvl_feats] |
|
|
base_shape = mlvl_feats[0].shape[-2:] |
|
|
num_level = len(mlvl_feats) |
|
|
to_shape = [(base_shape[0] * 2 ** level, base_shape[1] * 2 ** level) |
|
|
for level in range(num_level)] |
|
|
to_shape = to_shape[::-1] |
|
|
for level in range(num_level): |
|
|
feat = mlvl_feats[level] |
|
|
shape = to_shape[level] |
|
|
|
|
|
|
|
|
|
|
|
feat = feat.to(torch.float32) |
|
|
mlvl_feats[level] = F.interpolate( |
|
|
feat, size=shape, mode='bilinear', align_corners=True) |
|
|
mlvl_feats[level] = mlvl_feats[level].to(torch.bfloat16) |
|
|
|
|
|
mlvl_feats = self.mlvl_fuse(mlvl_feats) |
|
|
|
|
|
return self.roi_align(mlvl_feats, bboxes) |
|
|
|