DenseLabelDev / projects /glamm /models /region_encoder.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Tuple
from torch import Tensor
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv import ops
from mmcv.cnn import ConvModule, Linear
from mmengine.model import BaseModule
class BaseRoIExtractor(BaseModule, metaclass=ABCMeta):
"""Base class for RoI extractor.
Args:
roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and
arguments.
out_channels (int): Output channels of RoI layers.
featmap_strides (list[int]): Strides of input feature maps.
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
dict], optional): Initialization config dict. Defaults to None.
"""
def __init__(self,
roi_layer,
out_channels: int,
featmap_strides: List[int],
init_cfg=None) -> None:
super().__init__(init_cfg=init_cfg)
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
self.out_channels = out_channels
self.featmap_strides = featmap_strides
@property
def num_inputs(self) -> int:
"""int: Number of input feature maps."""
return len(self.featmap_strides)
def build_roi_layers(self, layer_cfg,
featmap_strides: List[int]) -> nn.ModuleList:
"""Build RoI operator to extract feature from each level feature map.
Args:
layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
config RoI layer operation. Options are modules under
``mmcv/ops`` such as ``RoIAlign``.
featmap_strides (list[int]): The stride of input feature map w.r.t
to the original image size, which would be used to scale RoI
coordinate (original image coordinate system) to feature
coordinate system.
Returns:
:obj:`nn.ModuleList`: The RoI extractor modules for each level
feature map.
"""
cfg = layer_cfg.copy()
layer_type = cfg.pop('type')
if isinstance(layer_type, str):
assert hasattr(ops, layer_type)
layer_cls = getattr(ops, layer_type)
else:
layer_cls = layer_type
roi_layers = nn.ModuleList(
[layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
return roi_layers
def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor:
"""Scale RoI coordinates by scale factor.
Args:
rois (Tensor): RoI (Region of Interest), shape (n, 5)
scale_factor (float): Scale factor that RoI will be multiplied by.
Returns:
Tensor: Scaled RoI.
"""
cx = (rois[:, 1] + rois[:, 3]) * 0.5
cy = (rois[:, 2] + rois[:, 4]) * 0.5
w = rois[:, 3] - rois[:, 1]
h = rois[:, 4] - rois[:, 2]
new_w = w * scale_factor
new_h = h * scale_factor
x1 = cx - new_w * 0.5
x2 = cx + new_w * 0.5
y1 = cy - new_h * 0.5
y2 = cy + new_h * 0.5
new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
return new_rois
@abstractmethod
def forward(self,
feats: Tuple[Tensor],
rois: Tensor,
roi_scale_factor: Optional[float] = None) -> Tensor:
"""Extractor ROI feats.
Args:
feats (Tuple[Tensor]): Multi-scale features.
rois (Tensor): RoIs with the shape (n, 5) where the first
column indicates batch id of each RoI.
roi_scale_factor (Optional[float]): RoI scale factor.
Defaults to None.
Returns:
Tensor: RoI feature.
"""
pass
class MLVLFuseModule(nn.Module):
def __init__(self, input_dims=1024, embed_dims=1024, num_levels=3, num_fuse=4):
super(MLVLFuseModule, self).__init__()
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_fuse = num_fuse
self.input_dims = input_dims
self.shuffle_channles = embed_dims // 4
# contains the tuple of level indices that will do the interaction
self.fuse_lvl_list = []
num_levels = self.num_levels
for lvl in range(num_levels):
top_lvl = min(lvl + 1, num_levels - 1)
dow_lvl = max(lvl - 1, 0)
tar_lvl = lvl
self.fuse_lvl_list.append((tar_lvl, top_lvl, dow_lvl))
self.remain_chs = self.embed_dims - self.shuffle_channles * 2
self._init_layers()
def generate_coordinate(self, featmap_sizes, device='cuda'):
x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([featmap_sizes[0], 1, -1, -1])
x = x.expand([featmap_sizes[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
return coord_feat
def _init_layers(self):
self.input_conv = nn.ModuleList([nn.Conv2d(self.input_dims + 2,
self.embed_dims, 1)
for _ in range(self.num_levels)])
self.fuse_convs = nn.ModuleList()
for i in range(self.num_fuse):
self.fuse_convs.append(
ConvModule(self.embed_dims,
self.embed_dims,
3,
stride=1,
padding=3 // 2,
conv_cfg=None,
norm_cfg=dict(type='GN',
num_groups=64,
requires_grad=True)
))
def init_weights(self):
pass
def _single_shuffle(self, inputs, conv_module):
if not isinstance(conv_module, (nn.ModuleList, list)):
conv_module = [conv_module]
for single_conv_m in conv_module:
fused_inputs = []
for fuse_lvl_tuple in self.fuse_lvl_list:
tar_lvl, top_lvl, dow_lvl = fuse_lvl_tuple
tar_input = inputs[tar_lvl]
top_input = inputs[top_lvl]
down_input = inputs[dow_lvl]
remain = tar_input[:, :self.remain_chs]
from_top = top_input[:, self.remain_chs:][:, self.shuffle_channles:]
from_top = F.interpolate(from_top.to(torch.float32),
size=tar_input.shape[-2:],
mode='bilinear',
align_corners=True)
from_down = down_input[:, self.remain_chs:][:, :self.shuffle_channles]
from_down = F.interpolate(from_down.to(torch.float32),
size=tar_input.shape[-2:],
mode='bilinear',
align_corners=True)
fused_inputs.append(
torch.cat([remain, from_top.to(remain.dtype), from_down.to(remain.dtype)], dim=1))
fused_inputs = [single_conv_m(item) for item in fused_inputs]
inputs = fused_inputs
return inputs
def forward(self, inputs, ):
feat_size = [item.shape for item in inputs]
new_inputs = []
for feat, single_feat_size in zip(inputs, feat_size):
coord_feat = self.generate_coordinate(
single_feat_size, device=inputs[0].device)
# feat = torch.cat([feat, coord_feat], dim=1)
feat = torch.cat([feat, coord_feat.to(feat.dtype)], dim=1)
new_inputs.append(feat)
inputs = new_inputs
inputs = [self.input_conv[lvl](item)
for lvl, item in enumerate(inputs)]
for conv_m in self.fuse_convs:
inputs = self._single_shuffle(inputs, [conv_m])
return inputs
class MlvlRoIExtractor(BaseRoIExtractor):
def __init__(self,
roi_layer,
out_channels,
featmap_strides,
embed_dims=1024,
stride=1,
norm_init=True,
fuse_level=3,
finest_scale=56,
init_cfg=None):
super(MlvlRoIExtractor, self).__init__(roi_layer, out_channels,
featmap_strides, init_cfg)
self.embed_dims = embed_dims
self.finest_scale = finest_scale
self.fuse_level = fuse_level
self.norm_init = norm_init
self.pconvs = nn.ModuleList(
nn.Conv2d(self.embed_dims, self.embed_dims, 3, stride=1, padding=1)
for _ in range(self.fuse_level))
self.pos_embedd = nn.Sequential(
nn.Linear(4, 256),
nn.ReLU(inplace=True),
nn.LayerNorm(256),
nn.Linear(256, 1024),
nn.ReLU(inplace=True),
nn.LayerNorm(1024),
)
self.updims = nn.Linear(1024, 4096)
self.flatten_linear = nn.Linear(
self.embed_dims * self.roi_layers[0].output_size[0] ** 2, 1024)
self.norm_init_weights()
# self.dtype = torch.float32
def norm_init_weights(self):
pass
def forward(self, feats, rois, roi_scale_factor=None):
"""Forward function."""
num_imgs = len(rois)
# feats = [item for item in feats]
batch_rois = torch.cat(rois, dim=0).to(feats[0].dtype)
pos_embedd = self.pos_embedd(batch_rois)
out_size = self.roi_layers[0].output_size
num_levels = len(feats)
if feats[0].dim() == 3:
h = w = int(math.sqrt(feats[0].shape[1]))
assert h == 16
assert w == 16
b, c = feats[0].shape[0], feats[0].shape[-1]
feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2)
for item in feats]
new_rois = []
for img_id, single_img_roi in enumerate(rois):
# rescale to original img scale
single_img_roi = single_img_roi * 224
roi_img_id = single_img_roi.new_ones(len(single_img_roi)) * img_id
single_img_roi = torch.cat(
[roi_img_id[:, None], single_img_roi], dim=1)
new_rois.append(single_img_roi)
rois = torch.cat(new_rois)
roi_feats = feats[0].new_zeros(self.fuse_level,
rois.size(0), self.out_channels, *out_size)
for i in range(num_levels):
if len(rois) > 0:
rois_ = rois
ori_dtype = feats[i].dtype
roi_feats_t = self.roi_layers[i](feats[i].to(
torch.float32), rois_.to(torch.float32))
roi_feats[i] = roi_feats_t.to(ori_dtype)
else:
roi_feats += sum(
x.view(-1)[0]
for x in self.parameters()) * 0. + feats[i].sum() * 0.
fuse_roi_feats = []
for i in range(self.fuse_level):
fuse_roi_feats.append(self.pconvs[i](roi_feats[i]))
fuse_roi_feats = sum(fuse_roi_feats)
fuse_roi_feats = F.relu(fuse_roi_feats)
fuse_roi_feats = fuse_roi_feats.flatten(1, -1)
fuse_roi_feats = self.flatten_linear(fuse_roi_feats)
fuse_roi_feats = fuse_roi_feats + pos_embedd
fuse_roi_feats = self.updims(fuse_roi_feats)
query_feats = []
for i in range(num_imgs):
mask = rois[:, 0] == i
query_feats.append(fuse_roi_feats[mask])
return query_feats
class MLVLROIQueryModule(nn.Module):
def __init__(self, embed_dims=1024, out_dims=4096,
num_levels=3):
super(MLVLROIQueryModule, self).__init__()
self.mlvl_fuse = MLVLFuseModule(input_dims=embed_dims,
embed_dims=embed_dims,
num_levels=num_levels,
num_fuse=5)
strids = [14 / 8, 14 / 4, 14 / 2, 14]
assert len(strids) == num_levels
bbox_roi_extractor = dict(roi_layer=dict(type='RoIAlign',
output_size=14,
sampling_ratio=2),
out_channels=embed_dims,
embed_dims=embed_dims,
fuse_level=num_levels,
featmap_strides=strids)
self.roi_align = MlvlRoIExtractor(**bbox_roi_extractor)
def forward(self, mlvl_feats, bboxes):
if mlvl_feats[0].dim() == 3:
h = w = int(math.sqrt(mlvl_feats[0].shape[1]))
assert h == 24
assert w == 24
b, c = mlvl_feats[0].shape[0], mlvl_feats[0].shape[-1]
mlvl_feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) for item in mlvl_feats]
base_shape = mlvl_feats[0].shape[-2:]
num_level = len(mlvl_feats)
to_shape = [(base_shape[0] * 2 ** level, base_shape[1] * 2 ** level)
for level in range(num_level)]
to_shape = to_shape[::-1]
for level in range(num_level):
feat = mlvl_feats[level]
shape = to_shape[level]
# feat = feat
# mlvl_feats[level] = F.interpolate(feat, size=shape, mode='bilinear', align_corners=True)
# todo: temporary fix for "upsample_bilinear2d_out_frame" not implemented for 'BFloat16'
feat = feat.to(torch.float32)
mlvl_feats[level] = F.interpolate(
feat, size=shape, mode='bilinear', align_corners=True)
mlvl_feats[level] = mlvl_feats[level].to(torch.bfloat16)
mlvl_feats = self.mlvl_fuse(mlvl_feats)
return self.roi_align(mlvl_feats, bboxes)