|
|
|
from typing import Sequence, Tuple |
|
|
|
from mmengine.model import BaseModule |
|
from torch import Tensor |
|
from torch import nn as nn |
|
|
|
from mmdet3d.registry import MODELS |
|
from mmdet3d.utils import OptMultiConfig |
|
|
|
|
|
class BasicResBlock(nn.Module): |
|
expansion: int = 1 |
|
|
|
def __init__( |
|
self, |
|
inplanes: int, |
|
planes: int, |
|
stride: int = 1, |
|
padding: int = 1, |
|
downsample: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.conv1 = nn.Conv2d( |
|
inplanes, |
|
planes, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=False) |
|
self.bn1 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01) |
|
self.relu1 = nn.ReLU() |
|
self.conv2 = nn.Conv2d( |
|
planes, planes, kernel_size=3, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01) |
|
self.relu2 = nn.ReLU() |
|
self.downsample = downsample |
|
if self.downsample: |
|
self.downsample_layer = nn.Sequential( |
|
nn.Conv2d( |
|
inplanes, |
|
planes, |
|
kernel_size=1, |
|
stride=stride, |
|
padding=0, |
|
bias=False), |
|
nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)) |
|
self.stride = stride |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu1(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
if self.downsample: |
|
identity = self.downsample_layer(x) |
|
|
|
out += identity |
|
out = self.relu2(out) |
|
|
|
return out |
|
|
|
|
|
@MODELS.register_module() |
|
class ResSECOND(BaseModule): |
|
"""Backbone network for DSVT. The difference between `ResSECOND` and |
|
`SECOND` is that the basic block in this module contains residual layers. |
|
|
|
Args: |
|
in_channels (int): Input channels. |
|
out_channels (list[int]): Output channels for multi-scale feature maps. |
|
blocks_nums (list[int]): Number of blocks in each stage. |
|
layer_strides (list[int]): Strides of each stage. |
|
init_cfg (dict, optional): Config for weight initialization. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels: int = 128, |
|
out_channels: Sequence[int] = [128, 128, 256], |
|
blocks_nums: Sequence[int] = [1, 2, 2], |
|
layer_strides: Sequence[int] = [2, 2, 2], |
|
init_cfg: OptMultiConfig = None) -> None: |
|
super(ResSECOND, self).__init__(init_cfg=init_cfg) |
|
assert len(layer_strides) == len(blocks_nums) |
|
assert len(out_channels) == len(blocks_nums) |
|
|
|
in_filters = [in_channels, *out_channels[:-1]] |
|
blocks = [] |
|
for i, block_num in enumerate(blocks_nums): |
|
cur_layers = [ |
|
BasicResBlock( |
|
in_filters[i], |
|
out_channels[i], |
|
stride=layer_strides[i], |
|
downsample=True) |
|
] |
|
for _ in range(block_num): |
|
cur_layers.append( |
|
BasicResBlock(out_channels[i], out_channels[i])) |
|
blocks.append(nn.Sequential(*cur_layers)) |
|
self.blocks = nn.Sequential(*blocks) |
|
|
|
def forward(self, x: Tensor) -> Tuple[Tensor, ...]: |
|
"""Forward function. |
|
|
|
Args: |
|
x (torch.Tensor): Input with shape (N, C, H, W). |
|
|
|
Returns: |
|
tuple[torch.Tensor]: Multi-scale features. |
|
""" |
|
outs = [] |
|
for i in range(len(self.blocks)): |
|
x = self.blocks[i](x) |
|
outs.append(x) |
|
return tuple(outs) |
|
|