Spaces:
Running
Running
import copy | |
import os | |
import sys | |
from tabnanny import verbose | |
from typing import List, Optional, Tuple | |
import torch | |
from ...third_party.nni_new.algorithms.compression.pytorch.pruning import L1FilterPruner | |
from ...third_party.nni_new.compression.pytorch.speedup import ModelSpeedup | |
from ...common.others import get_cur_time_str | |
def _prune_module(model, pruner, model_input_size, device, verbose=False, need_return_mask=False): | |
pruner.compress() | |
pid = os.getpid() | |
timestamp = get_cur_time_str() | |
tmp_model_path = './tmp_weight-{}-{}.pth'.format(pid, timestamp) | |
tmp_mask_path = './tmp_mask-{}-{}.pth'.format(pid, timestamp) | |
pruner.export_model(model_path=tmp_model_path, mask_path=tmp_mask_path) | |
os.remove(tmp_model_path) | |
# speed up | |
dummy_input = torch.rand(model_input_size).to(device) | |
pruned_model = model | |
pruned_model.eval() | |
model_speedup = ModelSpeedup(pruned_model, dummy_input, tmp_mask_path, device) | |
fixed_mask = model_speedup.speedup_model() | |
if not need_return_mask: | |
os.remove(tmp_mask_path) | |
return pruned_model | |
else: | |
mask = fixed_mask | |
os.remove(tmp_mask_path) | |
return pruned_model, mask | |
def l1_prune_model(model: torch.nn.Module, pruned_layers_name: Optional[List[str]], sparsity: float, | |
model_input_size: Tuple[int], device: str, verbose=False, need_return_mask=False, dep_aware=False): | |
"""Get the pruned model via L1 Filter Pruning. | |
Reference: | |
Li H, Kadav A, Durdanovic I, et al. Pruning filters for efficient convnets[J]. arXiv preprint arXiv:1608.08710, 2016. | |
Args: | |
model (torch.nn.Module): A PyTorch model. | |
pruned_layers_name (Optional[List[str]]): Which layers will be pruned. If it's `None`, all layers will be pruned. | |
sparsity (float): Target sparsity. The pruned model is smaller if sparsity is higher. | |
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. | |
device (str): Typically be 'cpu' or 'cuda'. | |
verbose (bool, optional): Whether to output the verbose log. Defaults to False. (BUG TO FIX) | |
need_return_mask (bool, optional): Return the fine-grained mask generated by NNI framework for debug. Defaults to False. | |
dep_aware (bool, optional): Refers to the argument `dependency_aware` in NNI framework. Defaults to False. | |
Returns: | |
torch.nn.Module: Pruned model. | |
""" | |
model = copy.deepcopy(model).to(device) | |
if sparsity == 0: | |
return model | |
pruned_model = copy.deepcopy(model).to(device) | |
# generate mask | |
model.eval() | |
if pruned_layers_name is not None: | |
config_list = [{ | |
'op_types': ['Conv2d', 'ConvTranspose2d'], | |
'op_names': pruned_layers_name, | |
'sparsity': sparsity | |
}] | |
else: | |
config_list = [{ | |
'op_types': ['Conv2d', 'ConvTranspose2d'], | |
'sparsity': sparsity | |
}] | |
pruner = L1FilterPruner(model, config_list, dependency_aware=dep_aware, | |
dummy_input=torch.rand(model_input_size).to(device) if dep_aware else None) | |
pruned_model = _prune_module(pruned_model, pruner, model_input_size, device, verbose, need_return_mask) | |
return pruned_model | |