TF-Keras
TF Lite
OptimizedPilotNet / README.md
sergiopaniego's picture
Update README.md
a6500b3
metadata
license: apache-2.0
datasets:
  - sergiopaniego/CarlaFollowLanePreviousV

OptimizedPilotNet

These models contain a PilotNet baseline model for end-to-end control of an autonomous ego vehicle in CARLA and its optimized versions for both PyTorch and Tensorflow.

  • GPU: NVIDIA GeForce 3090
  • CUDA: 11.6
  • Driver version: 510.54
  • Input shape (200,66,3)

Tensorflow

  • Tensorflow version: 2.7.0 / 2.11.0
  • TensorRT version: 7.2.2.1
  • Docker image: nvcr.io/nvidia/tensorflow:20.12-tf2-py3
  • nvidia-tensorrt: 7.2.2.1
Optimization Model size (MB) MSE Inference time (s/frame) Filename
Original 19 0.018 0.022 pilotnet.h5
Baseline 6.0925140380859375 0.010881431312199034 0.0016004319190979005 pilotnet_model.tflite
Dynamic Range Quantization 1.5377578735351562 0.010803998294344926 0.0008851253986358643 pilotnet_dynamic_quant.tflite
Integer Quantization 1.5389328002929688 0.01102226436099348 0.0008868560791015625 pilotnet_int_quant.tflite
Integer (float fallback) Quantization 1.5389175415039062 0.0008868560791015625 0.0008031470775604248 pilotnet_intflt_quant.tflite
Float16 Quantization 3.0508956909179688 0.010804510797606127 0.0013616561889648437 pilotnet_float16_quant.tflite
Quantization Aware Training 1.5446319580078125 0.0115418379596583 0.0008456888198852539 pilotnet_quant_aware.tflite
(random sparse) Weight pruning 6.0925140380859375 0.011697137610230973 0.0016570956707000733 pilotnet_pruned.tflite
(random sparse) Weight pruning Quantization 1.536590576171875 0.011635421636510991 0.0012711701393127441 pilotnet_pruned_quan.tflite
Cluster preserving Quantization Aware 1.5446319580078125 0.010546523951115492 0.0008221814632415771 pilotnet_cqat_model.tflite
Pruning preserving Quantization Aware 1.5446319580078125 0.010758002372154884 0.0008252830505371093 pilotnet_pqat_model.tflite
Sparsity and cluster preserving quantization aware training (PCQAT) 1.5446319580078125 0.008262857163545972 0.0008286898136138916 pilotnet_pcqat_model.tflite

TensorRT-Tensorflow:

To do inference:

  pip install nvidia-tensorrt===7.2.2.1
  python3 -c "import tensorrt; print(tensorrt.__version__); assert tensorrt.Builder(tensorrt.Logger())"
  export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/python3.8/site-packages/tensorrt
  python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
Optimization Model size (MB) MSE Inference time (s/frame) Folder
Float32 Quantization 0.00390625 0.010798301750717706 0.00038761067390441896 pilotnet_tftrt_fp32
Float16 Quantization 0.00390625 0.010798278900279191 0.00042218327522277834 pilotnet_tftrt_fp16
Int8 Quantization 0.00390625 0.04791482252948612 0.0003384373188018799 pilotnet_tftrt_int8

PyTorch

  • PyTorch version: 1.13.1+cu116
  • TensorRT version: 8.5.5
  • Docker image: nvcr.io/nvidia/pytorch:22.12-py3
  • torch-tensorrt: 1.3.0
Optimization Model size (MB) MSE Inference time (s/frame) Filename
Original 6.1217 0.03524 - pilotnet_model.pth
Dynamic Range Quantization 1.9493608474731445 0.012065857842182075 0.001480283498764038 24_05_dynamic_quan.pth
Static Quantization 1.6071176528930664 0.012072610909984047 0.0007314345836639404 24_05_static_quan.pth
Quantization Aware Training 1.6069536209106445 0.01109830549109022 0.0011710402965545653 24_05_quan_aware .pth
Local Prune 6.122584342956543 0.010850968803449539 0.0014387350082397461 24_05_local_prune.pth
Global Prune 6.122775077819824 0.010964057565769462 0.0014179635047912597 24_05_global_prune.pth
Prune + Quantization 1.6067094802856445 0.010949893930274941 0.0011728739738464356 24_05_prune_quan.pth

TensorRT-PyTorch:

To do inference:

  pip install torch-tensorrt==1.3.0
Optimization Model size (MB) MSE Inference time (s/frame) Filename
Float32 Quantization 6.121363639831543 0.009570527376262128 0.0002284455299377441 trt_mod_float32.jit.pt
Float16 Quantization 6.121363639831543 0.009571507916721152 0.000250823974609375 trt_mod_float16.jit.pt
Int8 Quantization 6.181861877441406 0.00969304293365875 0.0002463934421539307 trt_mod_int8.jit.pt