| |
| from typing import Optional, Tuple |
|
|
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from mmdet.registry import MODELS |
| from .fpn import FPN |
|
|
|
|
| @MODELS.register_module() |
| class FPN_DropBlock(FPN): |
|
|
| def __init__(self, |
| *args, |
| plugin: Optional[dict] = dict( |
| type='DropBlock', |
| drop_prob=0.3, |
| block_size=3, |
| warmup_iters=0), |
| **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| self.plugin = None |
| if plugin is not None: |
| self.plugin = MODELS.build(plugin) |
|
|
| def forward(self, inputs: Tuple[Tensor]) -> tuple: |
| """Forward function. |
| |
| Args: |
| inputs (tuple[Tensor]): Features from the upstream network, each |
| is a 4D-tensor. |
| |
| Returns: |
| tuple: Feature maps, each is a 4D-tensor. |
| """ |
| assert len(inputs) == len(self.in_channels) |
|
|
| |
| laterals = [ |
| lateral_conv(inputs[i + self.start_level]) |
| for i, lateral_conv in enumerate(self.lateral_convs) |
| ] |
|
|
| |
| used_backbone_levels = len(laterals) |
| for i in range(used_backbone_levels - 1, 0, -1): |
| |
| |
| if 'scale_factor' in self.upsample_cfg: |
| |
| laterals[i - 1] = laterals[i - 1] + F.interpolate( |
| laterals[i], **self.upsample_cfg) |
| else: |
| prev_shape = laterals[i - 1].shape[2:] |
| laterals[i - 1] = laterals[i - 1] + F.interpolate( |
| laterals[i], size=prev_shape, **self.upsample_cfg) |
|
|
| if self.plugin is not None: |
| laterals[i - 1] = self.plugin(laterals[i - 1]) |
|
|
| |
| |
| outs = [ |
| self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) |
| ] |
| |
| if self.num_outs > len(outs): |
| |
| |
| if not self.add_extra_convs: |
| for i in range(self.num_outs - used_backbone_levels): |
| outs.append(F.max_pool2d(outs[-1], 1, stride=2)) |
| |
| else: |
| if self.add_extra_convs == 'on_input': |
| extra_source = inputs[self.backbone_end_level - 1] |
| elif self.add_extra_convs == 'on_lateral': |
| extra_source = laterals[-1] |
| elif self.add_extra_convs == 'on_output': |
| extra_source = outs[-1] |
| else: |
| raise NotImplementedError |
| outs.append(self.fpn_convs[used_backbone_levels](extra_source)) |
| for i in range(used_backbone_levels + 1, self.num_outs): |
| if self.relu_before_extra_convs: |
| outs.append(self.fpn_convs[i](F.relu(outs[-1]))) |
| else: |
| outs.append(self.fpn_convs[i](outs[-1])) |
| return tuple(outs) |
|
|