import os import torch import torch.nn as nn from tqdm import tqdm from efficientnet_pytorch import EfficientNet from efficientnet_pytorch.model import MBConvBlock from efficientnet_pytorch import utils as efficientnet_utils from efficientnet_pytorch.utils import ( round_filters, round_repeats, get_same_padding_conv2d, calculate_output_image_size, MemoryEfficientSwish, ) from einops import rearrange, reduce from torch.hub import load_state_dict_from_url model_dir = os.getcwd() class _EffiNet(nn.Module): """A proxy for efficient net models""" def __init__(self, blocks_args=None, global_params=None, prune_start_layer: int = 0, prune_se: bool = True, prune_ratio: float = 0.0 ) -> None: super().__init__() if prune_ratio > 0: self.eff_net = EfficientNetB2Pruned(blocks_args=blocks_args, global_params=global_params, prune_start_layer=prune_start_layer, prune_se=prune_se, prune_ratio=prune_ratio) else: self.eff_net = EfficientNet(blocks_args=blocks_args, global_params=global_params) def forward(self, x: torch.Tensor): x = rearrange(x, 'b f t -> b 1 f t') x = self.eff_net.extract_features(x) return reduce(x, 'b c f t -> b t c', 'mean') def get_model(pretrained=True) -> _EffiNet: blocks_args, global_params = efficientnet_utils.get_model_params( 'efficientnet-b2', {'include_top': False}) model = _EffiNet(blocks_args=blocks_args, global_params=global_params) model.eff_net._change_in_channels(1) if pretrained: model_path = os.path.join(model_dir, "effb2.pt") if not os.path.exists(model_path): state_dict = load_state_dict_from_url( 'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt', progress=True, model_dir=model_dir) else: state_dict = torch.load(model_path) del_keys = [key for key in state_dict if key.startswith("front_end")] for key in del_keys: del state_dict[key] model.eff_net.load_state_dict(state_dict) return model class MBConvBlockPruned(MBConvBlock): def __init__(self, block_args, global_params, image_size=None, prune_ratio=0.5, prune_se=True): super(MBConvBlock, self).__init__() self._block_args = block_args self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) self.id_skip = block_args.id_skip # whether to use skip connection and drop connect # Expansion phase (Inverted Bottleneck) inp = self._block_args.input_filters # number of input channels oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels if self._block_args.expand_ratio != 1: oup = int(oup * (1 - prune_ratio)) Conv2d = get_same_padding_conv2d(image_size=image_size) self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride Conv2d = get_same_padding_conv2d(image_size=image_size) self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False) self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) image_size = calculate_output_image_size(image_size, s) # Squeeze and Excitation layer, if desired if self.has_se: Conv2d = get_same_padding_conv2d(image_size=(1, 1)) num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) if prune_se: num_squeezed_channels = int(num_squeezed_channels * (1 - prune_ratio)) self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) # Pointwise convolution phase final_oup = self._block_args.output_filters Conv2d = get_same_padding_conv2d(image_size=image_size) self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish() class EfficientNetB2Pruned(EfficientNet): def __init__(self, blocks_args=None, global_params=None, prune_start_layer=0, prune_ratio=0.5, prune_se=True): super(EfficientNet, self).__init__() assert isinstance(blocks_args, list), 'blocks_args should be a list' assert len(blocks_args) > 0, 'block args must be greater than 0' self._global_params = global_params self._blocks_args = blocks_args # Batch norm parameters bn_mom = 1 - self._global_params.batch_norm_momentum bn_eps = self._global_params.batch_norm_epsilon # Get stem static or dynamic convolution depending on image size image_size = global_params.image_size Conv2d = get_same_padding_conv2d(image_size=image_size) n_build_blks = 0 # Stem in_channels = 1 # spectrogram p = 0.0 if n_build_blks < prune_start_layer else prune_ratio out_channels = round_filters(32 * (1 - p), self._global_params) # number of output channels self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) image_size = calculate_output_image_size(image_size, 2) n_build_blks += 1 # Build blocks self._blocks = nn.ModuleList([]) for block_args in self._blocks_args: p = 0.0 if n_build_blks < prune_start_layer else prune_ratio orig_input_filters = block_args.input_filters # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters( block_args.input_filters * (1 - p), self._global_params), output_filters=round_filters( block_args.output_filters * (1 - p), self._global_params), num_repeat=round_repeats(block_args.num_repeat, self._global_params) ) if n_build_blks == prune_start_layer: block_args = block_args._replace(input_filters=round_filters( orig_input_filters, self._global_params) ) # The first block needs to take care of stride and filter size increase. self._blocks.append(MBConvBlockPruned(block_args, self._global_params, image_size=image_size, prune_ratio=p, prune_se=prune_se)) n_build_blks += 1 image_size = calculate_output_image_size(image_size, block_args.stride) if block_args.num_repeat > 1: # modify block_args to keep same output size block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) for _ in range(block_args.num_repeat - 1): self._blocks.append(MBConvBlockPruned(block_args, self._global_params, image_size=image_size, prune_ratio=p, prune_se=prune_se)) # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 # Head in_channels = block_args.output_filters # output of final block p = 0.0 if n_build_blks < prune_start_layer else prune_ratio out_channels = round_filters(1280 * (1 - p), self._global_params) Conv2d = get_same_padding_conv2d(image_size=image_size) self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Final linear layer self._avg_pooling = nn.AdaptiveAvgPool2d(1) if self._global_params.include_top: self._dropout = nn.Dropout(self._global_params.dropout_rate) self._fc = nn.Linear(out_channels, self._global_params.num_classes) # set activation to memory efficient swish by default self._swish = MemoryEfficientSwish() def get_pruned_model(pretrained: bool = True, prune_ratio: float = 0.5, prune_start_layer: int = 0, prune_se: bool = True, prune_method: str = "operator_norm") -> _EffiNet: import captioning.models.conv_filter_pruning as pruning_lib blocks_args, global_params = efficientnet_utils.get_model_params( 'efficientnet-b2', {'include_top': False}) # print("num blocks: ", len(blocks_args)) # print("block args: ") # for block_arg in blocks_args: # print(block_arg) model = _EffiNet(blocks_args=blocks_args, global_params=global_params, prune_start_layer=prune_start_layer, prune_se=prune_se, prune_ratio=prune_ratio) if prune_method == "operator_norm": filter_pruning = pruning_lib.operator_norm_pruning elif prune_method == "interspeech": filter_pruning = pruning_lib.cs_interspeech elif prune_method == "iclr_l1": filter_pruning = pruning_lib.iclr_l1 elif prune_method == "iclr_gm": filter_pruning = pruning_lib.iclr_gm elif prune_method == "cs_waspaa": filter_pruning = pruning_lib.cs_waspaa if isinstance(pretrained, str): ckpt = torch.load(pretrained, "cpu") state_dict = {} for key in ckpt["model"].keys(): if key.startswith("model.encoder.backbone"): state_dict[key[len("model.encoder.backbone.eff_net."):]] = ckpt["model"][key] elif isinstance(pretrained, bool): model_path = os.path.join(model_dir, "effb2.pt") if not os.path.exists(model_path): state_dict = load_state_dict_from_url( 'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt', progress=True, model_dir=model_dir) else: state_dict = torch.load(model_path) del_keys = [key for key in state_dict if key.startswith("front_end")] for key in del_keys: del state_dict[key] # load pretrained model with corresponding filters # rule: # * depthwise_conv: in_ch_idx = out_ch_idx = prev_conv_idx mod_dep_path = [ "_conv_stem", ] conv_to_bn = {"_conv_stem": "_bn0"} for i in range(2): mod_dep_path.extend([ f"_blocks.{i}._depthwise_conv", f"_blocks.{i}._se_reduce", f"_blocks.{i}._se_expand", f"_blocks.{i}._project_conv", ]) conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1" conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2" for i in range(2, 23): mod_dep_path.extend([ f"_blocks.{i}._expand_conv", f"_blocks.{i}._depthwise_conv", f"_blocks.{i}._se_reduce", f"_blocks.{i}._se_expand", f"_blocks.{i}._project_conv" ]) conv_to_bn[f"_blocks.{i}._expand_conv"] = f"_blocks.{i}._bn0" conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1" conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2" mod_dep_path.append("_conv_head") conv_to_bn["_conv_head"] = "_bn1" # print(mod_dep_path) # print(conv_to_bn) key_to_w_b_idx = {} model_dict = model.eff_net.state_dict() for conv_key in tqdm(mod_dep_path): weight = state_dict[f"{conv_key}.weight"] ptr_n_filter = weight.size(0) model_n_filter = model_dict[f"{conv_key}.weight"].size(0) if model_n_filter < ptr_n_filter: key_to_w_b_idx[conv_key] = filter_pruning(weight.numpy())[:model_n_filter] else: key_to_w_b_idx[conv_key] = slice(None) pruned_state_dict = {} for conv_key, prev_conv_key in zip(mod_dep_path, [None] + mod_dep_path[:-1]): for sub_key in ["weight", "bias"]: # adjust the conv layer cur_key = f"{conv_key}.{sub_key}" if cur_key not in state_dict: continue if prev_conv_key is None or conv_key.endswith("_depthwise_conv"): conv_in_idx = slice(None) else: conv_in_idx = key_to_w_b_idx[prev_conv_key] # the first pruned layer if model_dict[cur_key].ndim > 1 and model_dict[cur_key].size(1) == state_dict[cur_key].size(1): conv_in_idx = slice(None) if conv_key.endswith("_depthwise_conv"): conv_out_idx = key_to_w_b_idx[prev_conv_key] else: conv_out_idx = key_to_w_b_idx[conv_key] # if conv_key == "_blocks.16._se_reduce": # print(len(conv_out_idx), len(conv_in_idx)) if sub_key == "weight": pruned_state_dict[cur_key] = state_dict[cur_key][ conv_out_idx, ...][:, conv_in_idx, ...] else: pruned_state_dict[cur_key] = state_dict[cur_key][ conv_out_idx, ...] if conv_key in conv_to_bn: # adjust the corresponding bn layer for sub_key in ["weight", "bias", "running_mean", "running_var"]: cur_key = f"{conv_to_bn[conv_key]}.{sub_key}" if cur_key not in state_dict: continue pruned_state_dict[cur_key] = state_dict[cur_key][ key_to_w_b_idx[conv_key], ...] model.eff_net.load_state_dict(pruned_state_dict) return model