|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Script for post training quantization of ASR models |
|
""" |
|
|
|
import collections |
|
from argparse import ArgumentParser |
|
from pprint import pprint |
|
|
|
import torch |
|
from omegaconf import open_dict |
|
|
|
from nemo.collections.asr.metrics.wer import WER, CTCDecoding, CTCDecodingConfig, word_error_rate |
|
from nemo.collections.asr.models import EncDecCTCModel |
|
from nemo.utils import logging |
|
|
|
try: |
|
from pytorch_quantization import nn as quant_nn |
|
from pytorch_quantization import quant_modules |
|
except ImportError: |
|
raise ImportError( |
|
"pytorch-quantization is not installed. Install from " |
|
"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." |
|
) |
|
|
|
|
|
try: |
|
from torch.cuda.amp import autocast |
|
except ImportError: |
|
from contextlib import contextmanager |
|
|
|
@contextmanager |
|
def autocast(enabled=None): |
|
yield |
|
|
|
|
|
can_gpu = torch.cuda.is_available() |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'", |
|
) |
|
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") |
|
parser.add_argument("--wer_target", type=float, default=None, help="used by test") |
|
parser.add_argument("--batch_size", type=int, default=4) |
|
parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") |
|
parser.add_argument( |
|
"--dont_normalize_text", |
|
default=False, |
|
action='store_false', |
|
help="Turn off trasnscript normalization. Recommended for non-English.", |
|
) |
|
parser.add_argument( |
|
"--use_cer", default=False, action='store_true', help="Use Character Error Rate as the evaluation metric" |
|
) |
|
parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis") |
|
parser.add_argument('--onnx', action="store_true", help="Export to ONNX") |
|
parser.add_argument('--quant-disable-keyword', type=str, nargs='+', help='disable quantizers by keyword') |
|
args = parser.parse_args() |
|
torch.set_grad_enabled(False) |
|
|
|
quant_modules.initialize() |
|
|
|
if args.asr_model.endswith('.nemo'): |
|
logging.info(f"Using local ASR model from {args.asr_model}") |
|
asr_model_cfg = EncDecCTCModel.restore_from(restore_path=args.asr_model, return_config=True) |
|
with open_dict(asr_model_cfg): |
|
asr_model_cfg.encoder.quantize = True |
|
asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, override_config_path=asr_model_cfg) |
|
|
|
else: |
|
logging.info(f"Using NGC cloud ASR model {args.asr_model}") |
|
asr_model_cfg = EncDecCTCModel.from_pretrained(model_name=args.asr_model, return_config=True) |
|
with open_dict(asr_model_cfg): |
|
asr_model_cfg.encoder.quantize = True |
|
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, override_config_path=asr_model_cfg) |
|
asr_model.setup_test_data( |
|
test_data_config={ |
|
'sample_rate': 16000, |
|
'manifest_filepath': args.dataset, |
|
'labels': asr_model.decoder.vocabulary, |
|
'batch_size': args.batch_size, |
|
'normalize_transcripts': args.dont_normalize_text, |
|
} |
|
) |
|
asr_model.preprocessor.featurizer.dither = 0.0 |
|
asr_model.preprocessor.featurizer.pad_to = 0 |
|
if can_gpu: |
|
asr_model = asr_model.cuda() |
|
asr_model.eval() |
|
|
|
if args.quant_disable_keyword: |
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
for keyword in args.quant_disable_keyword: |
|
if keyword in name: |
|
logging.warning(F"Disable {name}") |
|
module.disable() |
|
|
|
labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) |
|
decoding_cfg = CTCDecodingConfig() |
|
char_decoding = CTCDecoding(decoding_cfg, vocabulary=labels_map) |
|
wer = WER(char_decoding, use_cer=args.use_cer) |
|
wer_quant = evaluate(asr_model, labels_map, wer) |
|
logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}') |
|
|
|
if args.sensitivity: |
|
if wer_quant < args.wer_tolerance: |
|
logging.info("Tolerance is already met. Skip sensitivity analyasis.") |
|
return |
|
quant_layer_names = [] |
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
module.disable() |
|
layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "") |
|
if layer_name not in quant_layer_names: |
|
quant_layer_names.append(layer_name) |
|
logging.info(F"{len(quant_layer_names)} quantized layers found.") |
|
|
|
|
|
quant_layer_sensitivity = {} |
|
for i, quant_layer in enumerate(quant_layer_names): |
|
logging.info(F"Enable {quant_layer}") |
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: |
|
module.enable() |
|
logging.info(F"{name:40}: {module}") |
|
|
|
|
|
wer_value = evaluate(asr_model, labels_map, wer) |
|
logging.info(F"WER: {wer_value}") |
|
quant_layer_sensitivity[quant_layer] = args.wer_tolerance - wer_value |
|
|
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name: |
|
module.disable() |
|
logging.info(F"{name:40}: {module}") |
|
|
|
|
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
module.enable() |
|
quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1])) |
|
pprint(quant_layer_sensitivity) |
|
skipped_layers = [] |
|
for quant_layer, _ in quant_layer_sensitivity.items(): |
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
if quant_layer in name: |
|
logging.info(F"Disable {name}") |
|
if not quant_layer in skipped_layers: |
|
skipped_layers.append(quant_layer) |
|
module.disable() |
|
wer_value = evaluate(asr_model, labels_map, wer) |
|
if wer_value <= args.wer_tolerance: |
|
logging.info( |
|
F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers." |
|
) |
|
print(skipped_layers) |
|
export_onnx(args, asr_model) |
|
return |
|
raise ValueError(f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!") |
|
|
|
export_onnx(args, asr_model) |
|
|
|
|
|
def export_onnx(args, asr_model): |
|
if args.onnx: |
|
if args.asr_model.endswith("nemo"): |
|
onnx_name = args.asr_model.replace(".nemo", ".onnx") |
|
else: |
|
onnx_name = args.asr_model |
|
logging.info(F"Export to {onnx_name}") |
|
quant_nn.TensorQuantizer.use_fb_fake_quant = True |
|
asr_model.export(onnx_name, onnx_opset_version=13) |
|
quant_nn.TensorQuantizer.use_fb_fake_quant = False |
|
|
|
|
|
def evaluate(asr_model, labels_map, wer): |
|
|
|
hypotheses = [] |
|
references = [] |
|
for test_batch in asr_model.test_dataloader(): |
|
if can_gpu: |
|
test_batch = [x.cuda() for x in test_batch] |
|
with autocast(): |
|
log_probs, encoded_len, greedy_predictions = asr_model( |
|
input_signal=test_batch[0], input_signal_length=test_batch[1] |
|
) |
|
hypotheses += wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)[0] |
|
for batch_ind in range(greedy_predictions.shape[0]): |
|
seq_len = test_batch[3][batch_ind].cpu().detach().numpy() |
|
seq_ids = test_batch[2][batch_ind].cpu().detach().numpy() |
|
reference = ''.join([labels_map[c] for c in seq_ids[0:seq_len]]) |
|
references.append(reference) |
|
del test_batch |
|
wer_value = word_error_rate(hypotheses=hypotheses, references=references, use_cer=wer.use_cer) |
|
|
|
return wer_value |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|