""" Source url: https://github.com/lukemelas/EfficientNet-PyTorch Modified by Min Seok Lee, Wooseok Shin, Nikita Selin License: Apache License 2.0 Changes: - Added support for extracting edge features - Added support for extracting object features at different levels - Refactored the code """ from typing import Any, List import torch from torch import nn from torch.nn import functional as F from carvekit.ml.arch.tracerb7.effi_utils import ( get_same_padding_conv2d, calculate_output_image_size, MemoryEfficientSwish, drop_connect, round_filters, round_repeats, Swish, create_block_args, ) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck Block. Args: block_args (namedtuple): BlockArgs, defined in utils.py. global_params (namedtuple): GlobalParam, defined in utils.py. image_size (tuple or list): [image_height, image_width]. References: [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) """ def __init__(self, block_args, global_params, image_size=None): super().__init__() self._block_args = block_args self._bn_mom = ( 1 - global_params.batch_norm_momentum ) # pytorch's difference from tensorflow self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and ( 0 < self._block_args.se_ratio <= 1 ) self.id_skip = ( block_args.id_skip ) # whether to use skip connection and drop connect # Expansion phase (Inverted Bottleneck) inp = self._block_args.input_filters # number of input channels oup = ( self._block_args.input_filters * self._block_args.expand_ratio ) # number of output channels if self._block_args.expand_ratio != 1: Conv2d = get_same_padding_conv2d(image_size=image_size) self._expand_conv = Conv2d( in_channels=inp, out_channels=oup, kernel_size=1, bias=False ) self._bn0 = nn.BatchNorm2d( num_features=oup, momentum=self._bn_mom, eps=self._bn_eps ) # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride Conv2d = get_same_padding_conv2d(image_size=image_size) self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False, ) self._bn1 = nn.BatchNorm2d( num_features=oup, momentum=self._bn_mom, eps=self._bn_eps ) image_size = calculate_output_image_size(image_size, s) # Squeeze and Excitation layer, if desired if self.has_se: Conv2d = get_same_padding_conv2d(image_size=(1, 1)) num_squeezed_channels = max( 1, int(self._block_args.input_filters * self._block_args.se_ratio) ) self._se_reduce = Conv2d( in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1 ) self._se_expand = Conv2d( in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1 ) # Pointwise convolution phase final_oup = self._block_args.output_filters Conv2d = get_same_padding_conv2d(image_size=image_size) self._project_conv = Conv2d( in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False ) self._bn2 = nn.BatchNorm2d( num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps ) self._swish = MemoryEfficientSwish() def forward(self, inputs, drop_connect_rate=None): """MBConvBlock's forward function. Args: inputs (tensor): Input tensor. drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). Returns: Output of this block after processing. """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = self._expand_conv(inputs) x = self._bn0(x) x = self._swish(x) x = self._depthwise_conv(x) x = self._bn1(x) x = self._swish(x) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_reduce(x_squeezed) x_squeezed = self._swish(x_squeezed) x_squeezed = self._se_expand(x_squeezed) x = torch.sigmoid(x_squeezed) * x # Pointwise Convolution x = self._project_conv(x) x = self._bn2(x) # Skip connection and drop connect input_filters, output_filters = ( self._block_args.input_filters, self._block_args.output_filters, ) if ( self.id_skip and self._block_args.stride == 1 and input_filters == output_filters ): # The combination of skip connection and drop connect brings about stochastic depth. if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export). Args: memory_efficient (bool): Whether to use memory-efficient version of swish. """ self._swish = MemoryEfficientSwish() if memory_efficient else Swish() class EfficientNet(nn.Module): def __init__(self, blocks_args=None, global_params=None): super().__init__() assert isinstance(blocks_args, list), "blocks_args should be a list" assert len(blocks_args) > 0, "block args must be greater than 0" self._global_params = global_params self._blocks_args = blocks_args # Batch norm parameters bn_mom = 1 - self._global_params.batch_norm_momentum bn_eps = self._global_params.batch_norm_epsilon # Get stem static or dynamic convolution depending on image size image_size = global_params.image_size Conv2d = get_same_padding_conv2d(image_size=image_size) # Stem in_channels = 3 # rgb out_channels = round_filters( 32, self._global_params ) # number of output channels self._conv_stem = Conv2d( in_channels, out_channels, kernel_size=3, stride=2, bias=False ) self._bn0 = nn.BatchNorm2d( num_features=out_channels, momentum=bn_mom, eps=bn_eps ) image_size = calculate_output_image_size(image_size, 2) # Build blocks self._blocks = nn.ModuleList([]) for block_args in self._blocks_args: # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters( block_args.input_filters, self._global_params ), output_filters=round_filters( block_args.output_filters, self._global_params ), num_repeat=round_repeats(block_args.num_repeat, self._global_params), ) # The first block needs to take care of stride and filter size increase. self._blocks.append( MBConvBlock(block_args, self._global_params, image_size=image_size) ) image_size = calculate_output_image_size(image_size, block_args.stride) if block_args.num_repeat > 1: # modify block_args to keep same output size block_args = block_args._replace( input_filters=block_args.output_filters, stride=1 ) for _ in range(block_args.num_repeat - 1): self._blocks.append( MBConvBlock(block_args, self._global_params, image_size=image_size) ) # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 self._swish = MemoryEfficientSwish() def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export). Args: memory_efficient (bool): Whether to use memory-efficient version of swish. """ self._swish = MemoryEfficientSwish() if memory_efficient else Swish() for block in self._blocks: block.set_swish(memory_efficient) def extract_endpoints(self, inputs): endpoints = dict() # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) prev_x = x # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len( self._blocks ) # scale drop connect_rate x = block(x, drop_connect_rate=drop_connect_rate) if prev_x.size(2) > x.size(2): endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x prev_x = x # Head x = self._swish(self._bn1(self._conv_head(x))) endpoints["reduction_{}".format(len(endpoints) + 1)] = x return endpoints def _change_in_channels(self, in_channels): """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. Args: in_channels (int): Input data's channel number. """ if in_channels != 3: Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) out_channels = round_filters(32, self._global_params) self._conv_stem = Conv2d( in_channels, out_channels, kernel_size=3, stride=2, bias=False ) class EfficientEncoderB7(EfficientNet): def __init__(self): super().__init__( *create_block_args( width_coefficient=2.0, depth_coefficient=3.1, dropout_rate=0.5, image_size=600, ) ) self._change_in_channels(3) self.block_idx = [10, 17, 37, 54] self.channels = [48, 80, 224, 640] def initial_conv(self, inputs): x = self._swish(self._bn0(self._conv_stem(inputs))) return x def get_blocks(self, x, H, W, block_idx): features = [] for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len( self._blocks ) # scale drop connect_rate x = block(x, drop_connect_rate=drop_connect_rate) if idx == block_idx[0]: features.append(x.clone()) if idx == block_idx[1]: features.append(x.clone()) if idx == block_idx[2]: features.append(x.clone()) if idx == block_idx[3]: features.append(x.clone()) return features def forward(self, inputs: torch.Tensor) -> List[Any]: B, C, H, W = inputs.size() x = self.initial_conv(inputs) # Prepare input for the backbone return self.get_blocks( x, H, W, block_idx=self.block_idx ) # Get backbone features and edge maps