tensorrt-tf / utils.py
sayakpaul's picture
sayakpaul HF staff
fix: remove benchmarking and kept it simple.
bf49c9f
from tensorflow.python.compiler.tensorrt import trt_convert as trt
def convert_to_trt(input_model_path: str, trt_model_path: str) -> None:
"""Utility to convert and save an input SavedModel to an optimized TensorRT graph.
Args:
input_model_path: Path to the SavedModel to optimize.
trt_model_path: Path to save the converted TensorRT graph.
"""
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=input_model_path,
precision_mode=trt.TrtPrecisionMode.FP32,
max_workspace_size_bytes=8000000000,
)
converter.convert()
converter.save(output_saved_model_dir=trt_model_path)
print("Done Converting to TF-TRT FP32")