|
import copy |
|
import os |
|
import sys |
|
from tabnanny import verbose |
|
from typing import List, Optional, Tuple |
|
import torch |
|
|
|
from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner |
|
from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup |
|
from ...common.others import get_cur_time_str |
|
|
|
|
|
def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False): |
|
pruner.compress() |
|
pid = os.getpid() |
|
timestamp = get_cur_time_str() |
|
tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp) |
|
tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp) |
|
pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path) |
|
os.remove(tmp_model_path) |
|
|
|
|
|
dummy_input = torch.rand(model_input_size).to(device) |
|
pruned_model = model |
|
pruned_model.eval() |
|
model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device) |
|
fixed_mask = model_speedup.speedup_model() |
|
|
|
if not need_return_mask: |
|
os.remove(tmp_mask_path) |
|
return pruned_model |
|
else: |
|
mask = fixed_mask |
|
os.remove(tmp_mask_path) |
|
return pruned_model, mask |
|
|
|
|
|
def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float, |
|
model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False): |
|
"""Get the pruned model via L1 Filter Pruning. |
|
|
|
Reference: |
|
Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016. |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned. |
|
sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher. |
|
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. |
|
device (str): Typically be 'cpu' or 'cuda'. |
|
verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX) |
|
need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False. |
|
dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False. |
|
|
|
Returns: |
|
torch.nn.Module: Pruned model. |
|
""" |
|
model = copy.deepcopy(model).to(device) |
|
if sparsity == 0: |
|
return model |
|
pruned_model = copy.deepcopy(model).to(device) |
|
|
|
|
|
model.eval() |
|
if pruned_layers_name is not None: |
|
config_list = [{ |
|
'op_types': ['Conv2d', 'ConvTranspose2d'], |
|
'op_names': pruned_layers_name, |
|
'sparsity': sparsity |
|
}] |
|
else: |
|
config_list = [{ |
|
'op_types': ['Conv2d', 'ConvTranspose2d'], |
|
'sparsity': sparsity |
|
}] |
|
|
|
pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware, |
|
dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None) |
|
pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask) |
|
return pruned_model |
|
|