werning's picture
Add code and weights
f3b0a90
raw
history blame
8.04 kB
import torch
import torch.nn as nn
from torchvision.ops.misc import Conv2dNormActivation
from .helpers.utils import make_divisible
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
def initialize_weights(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
class Block(nn.Module):
def __init__(self, in_channels, out_channels, expansion_rate, stride):
super().__init__()
exp_channels = make_divisible(in_channels * expansion_rate, 8)
# create the three factorized convs that make up the inverted bottleneck block
exp_conv = Conv2dNormActivation(
in_channels,
exp_channels,
kernel_size=1,
stride=1,
norm_layer=nn.BatchNorm2d,
activation_layer=nn.ReLU,
inplace=False,
)
# depthwise convolution with possible stride
depth_conv = Conv2dNormActivation(
exp_channels,
exp_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=exp_channels,
norm_layer=nn.BatchNorm2d,
activation_layer=nn.ReLU,
inplace=False,
)
proj_conv = Conv2dNormActivation(
exp_channels,
out_channels,
kernel_size=1,
stride=1,
norm_layer=nn.BatchNorm2d,
activation_layer=None,
inplace=False,
)
self.after_block_activation = nn.ReLU()
if in_channels == out_channels:
self.use_shortcut = True
if stride == 1 or stride == (1, 1):
self.shortcut = nn.Sequential()
else:
# average pooling required for shortcut
self.shortcut = nn.Sequential(
nn.AvgPool2d(kernel_size=3, stride=stride, padding=1),
nn.Sequential(),
)
else:
self.use_shortcut = False
self.block = nn.Sequential(exp_conv, depth_conv, proj_conv)
def forward(self, x):
if self.use_shortcut:
x = self.block(x) + self.shortcut(x)
else:
x = self.block(x)
x = self.after_block_activation(x)
return x
class NetworkConfig(PretrainedConfig):
def __init__(
self,
n_classes=10,
in_channels=1,
base_channels=32,
channels_multiplier=2.3,
expansion_rate=3.0,
n_blocks=(3, 2, 1),
strides=dict(b2=(1, 1), b3=(1, 2), b4=(2, 1)),
add_feats=False,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.n_classes = n_classes
self.in_channels = in_channels
self.base_channels = base_channels
self.channels_multiplier = channels_multiplier
self.expansion_rate = expansion_rate
self.n_blocks = n_blocks
self.strides = strides
self.add_feats = add_feats
class Network(PreTrainedModel):
config_class = NetworkConfig
def __init__(self, config):
super().__init__(config)
n_classes = config.n_classes
in_channels = config.in_channels
base_channels = config.base_channels
channels_multiplier = config.channels_multiplier
expansion_rate = config.expansion_rate
n_blocks = config.n_blocks
strides = config.strides
n_stages = len(n_blocks)
self.add_feats = config.add_feats
base_channels = make_divisible(base_channels, 8)
channels_per_stage = [base_channels] + [
make_divisible(base_channels * channels_multiplier**stage_id, 8)
for stage_id in range(n_stages)
]
self.total_block_count = 0
self.in_c = nn.Sequential(
Conv2dNormActivation(
in_channels,
channels_per_stage[0] // 4,
activation_layer=torch.nn.ReLU,
kernel_size=3,
stride=2,
inplace=False,
),
Conv2dNormActivation(
channels_per_stage[0] // 4,
channels_per_stage[0],
activation_layer=torch.nn.ReLU,
kernel_size=3,
stride=2,
inplace=False,
),
)
self.stages = nn.Sequential()
for stage_id in range(n_stages):
stage = self._make_stage(
channels_per_stage[stage_id],
channels_per_stage[stage_id + 1],
n_blocks[stage_id],
strides=strides,
expansion_rate=expansion_rate,
)
self.stages.add_module(f"s{stage_id + 1}", stage)
ff_list = []
ff_list += [
nn.Conv2d(
channels_per_stage[-1],
n_classes,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
bias=False,
),
nn.BatchNorm2d(n_classes),
]
ff_list.append(nn.AdaptiveAvgPool2d((1, 1)))
self.feed_forward = nn.Sequential(*ff_list)
self.apply(initialize_weights)
def _make_stage(self, in_channels, out_channels, n_blocks, strides, expansion_rate):
stage = nn.Sequential()
for index in range(n_blocks):
block_id = self.total_block_count + 1
bname = f"b{block_id}"
self.total_block_count = self.total_block_count + 1
if bname in strides:
stride = strides[bname]
else:
stride = (1, 1)
block = self._make_block(
in_channels, out_channels, stride=stride, expansion_rate=expansion_rate
)
stage.add_module(bname, block)
in_channels = out_channels
return stage
def _make_block(self, in_channels, out_channels, stride, expansion_rate):
block = Block(in_channels, out_channels, expansion_rate, stride)
return block
def _forward_conv(self, x):
x = self.in_c(x)
x = self.stages(x)
return x
def forward(self, x):
y = self._forward_conv(x)
x = self.feed_forward(y)
logits = x.squeeze(2).squeeze(2)
if self.add_feats:
return logits, y
else:
return logits
def get_model(
n_classes=10,
in_channels=1,
base_channels=32,
channels_multiplier=2.3,
expansion_rate=3.0,
n_blocks=(3, 2, 1),
strides=None,
add_feats=False,
):
"""
@param n_classes: number of the classes to predict
@param in_channels: input channels to the network, for audio it is by default 1
@param base_channels: number of channels after in_conv
@param channels_multiplier: controls the increase in the width of the network after each stage
@param expansion_rate: determines the expansion rate in inverted bottleneck blocks
@param n_blocks: number of blocks that should exist in each stage
@param strides: default value set below
@return: full neural network model based on the specified configs
"""
if strides is None:
strides = dict(b2=(1, 1), b3=(1, 2), b4=(2, 1))
model_config = {
"n_classes": n_classes,
"in_channels": in_channels,
"base_channels": base_channels,
"channels_multiplier": channels_multiplier,
"expansion_rate": expansion_rate,
"n_blocks": n_blocks,
"strides": strides,
"add_feats": add_feats,
}
m = Network(NetworkConfig(**model_config))
return m