|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from .utils import ( |
|
Swish, MemoryEfficientSwish, drop_connect, round_filters, round_repeats, |
|
get_model_params, efficientnet_params, get_same_padding_conv2d, |
|
load_pretrained_weights, calculate_output_image_size |
|
) |
|
from ..build import BACKBONE_REGISTRY |
|
from ..backbone import Backbone |
|
|
|
|
|
class MBConvBlock(nn.Module): |
|
""" |
|
Mobile Inverted Residual Bottleneck Block |
|
|
|
Args: |
|
block_args (namedtuple): BlockArgs, see above |
|
global_params (namedtuple): GlobalParam, see above |
|
|
|
Attributes: |
|
has_se (bool): Whether the block contains a Squeeze and Excitation layer. |
|
""" |
|
|
|
def __init__(self, block_args, global_params, image_size=None): |
|
super().__init__() |
|
self._block_args = block_args |
|
self._bn_mom = 1 - global_params.batch_norm_momentum |
|
self._bn_eps = global_params.batch_norm_epsilon |
|
self.has_se = (self._block_args.se_ratio is |
|
not None) and (0 < self._block_args.se_ratio <= 1) |
|
self.id_skip = block_args.id_skip |
|
|
|
|
|
inp = self._block_args.input_filters |
|
oup = ( |
|
self._block_args.input_filters * self._block_args.expand_ratio |
|
) |
|
if self._block_args.expand_ratio != 1: |
|
Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
self._expand_conv = Conv2d( |
|
in_channels=inp, out_channels=oup, kernel_size=1, bias=False |
|
) |
|
self._bn0 = nn.BatchNorm2d( |
|
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps |
|
) |
|
|
|
|
|
|
|
k = self._block_args.kernel_size |
|
s = self._block_args.stride |
|
Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
self._depthwise_conv = Conv2d( |
|
in_channels=oup, |
|
out_channels=oup, |
|
groups=oup, |
|
kernel_size=k, |
|
stride=s, |
|
bias=False, |
|
) |
|
self._bn1 = nn.BatchNorm2d( |
|
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps |
|
) |
|
image_size = calculate_output_image_size(image_size, s) |
|
|
|
|
|
if self.has_se: |
|
Conv2d = get_same_padding_conv2d(image_size=(1, 1)) |
|
num_squeezed_channels = max( |
|
1, |
|
int( |
|
self._block_args.input_filters * self._block_args.se_ratio |
|
) |
|
) |
|
self._se_reduce = Conv2d( |
|
in_channels=oup, |
|
out_channels=num_squeezed_channels, |
|
kernel_size=1 |
|
) |
|
self._se_expand = Conv2d( |
|
in_channels=num_squeezed_channels, |
|
out_channels=oup, |
|
kernel_size=1 |
|
) |
|
|
|
|
|
final_oup = self._block_args.output_filters |
|
Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
self._project_conv = Conv2d( |
|
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False |
|
) |
|
self._bn2 = nn.BatchNorm2d( |
|
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps |
|
) |
|
self._swish = MemoryEfficientSwish() |
|
|
|
def forward(self, inputs, drop_connect_rate=None): |
|
""" |
|
:param inputs: input tensor |
|
:param drop_connect_rate: drop connect rate (float, between 0 and 1) |
|
:return: output of block |
|
""" |
|
|
|
|
|
x = inputs |
|
if self._block_args.expand_ratio != 1: |
|
x = self._swish(self._bn0(self._expand_conv(inputs))) |
|
x = self._swish(self._bn1(self._depthwise_conv(x))) |
|
|
|
|
|
if self.has_se: |
|
x_squeezed = F.adaptive_avg_pool2d(x, 1) |
|
x_squeezed = self._se_expand( |
|
self._swish(self._se_reduce(x_squeezed)) |
|
) |
|
x = torch.sigmoid(x_squeezed) * x |
|
|
|
x = self._bn2(self._project_conv(x)) |
|
|
|
|
|
input_filters, output_filters = ( |
|
self._block_args.input_filters, |
|
self._block_args.output_filters, |
|
) |
|
if ( |
|
self.id_skip and self._block_args.stride == 1 |
|
and input_filters == output_filters |
|
): |
|
if drop_connect_rate: |
|
x = drop_connect( |
|
x, p=drop_connect_rate, training=self.training |
|
) |
|
x = x + inputs |
|
return x |
|
|
|
def set_swish(self, memory_efficient=True): |
|
"""Sets swish function as memory efficient (for training) or standard (for export)""" |
|
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
|
|
|
|
|
class EfficientNet(Backbone): |
|
""" |
|
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods |
|
|
|
Args: |
|
blocks_args (list): A list of BlockArgs to construct blocks |
|
global_params (namedtuple): A set of GlobalParams shared between blocks |
|
|
|
Example: |
|
model = EfficientNet.from_pretrained('efficientnet-b0') |
|
|
|
""" |
|
|
|
def __init__(self, blocks_args=None, global_params=None): |
|
super().__init__() |
|
assert isinstance(blocks_args, list), "blocks_args should be a list" |
|
assert len(blocks_args) > 0, "block args must be greater than 0" |
|
self._global_params = global_params |
|
self._blocks_args = blocks_args |
|
|
|
|
|
bn_mom = 1 - self._global_params.batch_norm_momentum |
|
bn_eps = self._global_params.batch_norm_epsilon |
|
|
|
|
|
image_size = global_params.image_size |
|
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) |
|
|
|
|
|
in_channels = 3 |
|
out_channels = round_filters( |
|
32, self._global_params |
|
) |
|
self._conv_stem = Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=2, bias=False |
|
) |
|
self._bn0 = nn.BatchNorm2d( |
|
num_features=out_channels, momentum=bn_mom, eps=bn_eps |
|
) |
|
image_size = calculate_output_image_size(image_size, 2) |
|
|
|
|
|
self._blocks = nn.ModuleList([]) |
|
for block_args in self._blocks_args: |
|
|
|
|
|
block_args = block_args._replace( |
|
input_filters=round_filters( |
|
block_args.input_filters, self._global_params |
|
), |
|
output_filters=round_filters( |
|
block_args.output_filters, self._global_params |
|
), |
|
num_repeat=round_repeats( |
|
block_args.num_repeat, self._global_params |
|
), |
|
) |
|
|
|
|
|
self._blocks.append( |
|
MBConvBlock( |
|
block_args, self._global_params, image_size=image_size |
|
) |
|
) |
|
image_size = calculate_output_image_size( |
|
image_size, block_args.stride |
|
) |
|
if block_args.num_repeat > 1: |
|
block_args = block_args._replace( |
|
input_filters=block_args.output_filters, stride=1 |
|
) |
|
for _ in range(block_args.num_repeat - 1): |
|
self._blocks.append( |
|
MBConvBlock( |
|
block_args, self._global_params, image_size=image_size |
|
) |
|
) |
|
|
|
|
|
|
|
in_channels = block_args.output_filters |
|
out_channels = round_filters(1280, self._global_params) |
|
Conv2d = get_same_padding_conv2d(image_size=image_size) |
|
self._conv_head = Conv2d( |
|
in_channels, out_channels, kernel_size=1, bias=False |
|
) |
|
self._bn1 = nn.BatchNorm2d( |
|
num_features=out_channels, momentum=bn_mom, eps=bn_eps |
|
) |
|
|
|
|
|
self._avg_pooling = nn.AdaptiveAvgPool2d(1) |
|
self._dropout = nn.Dropout(self._global_params.dropout_rate) |
|
|
|
self._swish = MemoryEfficientSwish() |
|
|
|
self._out_features = out_channels |
|
|
|
def set_swish(self, memory_efficient=True): |
|
"""Sets swish function as memory efficient (for training) or standard (for export)""" |
|
self._swish = MemoryEfficientSwish() if memory_efficient else Swish() |
|
for block in self._blocks: |
|
block.set_swish(memory_efficient) |
|
|
|
def extract_features(self, inputs): |
|
"""Returns output of the final convolution layer""" |
|
|
|
|
|
x = self._swish(self._bn0(self._conv_stem(inputs))) |
|
|
|
|
|
for idx, block in enumerate(self._blocks): |
|
drop_connect_rate = self._global_params.drop_connect_rate |
|
if drop_connect_rate: |
|
drop_connect_rate *= float(idx) / len(self._blocks) |
|
x = block(x, drop_connect_rate=drop_connect_rate) |
|
|
|
|
|
x = self._swish(self._bn1(self._conv_head(x))) |
|
|
|
return x |
|
|
|
def forward(self, inputs): |
|
""" |
|
Calls extract_features to extract features, applies |
|
final linear layer, and returns logits. |
|
""" |
|
bs = inputs.size(0) |
|
|
|
x = self.extract_features(inputs) |
|
|
|
|
|
x = self._avg_pooling(x) |
|
x = x.view(bs, -1) |
|
x = self._dropout(x) |
|
|
|
return x |
|
|
|
@classmethod |
|
def from_name(cls, model_name, override_params=None): |
|
cls._check_model_name_is_valid(model_name) |
|
blocks_args, global_params = get_model_params( |
|
model_name, override_params |
|
) |
|
return cls(blocks_args, global_params) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, model_name, advprop=False, num_classes=1000, in_channels=3 |
|
): |
|
model = cls.from_name( |
|
model_name, override_params={"num_classes": num_classes} |
|
) |
|
load_pretrained_weights( |
|
model, model_name, load_fc=(num_classes == 1000), advprop=advprop |
|
) |
|
model._change_in_channels(in_channels) |
|
return model |
|
|
|
@classmethod |
|
def get_image_size(cls, model_name): |
|
cls._check_model_name_is_valid(model_name) |
|
_, _, res, _ = efficientnet_params(model_name) |
|
return res |
|
|
|
@classmethod |
|
def _check_model_name_is_valid(cls, model_name): |
|
"""Validates model name.""" |
|
valid_models = ["efficientnet-b" + str(i) for i in range(9)] |
|
if model_name not in valid_models: |
|
raise ValueError( |
|
"model_name should be one of: " + ", ".join(valid_models) |
|
) |
|
|
|
def _change_in_channels(model, in_channels): |
|
if in_channels != 3: |
|
Conv2d = get_same_padding_conv2d( |
|
image_size=model._global_params.image_size |
|
) |
|
out_channels = round_filters(32, model._global_params) |
|
model._conv_stem = Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=2, bias=False |
|
) |
|
|
|
|
|
def build_efficientnet(name, pretrained): |
|
if pretrained: |
|
return EfficientNet.from_pretrained("efficientnet-{}".format(name)) |
|
else: |
|
return EfficientNet.from_name("efficientnet-{}".format(name)) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b0(pretrained=True, **kwargs): |
|
return build_efficientnet("b0", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b1(pretrained=True, **kwargs): |
|
return build_efficientnet("b1", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b2(pretrained=True, **kwargs): |
|
return build_efficientnet("b2", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b3(pretrained=True, **kwargs): |
|
return build_efficientnet("b3", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b4(pretrained=True, **kwargs): |
|
return build_efficientnet("b4", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b5(pretrained=True, **kwargs): |
|
return build_efficientnet("b5", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b6(pretrained=True, **kwargs): |
|
return build_efficientnet("b6", pretrained) |
|
|
|
|
|
@BACKBONE_REGISTRY.register() |
|
def efficientnet_b7(pretrained=True, **kwargs): |
|
return build_efficientnet("b7", pretrained) |
|
|