RSPrompter / mmyolo /models /necks /base_yolo_neck.py
KyanChen's picture
Upload 89 files
3094730
raw
history blame
11.1 kB
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Union
import torch
import torch.nn as nn
from mmdet.utils import ConfigType, OptMultiConfig
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmyolo.registry import MODELS
@MODELS.register_module()
class BaseYOLONeck(BaseModule, metaclass=ABCMeta):
"""Base neck used in YOLO series.
.. code:: text
P5 neck model structure diagram
+--------+ +-------+
|top_down|----------+--------->| out |---> output0
| layer1 | | | layer0|
+--------+ | +-------+
stride=8 ^ |
idx=0 +------+ +--------+ |
-----> |reduce|--->| cat | |
|layer0| +--------+ |
+------+ ^ v
+--------+ +-----------+
|upsample| |downsample |
| layer1 | | layer0 |
+--------+ +-----------+
^ |
+--------+ v
|top_down| +-----------+
| layer2 |--->| cat |
+--------+ +-----------+
stride=16 ^ v
idx=1 +------+ +--------+ +-----------+ +-------+
-----> |reduce|--->| cat | | bottom_up |--->| out |---> output1
|layer1| +--------+ | layer0 | | layer1|
+------+ ^ +-----------+ +-------+
| v
+--------+ +-----------+
|upsample| |downsample |
| layer2 | | layer1 |
stride=32 +--------+ +-----------+
idx=2 +------+ ^ v
-----> |reduce| | +-----------+
|layer2|---------+------->| cat |
+------+ +-----------+
v
+-----------+ +-------+
| bottom_up |--->| out |---> output2
| layer1 | | layer2|
+-----------+ +-------+
.. code:: text
P6 neck model structure diagram
+--------+ +-------+
|top_down|----------+--------->| out |---> output0
| layer1 | | | layer0|
+--------+ | +-------+
stride=8 ^ |
idx=0 +------+ +--------+ |
-----> |reduce|--->| cat | |
|layer0| +--------+ |
+------+ ^ v
+--------+ +-----------+
|upsample| |downsample |
| layer1 | | layer0 |
+--------+ +-----------+
^ |
+--------+ v
|top_down| +-----------+
| layer2 |--->| cat |
+--------+ +-----------+
stride=16 ^ v
idx=1 +------+ +--------+ +-----------+ +-------+
-----> |reduce|--->| cat | | bottom_up |--->| out |---> output1
|layer1| +--------+ | layer0 | | layer1|
+------+ ^ +-----------+ +-------+
| v
+--------+ +-----------+
|upsample| |downsample |
| layer2 | | layer1 |
+--------+ +-----------+
^ |
+--------+ v
|top_down| +-----------+
| layer3 |--->| cat |
+--------+ +-----------+
stride=32 ^ v
idx=2 +------+ +--------+ +-----------+ +-------+
-----> |reduce|--->| cat | | bottom_up |--->| out |---> output2
|layer2| +--------+ | layer1 | | layer2|
+------+ ^ +-----------+ +-------+
| v
+--------+ +-----------+
|upsample| |downsample |
| layer3 | | layer2 |
+--------+ +-----------+
stride=64 ^ v
idx=3 +------+ | +-----------+
-----> |reduce|---------+------->| cat |
|layer3| +-----------+
+------+ v
+-----------+ +-------+
| bottom_up |--->| out |---> output3
| layer2 | | layer3|
+-----------+ +-------+
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
upsample_feats_cat_first (bool): Whether the output features are
concat first after upsampling in the topdown module.
Defaults to True. Currently only YOLOv7 is false.
freeze_all(bool): Whether to freeze the model. Defaults to False
norm_cfg (dict): Config dict for normalization layer.
Defaults to None.
act_cfg (dict): Config dict for activation layer.
Defaults to None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: List[int],
out_channels: Union[int, List[int]],
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
upsample_feats_cat_first: bool = True,
freeze_all: bool = False,
norm_cfg: ConfigType = None,
act_cfg: ConfigType = None,
init_cfg: OptMultiConfig = None,
**kwargs):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.deepen_factor = deepen_factor
self.widen_factor = widen_factor
self.upsample_feats_cat_first = upsample_feats_cat_first
self.freeze_all = freeze_all
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.reduce_layers = nn.ModuleList()
for idx in range(len(in_channels)):
self.reduce_layers.append(self.build_reduce_layer(idx))
# build top-down blocks
self.upsample_layers = nn.ModuleList()
self.top_down_layers = nn.ModuleList()
for idx in range(len(in_channels) - 1, 0, -1):
self.upsample_layers.append(self.build_upsample_layer(idx=idx, n_layers=len(in_channels)))
self.top_down_layers.append(self.build_top_down_layer(idx))
# build bottom-up blocks
self.downsample_layers = nn.ModuleList()
self.bottom_up_layers = nn.ModuleList()
for idx in range(len(in_channels) - 1):
self.downsample_layers.append(self.build_downsample_layer(idx))
self.bottom_up_layers.append(self.build_bottom_up_layer(idx))
self.out_layers = nn.ModuleList()
for idx in range(len(in_channels)):
self.out_layers.append(self.build_out_layer(idx))
@abstractmethod
def build_reduce_layer(self, idx: int):
"""build reduce layer."""
pass
@abstractmethod
def build_upsample_layer(self, idx: int):
"""build upsample layer."""
pass
@abstractmethod
def build_top_down_layer(self, idx: int):
"""build top down layer."""
pass
@abstractmethod
def build_downsample_layer(self, idx: int):
"""build downsample layer."""
pass
@abstractmethod
def build_bottom_up_layer(self, idx: int):
"""build bottom up layer."""
pass
@abstractmethod
def build_out_layer(self, idx: int):
"""build out layer."""
pass
def _freeze_all(self):
"""Freeze the model."""
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep the normalization
layer freezed."""
super().train(mode)
if self.freeze_all:
self._freeze_all()
def forward(self, inputs: List[torch.Tensor]) -> tuple:
"""Forward function."""
assert len(inputs) == len(self.in_channels)
# reduce layers
reduce_outs = []
for idx in range(len(self.in_channels)):
reduce_outs.append(self.reduce_layers[idx](inputs[idx]))
# top-down path
inner_outs = [reduce_outs[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_high = inner_outs[0]
feat_low = reduce_outs[idx - 1]
upsample_feat = self.upsample_layers[len(self.in_channels) - 1 -
idx](
feat_high)
if self.upsample_feats_cat_first:
top_down_layer_inputs = torch.cat([upsample_feat, feat_low], 1)
else:
top_down_layer_inputs = torch.cat([feat_low, upsample_feat], 1)
inner_out = self.top_down_layers[len(self.in_channels) - 1 - idx](
top_down_layer_inputs)
inner_outs.insert(0, inner_out)
# bottom-up path
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_high = inner_outs[idx + 1]
downsample_feat = self.downsample_layers[idx](feat_low)
out = self.bottom_up_layers[idx](
torch.cat([downsample_feat, feat_high], 1))
outs.append(out)
# out_layers
results = []
for idx in range(len(self.in_channels)):
results.append(self.out_layers[idx](outs[idx]))
return tuple(results)