|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Script for calibrating a pretrained ASR model for quantization |
|
""" |
|
|
|
from argparse import ArgumentParser |
|
|
|
import torch |
|
from omegaconf import open_dict |
|
|
|
from nemo.collections.asr.models import EncDecCTCModel |
|
from nemo.utils import logging |
|
|
|
try: |
|
from pytorch_quantization import calib |
|
from pytorch_quantization import nn as quant_nn |
|
from pytorch_quantization import quant_modules |
|
from pytorch_quantization.tensor_quant import QuantDescriptor |
|
except ImportError: |
|
raise ImportError( |
|
"pytorch-quantization is not installed. Install from " |
|
"https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." |
|
) |
|
|
|
try: |
|
from torch.cuda.amp import autocast |
|
except ImportError: |
|
from contextlib import contextmanager |
|
|
|
@contextmanager |
|
def autocast(enabled=None): |
|
yield |
|
|
|
|
|
can_gpu = torch.cuda.is_available() |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: 'QuartzNet15x5Base-En'", |
|
) |
|
parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data") |
|
parser.add_argument("--batch_size", type=int, default=256) |
|
parser.add_argument( |
|
"--dont_normalize_text", |
|
default=False, |
|
action='store_false', |
|
help="Turn off trasnscript normalization. Recommended for non-English.", |
|
) |
|
parser.add_argument('--num_calib_batch', default=1, type=int, help="Number of batches for calibration.") |
|
parser.add_argument('--calibrator', type=str, choices=["max", "histogram"], default="max") |
|
parser.add_argument('--percentile', nargs='+', type=float, default=[99.9, 99.99, 99.999, 99.9999]) |
|
parser.add_argument("--amp", action="store_true", help="Use AMP in calibration.") |
|
parser.set_defaults(amp=False) |
|
|
|
args = parser.parse_args() |
|
torch.set_grad_enabled(False) |
|
|
|
|
|
quant_desc_input = QuantDescriptor(calib_method=args.calibrator) |
|
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) |
|
quant_nn.QuantConvTranspose2d.set_default_quant_desc_input(quant_desc_input) |
|
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) |
|
|
|
if args.asr_model.endswith('.nemo'): |
|
logging.info(f"Using local ASR model from {args.asr_model}") |
|
asr_model_cfg = EncDecCTCModel.restore_from(restore_path=args.asr_model, return_config=True) |
|
with open_dict(asr_model_cfg): |
|
asr_model_cfg.encoder.quantize = True |
|
asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model, override_config_path=asr_model_cfg) |
|
|
|
else: |
|
logging.info(f"Using NGC cloud ASR model {args.asr_model}") |
|
asr_model_cfg = EncDecCTCModel.from_pretrained(model_name=args.asr_model, return_config=True) |
|
with open_dict(asr_model_cfg): |
|
asr_model_cfg.encoder.quantize = True |
|
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model, override_config_path=asr_model_cfg) |
|
|
|
asr_model.setup_test_data( |
|
test_data_config={ |
|
'sample_rate': 16000, |
|
'manifest_filepath': args.dataset, |
|
'labels': asr_model.decoder.vocabulary, |
|
'batch_size': args.batch_size, |
|
'normalize_transcripts': args.dont_normalize_text, |
|
'shuffle': True, |
|
} |
|
) |
|
asr_model.preprocessor.featurizer.dither = 0.0 |
|
asr_model.preprocessor.featurizer.pad_to = 0 |
|
if can_gpu: |
|
asr_model = asr_model.cuda() |
|
asr_model.eval() |
|
|
|
|
|
for name, module in asr_model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
if module._calibrator is not None: |
|
module.disable_quant() |
|
module.enable_calib() |
|
else: |
|
module.disable() |
|
|
|
for i, test_batch in enumerate(asr_model.test_dataloader()): |
|
if can_gpu: |
|
test_batch = [x.cuda() for x in test_batch] |
|
if args.amp: |
|
with autocast(): |
|
_ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) |
|
else: |
|
_ = asr_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) |
|
if i >= args.num_calib_batch: |
|
break |
|
|
|
|
|
model_name = args.asr_model.replace(".nemo", "") if args.asr_model.endswith(".nemo") else args.asr_model |
|
if not args.calibrator == "histogram": |
|
compute_amax(asr_model, method="max") |
|
asr_model.save_to(F"{model_name}-max-{args.num_calib_batch*args.batch_size}.nemo") |
|
else: |
|
for percentile in args.percentile: |
|
print(F"{percentile} percentile calibration") |
|
compute_amax(asr_model, method="percentile") |
|
asr_model.save_to(F"{model_name}-percentile-{percentile}-{args.num_calib_batch*args.batch_size}.nemo") |
|
|
|
for method in ["mse", "entropy"]: |
|
print(F"{method} calibration") |
|
compute_amax(asr_model, method=method) |
|
asr_model.save_to(F"{model_name}-{method}-{args.num_calib_batch*args.batch_size}.nemo") |
|
|
|
|
|
def compute_amax(model, **kwargs): |
|
for name, module in model.named_modules(): |
|
if isinstance(module, quant_nn.TensorQuantizer): |
|
if module._calibrator is not None: |
|
if isinstance(module._calibrator, calib.MaxCalibrator): |
|
module.load_calib_amax() |
|
else: |
|
module.load_calib_amax(**kwargs) |
|
print(F"{name:40}: {module}") |
|
if can_gpu: |
|
model.cuda() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|