|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
This script exports a pre-trained FireRedASR encoder model from PyTorch to |
|
|
ONNX and TensorRT. |
|
|
|
|
|
Usage: |
|
|
|
|
|
python3 examples/export_encoder_tensorrt.py \ |
|
|
--model-dir /path/to/your/model_dir \ |
|
|
--tensorrt-model-dir ./tensorrt_models \ |
|
|
--trt-engine-file-name encoder.plan |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
import tensorrt as trt |
|
|
|
|
|
from fireredasr.models.fireredasr import load_fireredasr_aed_model |
|
|
|
|
|
|
|
|
def get_parser() -> argparse.ArgumentParser: |
|
|
"""Get the command-line argument parser.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--model-dir", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The model directory that contains model checkpoint.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--onnx-model-path", |
|
|
type=str, |
|
|
default=None, |
|
|
help="If specified, we will directly use this onnx model to generate " |
|
|
"the tensorrt engine", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--idim", |
|
|
type=int, |
|
|
default=80, |
|
|
help="The input dimension of the model. This is required when " |
|
|
"--onnx-model-path is specified.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--tensorrt-model-dir", |
|
|
type=str, |
|
|
default="exp", |
|
|
help="Directory to save the exported models.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--trt-engine-file-name", |
|
|
type=str, |
|
|
default="encoder.plan", |
|
|
help="The name of the TensorRT engine file.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--opset-version", |
|
|
type=int, |
|
|
default=17, |
|
|
help="ONNX opset version.", |
|
|
) |
|
|
|
|
|
return parser |
|
|
|
|
|
|
|
|
def export_encoder_onnx( |
|
|
encoder: torch.nn.Module, |
|
|
filename: str, |
|
|
idim: int, |
|
|
opset_version: int = 17, |
|
|
) -> None: |
|
|
"""Export the conformer encoder model to ONNX format.""" |
|
|
logging.info("Exporting encoder to ONNX") |
|
|
encoder.half() |
|
|
|
|
|
|
|
|
seq_len = 400 |
|
|
batch_size = 1 |
|
|
padded_input = torch.randn(batch_size, seq_len, idim, dtype=torch.float16) |
|
|
input_lengths = torch.tensor([seq_len] * batch_size, dtype=torch.int32) |
|
|
|
|
|
|
|
|
torch.onnx.export( |
|
|
encoder, |
|
|
(padded_input, input_lengths), |
|
|
filename, |
|
|
opset_version=opset_version, |
|
|
input_names=["padded_input", "input_lengths"], |
|
|
output_names=["enc_output", "output_lengths", "src_mask"], |
|
|
dynamic_axes={ |
|
|
"padded_input": {0: "batch_size", 1: "seq_len"}, |
|
|
"input_lengths": {0: "batch_size"}, |
|
|
"enc_output": {0: "batch_size", 1: "seq_len_out"}, |
|
|
"output_lengths": {0: "batch_size",}, |
|
|
"src_mask": {0: "batch_size", 2: "seq_len_out"}, |
|
|
}, |
|
|
) |
|
|
logging.info(f"Exported encoder to {filename}") |
|
|
|
|
|
|
|
|
def get_trt_kwargs_dynamic_batch( |
|
|
idim: int, |
|
|
min_batch_size: int = 1, |
|
|
opt_batch_size: int = 4, |
|
|
max_batch_size: int = 64, |
|
|
): |
|
|
"""Get keyword arguments for TensorRT with dynamic batch size.""" |
|
|
min_seq_len = 50 |
|
|
opt_seq_len = 400 |
|
|
max_seq_len = 3000 |
|
|
|
|
|
min_shape = [(min_batch_size, min_seq_len, idim), (min_batch_size,)] |
|
|
opt_shape = [(opt_batch_size, opt_seq_len, idim), (opt_batch_size,)] |
|
|
max_shape = [(max_batch_size, max_seq_len, idim), (max_batch_size,)] |
|
|
input_names = ["padded_input", "input_lengths"] |
|
|
return { |
|
|
"min_shape": min_shape, |
|
|
"opt_shape": opt_shape, |
|
|
"max_shape": max_shape, |
|
|
"input_names": input_names, |
|
|
} |
|
|
|
|
|
|
|
|
def convert_onnx_to_trt( |
|
|
trt_model: str, trt_kwargs: dict, onnx_model: str, dtype: torch.dtype = torch.float16 |
|
|
) -> None: |
|
|
"""Convert an ONNX model to a TensorRT engine.""" |
|
|
logging.info("Converting ONNX to TensorRT engine...") |
|
|
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
|
|
logger = trt.Logger(trt.Logger.INFO) |
|
|
builder = trt.Builder(logger) |
|
|
network = builder.create_network(network_flags) |
|
|
parser = trt.OnnxParser(network, logger) |
|
|
config = builder.create_builder_config() |
|
|
|
|
|
if dtype == torch.float16: |
|
|
config.set_flag(trt.BuilderFlag.FP16) |
|
|
|
|
|
profile = builder.create_optimization_profile() |
|
|
|
|
|
with open(onnx_model, "rb") as f: |
|
|
if not parser.parse(f.read()): |
|
|
for error in range(parser.num_errors): |
|
|
print(parser.get_error(error)) |
|
|
raise ValueError(f'Failed to parse {onnx_model}') |
|
|
|
|
|
for i, name in enumerate(trt_kwargs['input_names']): |
|
|
profile.set_shape( |
|
|
name, |
|
|
trt_kwargs['min_shape'][i], |
|
|
trt_kwargs['opt_shape'][i], |
|
|
trt_kwargs['max_shape'][i] |
|
|
) |
|
|
|
|
|
config.add_optimization_profile(profile) |
|
|
|
|
|
try: |
|
|
engine_bytes = builder.build_serialized_network(network, config) |
|
|
except Exception as e: |
|
|
logging.error(f"TensorRT engine build failed: {e}") |
|
|
return |
|
|
|
|
|
with open(trt_model, "wb") as f: |
|
|
f.write(engine_bytes) |
|
|
logging.info("Successfully converted ONNX to TensorRT.") |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def main(): |
|
|
"""Main function to export the model.""" |
|
|
parser = get_parser() |
|
|
args = parser.parse_args() |
|
|
|
|
|
tensorrt_model_dir = Path(args.tensorrt_model_dir) |
|
|
tensorrt_model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if args.onnx_model_path: |
|
|
logging.info(f"Using provided ONNX model: {args.onnx_model_path}") |
|
|
if not args.idim: |
|
|
raise ValueError("--idim is required when using --onnx-model-path") |
|
|
idim = args.idim |
|
|
encoder_onnx_file = Path(args.onnx_model_path) |
|
|
if not encoder_onnx_file.is_file(): |
|
|
raise FileNotFoundError(f"ONNX model not found at {encoder_onnx_file}") |
|
|
else: |
|
|
if not args.model_dir: |
|
|
raise ValueError( |
|
|
"--model-dir is required if --onnx-model-path is not provided" |
|
|
) |
|
|
|
|
|
logging.info("Exporting ONNX model from PyTorch checkpoint") |
|
|
model_dir = Path(args.model_dir) |
|
|
model_path = model_dir / "model.pth.tar" |
|
|
|
|
|
|
|
|
package = torch.load(model_path, map_location="cpu", weights_only=False) |
|
|
model_args = package["args"] |
|
|
idim = model_args.idim |
|
|
|
|
|
model = load_fireredasr_aed_model(str(model_path)) |
|
|
encoder = model.encoder |
|
|
encoder.eval() |
|
|
|
|
|
|
|
|
encoder_onnx_file = tensorrt_model_dir / "encoder.fp16.onnx" |
|
|
export_encoder_onnx( |
|
|
encoder=encoder, |
|
|
filename=str(encoder_onnx_file), |
|
|
idim=idim, |
|
|
opset_version=args.opset_version, |
|
|
) |
|
|
|
|
|
|
|
|
trt_engine_file = tensorrt_model_dir / args.trt_engine_file_name |
|
|
trt_kwargs = get_trt_kwargs_dynamic_batch(idim=idim) |
|
|
convert_onnx_to_trt( |
|
|
trt_model=str(trt_engine_file), |
|
|
trt_kwargs=trt_kwargs, |
|
|
onnx_model=str(encoder_onnx_file), |
|
|
dtype=torch.float16, |
|
|
) |
|
|
|
|
|
logging.info("Done!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
logging.basicConfig(format=formatter, level=logging.INFO) |
|
|
main() |
|
|
|