NeMo / examples /asr /quantization /speech_to_text_quant_infer.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script for post training quantization of ASR models
"""
import collections
from argparse import ArgumentParser
from pprint import pprint
import torch
from omegaconf import open_dict
from nemo.collections.asr.metrics.wer import WER, CTCDecoding, CTCDecodingConfig, word_error_rate
from nemo.collections.asr.models import EncDecCTCModel
from nemo.utils import logging
try:
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
except ImportError:
raise ImportError(
"pytorch-quantization is not installed. Install from "
"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
)
try:
from torch.cuda.amp import autocast
except ImportError:
from contextlib import contextmanager
@contextmanager
def autocast(enabled=None):
yield
can_gpu = torch.cuda.is_available()
def main():
parser = ArgumentParser()
parser.add_argument(
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'",
)
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
parser.add_argument("--wer_target", type=float, default=None, help="used by test")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test")
parser.add_argument(
"--dont_normalize_text",
default=False,
action='store_false',
help="Turn off trasnscript normalization. Recommended for non-English.",
)
parser.add_argument(
"--use_cer", default=False, action='store_true', help="Use Character Error Rate as the evaluation metric"
)
parser.add_argument('--sensitivity', action="store_true", help="Perform sensitivity analysis")
parser.add_argument('--onnx', action="store_true", help="Export to ONNX")
parser.add_argument('--quant-disable-keyword', type=str, nargs='+', help='disable quantizers by keyword')
args = parser.parse_args()
torch.set_grad_enabled(False)
quant_modules.initialize()
if args.asr_model.endswith('.nemo'):
logging.info(f"Using local ASR model from {args.asr_model}")
asr_model_cfg = EncDecCTCModel.restore_from(restore_path=args.asr_model, return_config=True)
with open_dict(asr_model_cfg):
asr_model_cfg.encoder.quantize = True
asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, override_config_path=asr_model_cfg)
else:
logging.info(f"Using NGC cloud ASR model {args.asr_model}")
asr_model_cfg = EncDecCTCModel.from_pretrained(model_name=args.asr_model, return_config=True)
with open_dict(asr_model_cfg):
asr_model_cfg.encoder.quantize = True
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, override_config_path=asr_model_cfg)
asr_model.setup_test_data(
test_data_config={
'sample_rate': 16000,
'manifest_filepath': args.dataset,
'labels': asr_model.decoder.vocabulary,
'batch_size': args.batch_size,
'normalize_transcripts': args.dont_normalize_text,
}
)
asr_model.preprocessor.featurizer.dither = 0.0
asr_model.preprocessor.featurizer.pad_to = 0
if can_gpu:
asr_model = asr_model.cuda()
asr_model.eval()
if args.quant_disable_keyword:
for name, module in asr_model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
for keyword in args.quant_disable_keyword:
if keyword in name:
logging.warning(F"Disable {name}")
module.disable()
labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))])
decoding_cfg = CTCDecodingConfig()
char_decoding = CTCDecoding(decoding_cfg, vocabulary=labels_map)
wer = WER(char_decoding, use_cer=args.use_cer)
wer_quant = evaluate(asr_model, labels_map, wer)
logging.info(f'Got WER of {wer_quant}. Tolerance was {args.wer_tolerance}')
if args.sensitivity:
if wer_quant < args.wer_tolerance:
logging.info("Tolerance is already met. Skip sensitivity analyasis.")
return
quant_layer_names = []
for name, module in asr_model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
module.disable()
layer_name = name.replace("._input_quantizer", "").replace("._weight_quantizer", "")
if layer_name not in quant_layer_names:
quant_layer_names.append(layer_name)
logging.info(F"{len(quant_layer_names)} quantized layers found.")
# Build sensitivity profile
quant_layer_sensitivity = {}
for i, quant_layer in enumerate(quant_layer_names):
logging.info(F"Enable {quant_layer}")
for name, module in asr_model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name:
module.enable()
logging.info(F"{name:40}: {module}")
# Eval the model
wer_value = evaluate(asr_model, labels_map, wer)
logging.info(F"WER: {wer_value}")
quant_layer_sensitivity[quant_layer] = args.wer_tolerance - wer_value
for name, module in asr_model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer) and quant_layer in name:
module.disable()
logging.info(F"{name:40}: {module}")
# Skip most sensitive layers until WER target is met
for name, module in asr_model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
module.enable()
quant_layer_sensitivity = collections.OrderedDict(sorted(quant_layer_sensitivity.items(), key=lambda x: x[1]))
pprint(quant_layer_sensitivity)
skipped_layers = []
for quant_layer, _ in quant_layer_sensitivity.items():
for name, module in asr_model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if quant_layer in name:
logging.info(F"Disable {name}")
if not quant_layer in skipped_layers:
skipped_layers.append(quant_layer)
module.disable()
wer_value = evaluate(asr_model, labels_map, wer)
if wer_value <= args.wer_tolerance:
logging.info(
F"WER tolerance {args.wer_tolerance} is met by skipping {len(skipped_layers)} sensitive layers."
)
print(skipped_layers)
export_onnx(args, asr_model)
return
raise ValueError(f"WER tolerance {args.wer_tolerance} can not be met with any layer quantized!")
export_onnx(args, asr_model)
def export_onnx(args, asr_model):
if args.onnx:
if args.asr_model.endswith("nemo"):
onnx_name = args.asr_model.replace(".nemo", ".onnx")
else:
onnx_name = args.asr_model
logging.info(F"Export to {onnx_name}")
quant_nn.TensorQuantizer.use_fb_fake_quant = True
asr_model.export(onnx_name, onnx_opset_version=13)
quant_nn.TensorQuantizer.use_fb_fake_quant = False
def evaluate(asr_model, labels_map, wer):
# Eval the model
hypotheses = []
references = []
for test_batch in asr_model.test_dataloader():
if can_gpu:
test_batch = [x.cuda() for x in test_batch]
with autocast():
log_probs, encoded_len, greedy_predictions = asr_model(
input_signal=test_batch[0], input_signal_length=test_batch[1]
)
hypotheses += wer.decoding.ctc_decoder_predictions_tensor(greedy_predictions)[0]
for batch_ind in range(greedy_predictions.shape[0]):
seq_len = test_batch[3][batch_ind].cpu().detach().numpy()
seq_ids = test_batch[2][batch_ind].cpu().detach().numpy()
reference = ''.join([labels_map[c] for c in seq_ids[0:seq_len]])
references.append(reference)
del test_batch
wer_value = word_error_rate(hypotheses=hypotheses, references=references, use_cer=wer.use_cer)
return wer_value
if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter