Towsif7's picture
firrst commit
59e40e1
raw
history blame
No virus
12 kB
"""
Source url: https://github.com/lukemelas/EfficientNet-PyTorch
Modified by Min Seok Lee, Wooseok Shin, Nikita Selin
License: Apache License 2.0
Changes:
- Added support for extracting edge features
- Added support for extracting object features at different levels
- Refactored the code
"""
from typing import Any, List
import torch
from torch import nn
from torch.nn import functional as F
from carvekit.ml.arch.tracerb7.effi_utils import (
get_same_padding_conv2d,
calculate_output_image_size,
MemoryEfficientSwish,
drop_connect,
round_filters,
round_repeats,
Swish,
create_block_args,
)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck Block.
Args:
block_args (namedtuple): BlockArgs, defined in utils.py.
global_params (namedtuple): GlobalParam, defined in utils.py.
image_size (tuple or list): [image_height, image_width].
References:
[1] https://arxiv.org/abs/1704.04861 (MobileNet v1)
[2] https://arxiv.org/abs/1801.04381 (MobileNet v2)
[3] https://arxiv.org/abs/1905.02244 (MobileNet v3)
"""
def __init__(self, block_args, global_params, image_size=None):
super().__init__()
self._block_args = block_args
self._bn_mom = (
1 - global_params.batch_norm_momentum
) # pytorch's difference from tensorflow
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio is not None) and (
0 < self._block_args.se_ratio <= 1
)
self.id_skip = (
block_args.id_skip
) # whether to use skip connection and drop connect
# Expansion phase (Inverted Bottleneck)
inp = self._block_args.input_filters # number of input channels
oup = (
self._block_args.input_filters * self._block_args.expand_ratio
) # number of output channels
if self._block_args.expand_ratio != 1:
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._expand_conv = Conv2d(
in_channels=inp, out_channels=oup, kernel_size=1, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._depthwise_conv = Conv2d(
in_channels=oup,
out_channels=oup,
groups=oup, # groups makes it depthwise
kernel_size=k,
stride=s,
bias=False,
)
self._bn1 = nn.BatchNorm2d(
num_features=oup, momentum=self._bn_mom, eps=self._bn_eps
)
image_size = calculate_output_image_size(image_size, s)
# Squeeze and Excitation layer, if desired
if self.has_se:
Conv2d = get_same_padding_conv2d(image_size=(1, 1))
num_squeezed_channels = max(
1, int(self._block_args.input_filters * self._block_args.se_ratio)
)
self._se_reduce = Conv2d(
in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1
)
self._se_expand = Conv2d(
in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1
)
# Pointwise convolution phase
final_oup = self._block_args.output_filters
Conv2d = get_same_padding_conv2d(image_size=image_size)
self._project_conv = Conv2d(
in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False
)
self._bn2 = nn.BatchNorm2d(
num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps
)
self._swish = MemoryEfficientSwish()
def forward(self, inputs, drop_connect_rate=None):
"""MBConvBlock's forward function.
Args:
inputs (tensor): Input tensor.
drop_connect_rate (bool): Drop connect rate (float, between 0 and 1).
Returns:
Output of this block after processing.
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._expand_conv(inputs)
x = self._bn0(x)
x = self._swish(x)
x = self._depthwise_conv(x)
x = self._bn1(x)
x = self._swish(x)
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_reduce(x_squeezed)
x_squeezed = self._swish(x_squeezed)
x_squeezed = self._se_expand(x_squeezed)
x = torch.sigmoid(x_squeezed) * x
# Pointwise Convolution
x = self._project_conv(x)
x = self._bn2(x)
# Skip connection and drop connect
input_filters, output_filters = (
self._block_args.input_filters,
self._block_args.output_filters,
)
if (
self.id_skip
and self._block_args.stride == 1
and input_filters == output_filters
):
# The combination of skip connection and drop connect brings about stochastic depth.
if drop_connect_rate:
x = drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
class EfficientNet(nn.Module):
def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), "blocks_args should be a list"
assert len(blocks_args) > 0, "block args must be greater than 0"
self._global_params = global_params
self._blocks_args = blocks_args
# Batch norm parameters
bn_mom = 1 - self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Get stem static or dynamic convolution depending on image size
image_size = global_params.image_size
Conv2d = get_same_padding_conv2d(image_size=image_size)
# Stem
in_channels = 3 # rgb
out_channels = round_filters(
32, self._global_params
) # number of output channels
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
self._bn0 = nn.BatchNorm2d(
num_features=out_channels, momentum=bn_mom, eps=bn_eps
)
image_size = calculate_output_image_size(image_size, 2)
# Build blocks
self._blocks = nn.ModuleList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(
block_args.input_filters, self._global_params
),
output_filters=round_filters(
block_args.output_filters, self._global_params
),
num_repeat=round_repeats(block_args.num_repeat, self._global_params),
)
# The first block needs to take care of stride and filter size increase.
self._blocks.append(
MBConvBlock(block_args, self._global_params, image_size=image_size)
)
image_size = calculate_output_image_size(image_size, block_args.stride)
if block_args.num_repeat > 1: # modify block_args to keep same output size
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1
)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(
MBConvBlock(block_args, self._global_params, image_size=image_size)
)
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1
self._swish = MemoryEfficientSwish()
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export).
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
def extract_endpoints(self, inputs):
endpoints = dict()
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
prev_x = x
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(
self._blocks
) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
if prev_x.size(2) > x.size(2):
endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
prev_x = x
# Head
x = self._swish(self._bn1(self._conv_head(x)))
endpoints["reduction_{}".format(len(endpoints) + 1)] = x
return endpoints
def _change_in_channels(self, in_channels):
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.
Args:
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)
class EfficientEncoderB7(EfficientNet):
def __init__(self):
super().__init__(
*create_block_args(
width_coefficient=2.0,
depth_coefficient=3.1,
dropout_rate=0.5,
image_size=600,
)
)
self._change_in_channels(3)
self.block_idx = [10, 17, 37, 54]
self.channels = [48, 80, 224, 640]
def initial_conv(self, inputs):
x = self._swish(self._bn0(self._conv_stem(inputs)))
return x
def get_blocks(self, x, H, W, block_idx):
features = []
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(
self._blocks
) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
if idx == block_idx[0]:
features.append(x.clone())
if idx == block_idx[1]:
features.append(x.clone())
if idx == block_idx[2]:
features.append(x.clone())
if idx == block_idx[3]:
features.append(x.clone())
return features
def forward(self, inputs: torch.Tensor) -> List[Any]:
B, C, H, W = inputs.size()
x = self.initial_conv(inputs) # Prepare input for the backbone
return self.get_blocks(
x, H, W, block_idx=self.block_idx
) # Get backbone features and edge maps