Spaces:
Runtime error
Runtime error
File size: 5,967 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer
def is_pow2n(x):
return x > 0 and (x & (x - 1) == 0)
class ConvBlock2x(BaseModule):
"""The definition of convolution block."""
def __init__(self,
in_channels: int,
out_channels: int,
mid_channels: int,
norm_cfg: dict,
act_cfg: dict,
last_act: bool,
init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False)
self.norm1 = build_norm_layer(norm_cfg, mid_channels)
self.activate1 = MODELS.build(act_cfg)
self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False)
self.norm2 = build_norm_layer(norm_cfg, out_channels)
self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity()
def forward(self, x: torch.Tensor):
out = self.conv1(x)
out = self.norm1(out)
out = self.activate1(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.activate2(out)
return out
class DecoderConvModule(BaseModule):
"""The convolution module of decoder with upsampling."""
def __init__(self,
in_channels: int,
out_channels: int,
mid_channels: int,
kernel_size: int = 4,
scale_factor: int = 2,
num_conv_blocks: int = 1,
norm_cfg: dict = dict(type='SyncBN'),
act_cfg: dict = dict(type='ReLU6'),
last_act: bool = True,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
padding = (kernel_size - scale_factor) // 2
self.upsample = nn.ConvTranspose2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=scale_factor,
padding=padding,
bias=True)
conv_blocks_list = [
ConvBlock2x(
in_channels=in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
norm_cfg=norm_cfg,
last_act=last_act,
act_cfg=act_cfg) for _ in range(num_conv_blocks)
]
self.conv_blocks = nn.Sequential(*conv_blocks_list)
def forward(self, x):
x = self.upsample(x)
return self.conv_blocks(x)
@MODELS.register_module()
class SparKLightDecoder(BaseModule):
"""The decoder for SparK, which upsamples the feature maps.
Args:
feature_dim (int): The dimension of feature map.
upsample_ratio (int): The ratio of upsample, equal to downsample_raito
of the algorithm.
mid_channels (int): The middle channel of `DecoderConvModule`. Defaults
to 0.
kernel_size (int): The kernel size of `ConvTranspose2d` in
`DecoderConvModule`. Defaults to 4.
scale_factor (int): The scale_factor of `ConvTranspose2d` in
`DecoderConvModule`. Defaults to 2.
num_conv_blocks (int): The number of convolution blocks in
`DecoderConvModule`. Defaults to 1.
norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN').
act_cfg (dict): Activation config. Defaults to dict(type='ReLU6').
last_act (bool): Whether apply the last activation in
`DecoderConvModule`. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
feature_dim: int,
upsample_ratio: int,
mid_channels: int = 0,
kernel_size: int = 4,
scale_factor: int = 2,
num_conv_blocks: int = 1,
norm_cfg: dict = dict(type='SyncBN'),
act_cfg: dict = dict(type='ReLU6'),
last_act: bool = False,
init_cfg: Optional[dict] = [
dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']),
dict(type='TruncNormal', std=0.02, layer=['Linear']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm'])
],
):
super().__init__(init_cfg=init_cfg)
self.feature_dim = feature_dim
assert is_pow2n(upsample_ratio)
n = round(math.log2(upsample_ratio))
channels = [feature_dim // 2**i for i in range(n + 1)]
self.decoder = nn.ModuleList([
DecoderConvModule(
in_channels=c_in,
out_channels=c_out,
mid_channels=c_in if mid_channels == 0 else mid_channels,
kernel_size=kernel_size,
scale_factor=scale_factor,
num_conv_blocks=num_conv_blocks,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
last_act=last_act)
for (c_in, c_out) in zip(channels[:-1], channels[1:])
])
self.proj = nn.Conv2d(
channels[-1], 3, kernel_size=1, stride=1, bias=True)
def forward(self, to_dec):
x = 0
for i, d in enumerate(self.decoder):
if i < len(to_dec) and to_dec[i] is not None:
x = x + to_dec[i]
x = self.decoder[i](x)
return self.proj(x)
|