mm3dtest / projects /DSVT /dsvt /res_second.py
giantmonkeyTC
2344
34d1f8b
# modified from https://github.com/Haiyang-W/DSVT
from typing import Sequence, Tuple
from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn
from mmdet3d.registry import MODELS
from mmdet3d.utils import OptMultiConfig
class BasicResBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
padding: int = 1,
downsample: bool = False,
) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=padding,
bias=False)
self.bn1 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
self.relu2 = nn.ReLU()
self.downsample = downsample
if self.downsample:
self.downsample_layer = nn.Sequential(
nn.Conv2d(
inplanes,
planes,
kernel_size=1,
stride=stride,
padding=0,
bias=False),
nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01))
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
identity = self.downsample_layer(x)
out += identity
out = self.relu2(out)
return out
@MODELS.register_module()
class ResSECOND(BaseModule):
"""Backbone network for DSVT. The difference between `ResSECOND` and
`SECOND` is that the basic block in this module contains residual layers.
Args:
in_channels (int): Input channels.
out_channels (list[int]): Output channels for multi-scale feature maps.
blocks_nums (list[int]): Number of blocks in each stage.
layer_strides (list[int]): Strides of each stage.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def __init__(self,
in_channels: int = 128,
out_channels: Sequence[int] = [128, 128, 256],
blocks_nums: Sequence[int] = [1, 2, 2],
layer_strides: Sequence[int] = [2, 2, 2],
init_cfg: OptMultiConfig = None) -> None:
super(ResSECOND, self).__init__(init_cfg=init_cfg)
assert len(layer_strides) == len(blocks_nums)
assert len(out_channels) == len(blocks_nums)
in_filters = [in_channels, *out_channels[:-1]]
blocks = []
for i, block_num in enumerate(blocks_nums):
cur_layers = [
BasicResBlock(
in_filters[i],
out_channels[i],
stride=layer_strides[i],
downsample=True)
]
for _ in range(block_num):
cur_layers.append(
BasicResBlock(out_channels[i], out_channels[i]))
blocks.append(nn.Sequential(*cur_layers))
self.blocks = nn.Sequential(*blocks)
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward function.
Args:
x (torch.Tensor): Input with shape (N, C, H, W).
Returns:
tuple[torch.Tensor]: Multi-scale features.
"""
outs = []
for i in range(len(self.blocks)):
x = self.blocks[i](x)
outs.append(x)
return tuple(outs)