Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from efficientnet_pytorch import EfficientNet | |
from efficientnet_pytorch.model import MBConvBlock | |
from efficientnet_pytorch import utils as efficientnet_utils | |
from efficientnet_pytorch.utils import ( | |
round_filters, | |
round_repeats, | |
get_same_padding_conv2d, | |
calculate_output_image_size, | |
MemoryEfficientSwish, | |
) | |
from einops import rearrange, reduce | |
from torch.hub import load_state_dict_from_url | |
model_dir = os.getcwd() | |
class _EffiNet(nn.Module): | |
"""A proxy for efficient net models""" | |
def __init__(self, | |
blocks_args=None, | |
global_params=None, | |
prune_start_layer: int = 0, | |
prune_se: bool = True, | |
prune_ratio: float = 0.0 | |
) -> None: | |
super().__init__() | |
if prune_ratio > 0: | |
self.eff_net = EfficientNetB2Pruned(blocks_args=blocks_args, | |
global_params=global_params, | |
prune_start_layer=prune_start_layer, | |
prune_se=prune_se, | |
prune_ratio=prune_ratio) | |
else: | |
self.eff_net = EfficientNet(blocks_args=blocks_args, | |
global_params=global_params) | |
def forward(self, x: torch.Tensor): | |
x = rearrange(x, 'b f t -> b 1 f t') | |
x = self.eff_net.extract_features(x) | |
return reduce(x, 'b c f t -> b t c', 'mean') | |
def get_model(pretrained=True) -> _EffiNet: | |
blocks_args, global_params = efficientnet_utils.get_model_params( | |
'efficientnet-b2', {'include_top': False}) | |
model = _EffiNet(blocks_args=blocks_args, | |
global_params=global_params) | |
model.eff_net._change_in_channels(1) | |
if pretrained: | |
model_path = os.path.join(model_dir, "effb2.pt") | |
if not os.path.exists(model_path): | |
state_dict = load_state_dict_from_url( | |
'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt', | |
progress=True, | |
model_dir=model_dir) | |
else: | |
state_dict = torch.load(model_path) | |
del_keys = [key for key in state_dict if key.startswith("front_end")] | |
for key in del_keys: | |
del state_dict[key] | |
model.eff_net.load_state_dict(state_dict) | |
return model | |
class MBConvBlockPruned(MBConvBlock): | |
def __init__(self, block_args, global_params, image_size=None, prune_ratio=0.5, prune_se=True): | |
super(MBConvBlock, self).__init__() | |
self._block_args = block_args | |
self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow | |
self._bn_eps = global_params.batch_norm_epsilon | |
self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) | |
self.id_skip = block_args.id_skip # whether to use skip connection and drop connect | |
# Expansion phase (Inverted Bottleneck) | |
inp = self._block_args.input_filters # number of input channels | |
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels | |
if self._block_args.expand_ratio != 1: | |
oup = int(oup * (1 - prune_ratio)) | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) | |
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | |
# image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size | |
# Depthwise convolution phase | |
k = self._block_args.kernel_size | |
s = self._block_args.stride | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._depthwise_conv = Conv2d( | |
in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise | |
kernel_size=k, stride=s, bias=False) | |
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) | |
image_size = calculate_output_image_size(image_size, s) | |
# Squeeze and Excitation layer, if desired | |
if self.has_se: | |
Conv2d = get_same_padding_conv2d(image_size=(1, 1)) | |
num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) | |
if prune_se: | |
num_squeezed_channels = int(num_squeezed_channels * (1 - prune_ratio)) | |
self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) | |
self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) | |
# Pointwise convolution phase | |
final_oup = self._block_args.output_filters | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) | |
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) | |
self._swish = MemoryEfficientSwish() | |
class EfficientNetB2Pruned(EfficientNet): | |
def __init__(self, blocks_args=None, global_params=None, | |
prune_start_layer=0, prune_ratio=0.5, prune_se=True): | |
super(EfficientNet, self).__init__() | |
assert isinstance(blocks_args, list), 'blocks_args should be a list' | |
assert len(blocks_args) > 0, 'block args must be greater than 0' | |
self._global_params = global_params | |
self._blocks_args = blocks_args | |
# Batch norm parameters | |
bn_mom = 1 - self._global_params.batch_norm_momentum | |
bn_eps = self._global_params.batch_norm_epsilon | |
# Get stem static or dynamic convolution depending on image size | |
image_size = global_params.image_size | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
n_build_blks = 0 | |
# Stem | |
in_channels = 1 # spectrogram | |
p = 0.0 if n_build_blks < prune_start_layer else prune_ratio | |
out_channels = round_filters(32 * (1 - p), | |
self._global_params) # number of output channels | |
self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) | |
self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) | |
image_size = calculate_output_image_size(image_size, 2) | |
n_build_blks += 1 | |
# Build blocks | |
self._blocks = nn.ModuleList([]) | |
for block_args in self._blocks_args: | |
p = 0.0 if n_build_blks < prune_start_layer else prune_ratio | |
orig_input_filters = block_args.input_filters | |
# Update block input and output filters based on depth multiplier. | |
block_args = block_args._replace( | |
input_filters=round_filters( | |
block_args.input_filters * (1 - p), | |
self._global_params), | |
output_filters=round_filters( | |
block_args.output_filters * (1 - p), | |
self._global_params), | |
num_repeat=round_repeats(block_args.num_repeat, self._global_params) | |
) | |
if n_build_blks == prune_start_layer: | |
block_args = block_args._replace(input_filters=round_filters( | |
orig_input_filters, | |
self._global_params) | |
) | |
# The first block needs to take care of stride and filter size increase. | |
self._blocks.append(MBConvBlockPruned(block_args, self._global_params, | |
image_size=image_size, prune_ratio=p, | |
prune_se=prune_se)) | |
n_build_blks += 1 | |
image_size = calculate_output_image_size(image_size, block_args.stride) | |
if block_args.num_repeat > 1: # modify block_args to keep same output size | |
block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) | |
for _ in range(block_args.num_repeat - 1): | |
self._blocks.append(MBConvBlockPruned(block_args, | |
self._global_params, | |
image_size=image_size, | |
prune_ratio=p, | |
prune_se=prune_se)) | |
# image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 | |
# Head | |
in_channels = block_args.output_filters # output of final block | |
p = 0.0 if n_build_blks < prune_start_layer else prune_ratio | |
out_channels = round_filters(1280 * (1 - p), self._global_params) | |
Conv2d = get_same_padding_conv2d(image_size=image_size) | |
self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) | |
self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) | |
# Final linear layer | |
self._avg_pooling = nn.AdaptiveAvgPool2d(1) | |
if self._global_params.include_top: | |
self._dropout = nn.Dropout(self._global_params.dropout_rate) | |
self._fc = nn.Linear(out_channels, self._global_params.num_classes) | |
# set activation to memory efficient swish by default | |
self._swish = MemoryEfficientSwish() | |
def get_pruned_model(pretrained: bool = True, | |
prune_ratio: float = 0.5, | |
prune_start_layer: int = 0, | |
prune_se: bool = True, | |
prune_method: str = "operator_norm") -> _EffiNet: | |
import captioning.models.conv_filter_pruning as pruning_lib | |
blocks_args, global_params = efficientnet_utils.get_model_params( | |
'efficientnet-b2', {'include_top': False}) | |
# print("num blocks: ", len(blocks_args)) | |
# print("block args: ") | |
# for block_arg in blocks_args: | |
# print(block_arg) | |
model = _EffiNet(blocks_args=blocks_args, | |
global_params=global_params, | |
prune_start_layer=prune_start_layer, | |
prune_se=prune_se, | |
prune_ratio=prune_ratio) | |
if prune_method == "operator_norm": | |
filter_pruning = pruning_lib.operator_norm_pruning | |
elif prune_method == "interspeech": | |
filter_pruning = pruning_lib.cs_interspeech | |
elif prune_method == "iclr_l1": | |
filter_pruning = pruning_lib.iclr_l1 | |
elif prune_method == "iclr_gm": | |
filter_pruning = pruning_lib.iclr_gm | |
elif prune_method == "cs_waspaa": | |
filter_pruning = pruning_lib.cs_waspaa | |
if isinstance(pretrained, str): | |
ckpt = torch.load(pretrained, "cpu") | |
state_dict = {} | |
for key in ckpt["model"].keys(): | |
if key.startswith("model.encoder.backbone"): | |
state_dict[key[len("model.encoder.backbone.eff_net."):]] = ckpt["model"][key] | |
elif isinstance(pretrained, bool): | |
model_path = os.path.join(model_dir, "effb2.pt") | |
if not os.path.exists(model_path): | |
state_dict = load_state_dict_from_url( | |
'https://github.com/richermans/HEAR2021_EfficientLatent/releases/download/v0.0.1/effb2.pt', | |
progress=True, | |
model_dir=model_dir) | |
else: | |
state_dict = torch.load(model_path) | |
del_keys = [key for key in state_dict if key.startswith("front_end")] | |
for key in del_keys: | |
del state_dict[key] | |
# load pretrained model with corresponding filters | |
# rule: | |
# * depthwise_conv: in_ch_idx = out_ch_idx = prev_conv_idx | |
mod_dep_path = [ | |
"_conv_stem", | |
] | |
conv_to_bn = {"_conv_stem": "_bn0"} | |
for i in range(2): | |
mod_dep_path.extend([ | |
f"_blocks.{i}._depthwise_conv", | |
f"_blocks.{i}._se_reduce", | |
f"_blocks.{i}._se_expand", | |
f"_blocks.{i}._project_conv", | |
]) | |
conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1" | |
conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2" | |
for i in range(2, 23): | |
mod_dep_path.extend([ | |
f"_blocks.{i}._expand_conv", | |
f"_blocks.{i}._depthwise_conv", | |
f"_blocks.{i}._se_reduce", | |
f"_blocks.{i}._se_expand", | |
f"_blocks.{i}._project_conv" | |
]) | |
conv_to_bn[f"_blocks.{i}._expand_conv"] = f"_blocks.{i}._bn0" | |
conv_to_bn[f"_blocks.{i}._depthwise_conv"] = f"_blocks.{i}._bn1" | |
conv_to_bn[f"_blocks.{i}._project_conv"] = f"_blocks.{i}._bn2" | |
mod_dep_path.append("_conv_head") | |
conv_to_bn["_conv_head"] = "_bn1" | |
# print(mod_dep_path) | |
# print(conv_to_bn) | |
key_to_w_b_idx = {} | |
model_dict = model.eff_net.state_dict() | |
for conv_key in tqdm(mod_dep_path): | |
weight = state_dict[f"{conv_key}.weight"] | |
ptr_n_filter = weight.size(0) | |
model_n_filter = model_dict[f"{conv_key}.weight"].size(0) | |
if model_n_filter < ptr_n_filter: | |
key_to_w_b_idx[conv_key] = filter_pruning(weight.numpy())[:model_n_filter] | |
else: | |
key_to_w_b_idx[conv_key] = slice(None) | |
pruned_state_dict = {} | |
for conv_key, prev_conv_key in zip(mod_dep_path, [None] + mod_dep_path[:-1]): | |
for sub_key in ["weight", "bias"]: # adjust the conv layer | |
cur_key = f"{conv_key}.{sub_key}" | |
if cur_key not in state_dict: | |
continue | |
if prev_conv_key is None or conv_key.endswith("_depthwise_conv"): | |
conv_in_idx = slice(None) | |
else: | |
conv_in_idx = key_to_w_b_idx[prev_conv_key] | |
# the first pruned layer | |
if model_dict[cur_key].ndim > 1 and model_dict[cur_key].size(1) == state_dict[cur_key].size(1): | |
conv_in_idx = slice(None) | |
if conv_key.endswith("_depthwise_conv"): | |
conv_out_idx = key_to_w_b_idx[prev_conv_key] | |
else: | |
conv_out_idx = key_to_w_b_idx[conv_key] | |
# if conv_key == "_blocks.16._se_reduce": | |
# print(len(conv_out_idx), len(conv_in_idx)) | |
if sub_key == "weight": | |
pruned_state_dict[cur_key] = state_dict[cur_key][ | |
conv_out_idx, ...][:, conv_in_idx, ...] | |
else: | |
pruned_state_dict[cur_key] = state_dict[cur_key][ | |
conv_out_idx, ...] | |
if conv_key in conv_to_bn: # adjust the corresponding bn layer | |
for sub_key in ["weight", "bias", "running_mean", "running_var"]: | |
cur_key = f"{conv_to_bn[conv_key]}.{sub_key}" | |
if cur_key not in state_dict: | |
continue | |
pruned_state_dict[cur_key] = state_dict[cur_key][ | |
key_to_w_b_idx[conv_key], ...] | |
model.eff_net.load_state_dict(pruned_state_dict) | |
return model | |