| |
| import torch.nn as nn |
| from mmcv.cnn import build_norm_layer |
|
|
| from mmseg.registry import MODELS |
|
|
|
|
| @MODELS.register_module() |
| class Feature2Pyramid(nn.Module): |
| """Feature2Pyramid. |
| |
| A neck structure connect ViT backbone and decoder_heads. |
| |
| Args: |
| embed_dims (int): Embedding dimension. |
| rescales (list[float]): Different sampling multiples were |
| used to obtain pyramid features. Default: [4, 2, 1, 0.5]. |
| norm_cfg (dict): Config dict for normalization layer. |
| Default: dict(type='SyncBN', requires_grad=True). |
| """ |
|
|
| def __init__(self, |
| embed_dim, |
| rescales=[4, 2, 1, 0.5], |
| norm_cfg=dict(type='SyncBN', requires_grad=True)): |
| super().__init__() |
| self.rescales = rescales |
| self.upsample_4x = None |
| for k in self.rescales: |
| if k == 4: |
| self.upsample_4x = nn.Sequential( |
| nn.ConvTranspose2d( |
| embed_dim, embed_dim, kernel_size=2, stride=2), |
| build_norm_layer(norm_cfg, embed_dim)[1], |
| nn.GELU(), |
| nn.ConvTranspose2d( |
| embed_dim, embed_dim, kernel_size=2, stride=2), |
| ) |
| elif k == 2: |
| self.upsample_2x = nn.Sequential( |
| nn.ConvTranspose2d( |
| embed_dim, embed_dim, kernel_size=2, stride=2)) |
| elif k == 1: |
| self.identity = nn.Identity() |
| elif k == 0.5: |
| self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) |
| elif k == 0.25: |
| self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) |
| else: |
| raise KeyError(f'invalid {k} for feature2pyramid') |
|
|
| def forward(self, inputs): |
| assert len(inputs) == len(self.rescales) |
| outputs = [] |
| if self.upsample_4x is not None: |
| ops = [ |
| self.upsample_4x, self.upsample_2x, self.identity, |
| self.downsample_2x |
| ] |
| else: |
| ops = [ |
| self.upsample_2x, self.identity, self.downsample_2x, |
| self.downsample_4x |
| ] |
| for i in range(len(inputs)): |
| outputs.append(ops[i](inputs[i])) |
| return tuple(outputs) |
|
|