EdgeTA / utils /dl /common /pruning.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import copy
import os
import sys
from tabnanny import verbose
from typing import List, Optional, Tuple
import torch
from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner
from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup
from ...common.others import get_cur_time_str
def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False):
pruner.compress()
pid = os.getpid()
timestamp = get_cur_time_str()
tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp)
tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp)
pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path)
os.remove(tmp_model_path)
# speed up
dummy_input = torch.rand(model_input_size).to(device)
pruned_model = model
pruned_model.eval()
model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device)
fixed_mask = model_speedup.speedup_model()
if not need_return_mask:
os.remove(tmp_mask_path)
return pruned_model
else:
mask = fixed_mask
os.remove(tmp_mask_path)
return pruned_model, mask
def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float,
model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False):
"""Get the pruned model via L1 Filter Pruning.
Reference:
Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016.
Args:
model (torch.nn.Module): A PyTorch model.
pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned.
sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher.
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`.
device (str): Typically be 'cpu' or 'cuda'.
verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX)
need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False.
dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False.
Returns:
torch.nn.Module: Pruned model.
"""
model = copy.deepcopy(model).to(device)
if sparsity == 0:
return model
pruned_model = copy.deepcopy(model).to(device)
# generate mask
model.eval()
if pruned_layers_name is not None:
config_list = [{
'op_types': ['Conv2d', 'ConvTranspose2d'],
'op_names': pruned_layers_name,
'sparsity': sparsity
}]
else:
config_list = [{
'op_types': ['Conv2d', 'ConvTranspose2d'],
'sparsity': sparsity
}]
pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware,
dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None)
pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask)
return pruned_model