tensorrt-tf / main.py
sayakpaul's picture
sayakpaul HF staff
chore: shortened URl.
1c156e9
from typing import List
import gradio as gr
import tensorflow as tf
from huggingface_hub import HfApi, create_repo
from utils import convert_to_trt
DESCRIPTION = """
This Space shows how to easily optimize a [ResNet50 model from Keras](https://keras.io/api/applications/) with [TensorRT](https://developer.nvidia.com/tensorrt). TensorRT is a framework to optimize deep learning models specifically for NVIDIA hardware.
This Space does the following things:
* Loads a ResNet50 model from `tf.keras.applications` and serializes it as a SavedModel.
* Performs optimizations with TensorRT.
* Displays the benchmarks to compare the throughputs of the native TensorFlow SavedModel and its TensorRT-optimized variant.
* Pushes the optimized model to a repository on the Hugging Face Hub. For this to work, one must provide a write-access token (from hf.co/settings/tokens) to `your_hf_token`.
As a consequence, you might have to wait for a few minutes to note the results.
## Notes (important)
* For this Space to work, having access to a GPU (at least T4) is a must.
* This Space makes use of the [Docker x Space integration](https://huggingface.co/docs/hub/spaces-sdks-docker) to perform the TensorRT optimizations.
* The default TensorFlow installation doesn't come loaded with a correctly compiled TensorRT. This is why it's recommended to use an [NVIDIA container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow) to perform your TensorRT-related stuff. This is also why the Docker x Space integration was used in this Space.
* To get the maximum peformance, one must use the same hardware for inference as the one used for running the optimizations. For example, if you used a T4-based machine to perform the optimizations, ensure that you're using the same GPU while running inference with your optimized model.
* One can use this Space to optimize the others models provided in [tf.keras.applications](https://keras.io/api/applications/).
* One is encouraged to try out different forms of post-training quantization as shown in [this notebook](https://github.com/tensorflow/tensorrt/blob/master/tftrt/benchmarking-python/image_classification/NGC-TFv2-TF-TRT-inference-from-Keras-saved-model.ipynb) to squeeze out the maximum performance using NVIDIA hardware and TensorRT.
"""
print("Loading ResNet50 model.")
model = tf.keras.applications.ResNet50(weights="imagenet")
def push_to_hub(hf_token: str, push_dir: str) -> str:
try:
if hf_token == "":
return "No HF token provided. Model won't be pushed."
else:
hf_api = HfApi(token=hf_token)
user = hf_api.whoami()["name"]
repo_id = f"{user}/{push_dir}"
_ = create_repo(repo_id=repo_id, token=hf_token)
url = hf_api.upload_folder(folder_path=push_dir, repo_id=repo_id)
return f"Model successfully pushed: [{url}]({url})"
except Exception as e:
return f"{e}"
def post_optimization(list_of_strs: List[str]) -> str:
tf_throughput, tf_trt_throughput = list_of_strs
benchamrk_str = f"""
### TensorFlow
{tf_throughput}
### TensorRT-optimized
{tf_trt_throughput}
### Benchmarking information
* OS: Ubuntu 20.04.5
* Python: 3.8.10
* CUDA: 11.8
* TensorFlow: 2.10.1
* TensorRT: 8.5.1
* GPU: T4
* Benchmarking Script: [Link](https://bit.ly/3vbhnD6)
### (TensorRT) model push
"""
return benchamrk_str
def run(hf_token: str) -> str:
print("Serializing the ResNet50 as a SavedModel.")
saved_model_path = "resnet50_saved_model"
model.save(saved_model_path)
print("Converting to TensorRT.")
tensorrt_path = "trt_resnet50_keras"
convert_to_trt(saved_model_path, tensorrt_path)
tf_throughput = "Throughput: 89 images/s"
tf_trt_throughput = "Throughput: 497 images/s"
benchmark_str = post_optimization([tf_throughput, tf_trt_throughput])
benchmark_str += push_to_hub(hf_token, tensorrt_path)
return benchmark_str
def launch_gradio():
demo = gr.Interface(
title="Optimize a ResNet50 model from Keras with TensorRT",
description=DESCRIPTION,
allow_flagging="never",
inputs=[gr.Text(max_lines=1, label="your_hf_token")],
outputs=[gr.Markdown(label="output")],
fn=run,
)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
launch_gradio()