MobileVit-Small First commit
Browse files- .gitattributes +1 -0
- README.md +104 -20
- embedl_mobilevit_small_int8.onnx +3 -0
- embedl_mobilevit_small_int8.pt2 +3 -0
- infer_pt2.py +65 -0
- infer_trt.py +139 -0
.gitattributes
CHANGED
|
@@ -20,6 +20,7 @@
|
|
| 20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pt2 filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 26 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -3,30 +3,114 @@ license: other
|
|
| 3 |
license_name: embedl-models-community-licence-1.0
|
| 4 |
license_link: https://github.com/embedl/embedl-models/blob/main/LICENSE
|
| 5 |
base_model:
|
| 6 |
-
-
|
| 7 |
quantized_from:
|
| 8 |
-
-
|
| 9 |
tags:
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
-
|
| 13 |
-
-
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
-
- edge
|
| 17 |
-
- embedl
|
| 18 |
gated: true
|
| 19 |
-
extra_gated_heading: "Access Embedl
|
| 20 |
-
extra_gated_description:
|
| 21 |
-
To access this model, please review and accept the terms below.
|
| 22 |
-
Your contact information is collected solely to manage access and,
|
| 23 |
-
with your explicit consent, to notify you about updated or new
|
| 24 |
-
optimized models from Embedl. You can withdraw consent at any time
|
| 25 |
-
by contacting us (see Contact section below). See our license for full terms.
|
| 26 |
extra_gated_button_content: "Agree and request access"
|
| 27 |
-
extra_gated_prompt: "By requesting access you agree to the Embedl Models Community Licence and the upstream
|
| 28 |
extra_gated_fields:
|
| 29 |
Company: text
|
| 30 |
-
I agree to the Embedl Models Community Licence and upstream
|
| 31 |
I consent to being contacted by Embedl about products and services (optional): checkbox
|
| 32 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
license_name: embedl-models-community-licence-1.0
|
| 4 |
license_link: https://github.com/embedl/embedl-models/blob/main/LICENSE
|
| 5 |
base_model:
|
| 6 |
+
- apple/mobilevit-small
|
| 7 |
quantized_from:
|
| 8 |
+
- apple/mobilevit-small
|
| 9 |
tags:
|
| 10 |
+
- image-classification
|
| 11 |
+
- quantization
|
| 12 |
+
- onnx
|
| 13 |
+
- tensorrt
|
| 14 |
+
- edge
|
| 15 |
+
- embedl
|
|
|
|
|
|
|
| 16 |
gated: true
|
| 17 |
+
extra_gated_heading: "Access Embedl Mobilevit Small"
|
| 18 |
+
extra_gated_description: "To access this model, please review and accept the terms below. Your contact information is collected solely to manage access and, with your explicit consent, to notify you about updated or new optimized models from Embedl."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
extra_gated_button_content: "Agree and request access"
|
| 20 |
+
extra_gated_prompt: "By requesting access you agree to the Embedl Models Community Licence and the upstream Mobilevit Small License"
|
| 21 |
extra_gated_fields:
|
| 22 |
Company: text
|
| 23 |
+
I agree to the Embedl Models Community Licence and upstream Mobilevit Small License: checkbox
|
| 24 |
I consent to being contacted by Embedl about products and services (optional): checkbox
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
# Embedl Mobilevit Small (Quantized for TensorRT)
|
| 28 |
+
|
| 29 |
+
Deployable INT8-quantized version of [`apple/mobilevit-small`](https://huggingface.co/apple/mobilevit-small),
|
| 30 |
+
optimized with [embedl-deploy](https://github.com/embedl/embedl-deploy)
|
| 31 |
+
for low-latency NVIDIA TensorRT inference on edge GPUs.
|
| 32 |
+
|
| 33 |
+
## Highlights
|
| 34 |
+
|
| 35 |
+
- **Mixed-precision INT8/FP16 quantization** with hardware-aware
|
| 36 |
+
optimizations from [embedl-deploy](https://github.com/embedl/embedl-deploy).
|
| 37 |
+
- **Drop-in replacement** for `apple/mobilevit-small` in TensorRT pipelines —
|
| 38 |
+
same input shape (256×256), same output
|
| 39 |
+
semantics.
|
| 40 |
+
- **Validated accuracy** within 3.30 pp of the FP32
|
| 41 |
+
baseline on ImageNet (see Accuracy table below).
|
| 42 |
+
- **Faster than `trtexec --best`** on supported NVIDIA hardware
|
| 43 |
+
(see Performance table below).
|
| 44 |
+
- Includes both **ONNX** (for TensorRT) and **PT2**
|
| 45 |
+
(`torch.export`-loadable) artifacts plus runnable inference scripts.
|
| 46 |
+
|
| 47 |
+
## Quick Start
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
pip install huggingface_hub onnxruntime-gpu pillow numpy
|
| 51 |
+
python -c "from huggingface_hub import snapshot_download; snapshot_download('embedl/mobilevit-small-quantized', local_dir='.')"
|
| 52 |
+
python infer_trt.py --image path/to/image.jpg # TensorRT
|
| 53 |
+
# or
|
| 54 |
+
python infer_pt2.py --image path/to/image.jpg # pure PyTorch via torch.export
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
## Files
|
| 58 |
+
|
| 59 |
+
| File | Purpose |
|
| 60 |
+
|---|---|
|
| 61 |
+
| `embedl_mobilevit_small_int8.onnx` | INT8-quantized ONNX with Q/DQ nodes — feed to TensorRT. |
|
| 62 |
+
| `embedl_mobilevit_small_int8.pt2` | INT8-quantized `torch.export` ExportedProgram. |
|
| 63 |
+
| `infer_trt.py` | Build a TRT engine from the ONNX and run sample inference. |
|
| 64 |
+
| `infer_pt2.py` | Load the `.pt2` with `torch.export.load` and run sample inference. |
|
| 65 |
+
| `latency_comparison.png` | Latency comparison across precisions and devices. |
|
| 66 |
+
|
| 67 |
+
## Performance
|
| 68 |
+
|
| 69 |
+
Latency measured with TensorRT + `trtexec`, GPU compute time only
|
| 70 |
+
(`--noDataTransfers`), CUDA Graph + Spin Wait enabled, clocks locked
|
| 71 |
+
(`nvpmodel -m 0 && jetson_clocks` on Jetson). See
|
| 72 |
+
`latency_comparison.png` for a visual summary.
|
| 73 |
+
|
| 74 |
+

|
| 75 |
+
|
| 76 |
+
### NVIDIA Jetson AGX Orin
|
| 77 |
+
|
| 78 |
+
| Configuration | Mean Latency | Speedup vs FP16 |
|
| 79 |
+
|---|---|---|
|
| 80 |
+
| TensorRT FP16 | 1.28 ms | 1.00x |
|
| 81 |
+
| TensorRT --best (unconstrained) | 1.09 ms | 1.17x |
|
| 82 |
+
| **Embedl Deploy INT8** | **1.09 ms** | **1.17x** |
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
## Accuracy
|
| 86 |
+
|
| 87 |
+
Evaluated on the ImageNet validation split. The quantized model
|
| 88 |
+
retains nearly all of the FP32 accuracy with a small tolerance.
|
| 89 |
+
|
| 90 |
+
| Model | Top-1 | Top-5 |
|
| 91 |
+
|---|---|---|
|
| 92 |
+
| `apple/mobilevit-small` FP32 (ours) | 78.14% | 94.08% |
|
| 93 |
+
| **Embedl Mobilevit Small INT8** | **74.83%** | **92.28%** |
|
| 94 |
+
|
| 95 |
+
## Creating Your Own Optimized Models
|
| 96 |
+
|
| 97 |
+
This artifact was produced with
|
| 98 |
+
[embedl-deploy](https://github.com/embedl/embedl-deploy),
|
| 99 |
+
Embedl's open-source PyTorch → TensorRT deployment library. You can
|
| 100 |
+
apply the same workflow to your own models — see
|
| 101 |
+
[the documentation](https://github.com/embedl/embedl-deploy#readme)
|
| 102 |
+
for installation and usage.
|
| 103 |
+
|
| 104 |
+
## License
|
| 105 |
+
|
| 106 |
+
| Component | License |
|
| 107 |
+
|---|---|
|
| 108 |
+
| Optimized model artifacts (this repo) | [Embedl Models Community Licence v1.0](https://github.com/embedl/embedl-models/blob/main/LICENSE) — no redistribution as a hosted service |
|
| 109 |
+
| Upstream architecture and weights | [Mobilevit Small License](https://huggingface.co/apple/mobilevit-small) |
|
| 110 |
+
|
| 111 |
+
## Contact
|
| 112 |
+
|
| 113 |
+
We offer engineering support for on-prem/edge deployments and partner
|
| 114 |
+
co-marketing opportunities. Reach out at
|
| 115 |
+
[contact@embedl.com](mailto:contact@embedl.com), or open an issue on
|
| 116 |
+
[GitHub](https://github.com/embedl/embedl-deploy).
|
embedl_mobilevit_small_int8.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:320b83fa490c5ec19a549ff4bfcb4d7413fb9100b3c2e485cd3e016a81389bf1
|
| 3 |
+
size 22518832
|
embedl_mobilevit_small_int8.pt2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f674b91e7aac2c7ad0a84516108285dd079d50f670aec553b9887ad573e4fd04
|
| 3 |
+
size 46826712
|
infer_pt2.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2026 Embedl AB
|
| 2 |
+
"""Run inference on the Embedl Mobilevit Small INT8 model via torch.export.
|
| 3 |
+
|
| 4 |
+
This script loads the shipped ``embedl_mobilevit_small_int8.pt2``
|
| 5 |
+
artifact with ``torch.export.load`` and runs a single image through
|
| 6 |
+
it. No TensorRT or ONNX runtime is required — just PyTorch.
|
| 7 |
+
|
| 8 |
+
Usage::
|
| 9 |
+
|
| 10 |
+
python infer_pt2.py --image path/to/image.jpg
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import Image
|
| 19 |
+
|
| 20 |
+
PT2_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.pt2")
|
| 21 |
+
INPUT_SIZE = (256, 256)
|
| 22 |
+
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 23 |
+
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def preprocess(image_path: Path) -> torch.Tensor:
|
| 27 |
+
# MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
|
| 28 |
+
# normalization (matches the upstream HF processor: do_normalize=None).
|
| 29 |
+
image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
|
| 30 |
+
arr = np.asarray(image, dtype=np.float32) / 255.0
|
| 31 |
+
arr = (arr - MEAN) / STD
|
| 32 |
+
arr = arr[..., ::-1].copy() # RGB -> BGR
|
| 33 |
+
arr = arr.transpose(2, 0, 1)[None] # NCHW
|
| 34 |
+
return torch.from_numpy(arr)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main() -> None:
|
| 38 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 39 |
+
parser.add_argument("--image", required=True, type=Path)
|
| 40 |
+
parser.add_argument("--topk", type=int, default=5)
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
if not PT2_PATH.exists():
|
| 44 |
+
raise SystemExit(
|
| 45 |
+
f"Expected {PT2_PATH.name} next to this script. "
|
| 46 |
+
"Did you `huggingface-cli download` the repo?"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# The ExportedProgram captured the model in eval mode at export
|
| 50 |
+
# time, so no further .eval() / no_grad toggling is needed (and
|
| 51 |
+
# neither is supported on the .module() wrapper).
|
| 52 |
+
model = torch.export.load(str(PT2_PATH)).module()
|
| 53 |
+
|
| 54 |
+
x = preprocess(args.image)
|
| 55 |
+
logits = model(x)
|
| 56 |
+
probs = torch.softmax(logits, dim=-1).squeeze(0)
|
| 57 |
+
|
| 58 |
+
topk_vals, topk_idx = probs.topk(args.topk)
|
| 59 |
+
print(f"Top-{args.topk} predictions for {args.image}:")
|
| 60 |
+
for rank, (idx, val) in enumerate(zip(topk_idx.tolist(), topk_vals.tolist()), 1):
|
| 61 |
+
print(f" {rank}. class {idx:>5d} ({val * 100:5.2f}%)")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
infer_trt.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2026 Embedl AB
|
| 2 |
+
"""Run inference on the Embedl Mobilevit Small INT8 model via TensorRT.
|
| 3 |
+
|
| 4 |
+
This script builds a TensorRT engine from the shipped
|
| 5 |
+
``embedl_mobilevit_small_int8.onnx`` artifact (Q/DQ nodes baked in by
|
| 6 |
+
embedl-deploy) and runs a single image through it. The first run
|
| 7 |
+
caches the engine to ``embedl_mobilevit_small_int8.engine`` so reuse is
|
| 8 |
+
fast.
|
| 9 |
+
|
| 10 |
+
Requires TensorRT >= 10.1 and pycuda (or cuda-python). Tested on
|
| 11 |
+
NVIDIA Jetson AGX Orin (JetPack 6) and discrete GPUs with CUDA 12.
|
| 12 |
+
|
| 13 |
+
Usage::
|
| 14 |
+
|
| 15 |
+
python infer_trt.py --image path/to/image.jpg
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import time
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tensorrt as trt
|
| 24 |
+
from PIL import Image
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import pycuda.autoinit # noqa: F401 (initializes CUDA context)
|
| 28 |
+
import pycuda.driver as cuda
|
| 29 |
+
except ImportError as exc: # pragma: no cover
|
| 30 |
+
raise SystemExit(
|
| 31 |
+
"pycuda is required. Install with: pip install pycuda"
|
| 32 |
+
) from exc
|
| 33 |
+
|
| 34 |
+
ONNX_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.onnx")
|
| 35 |
+
ENGINE_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.engine")
|
| 36 |
+
INPUT_SIZE = (256, 256)
|
| 37 |
+
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 38 |
+
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)
|
| 39 |
+
|
| 40 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def build_engine() -> bytes:
|
| 44 |
+
builder = trt.Builder(TRT_LOGGER)
|
| 45 |
+
network = builder.create_network(
|
| 46 |
+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 47 |
+
)
|
| 48 |
+
parser = trt.OnnxParser(network, TRT_LOGGER)
|
| 49 |
+
with open(ONNX_PATH, "rb") as f:
|
| 50 |
+
if not parser.parse(f.read()):
|
| 51 |
+
for i in range(parser.num_errors):
|
| 52 |
+
print(parser.get_error(i))
|
| 53 |
+
raise RuntimeError("ONNX parse failed.")
|
| 54 |
+
config = builder.create_builder_config()
|
| 55 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 56 |
+
config.set_flag(trt.BuilderFlag.INT8)
|
| 57 |
+
config.builder_optimization_level = 5
|
| 58 |
+
serialized = builder.build_serialized_network(network, config)
|
| 59 |
+
if serialized is None:
|
| 60 |
+
raise RuntimeError("Engine build failed.")
|
| 61 |
+
return bytes(serialized)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def load_or_build_engine() -> trt.ICudaEngine:
|
| 65 |
+
if ENGINE_PATH.exists():
|
| 66 |
+
data = ENGINE_PATH.read_bytes()
|
| 67 |
+
else:
|
| 68 |
+
print(f"Building engine (first run) → {ENGINE_PATH.name} …")
|
| 69 |
+
data = build_engine()
|
| 70 |
+
ENGINE_PATH.write_bytes(data)
|
| 71 |
+
runtime = trt.Runtime(TRT_LOGGER)
|
| 72 |
+
return runtime.deserialize_cuda_engine(data)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def preprocess(image_path: Path) -> np.ndarray:
|
| 76 |
+
# MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
|
| 77 |
+
# normalization (matches the upstream HF processor: do_normalize=None).
|
| 78 |
+
image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
|
| 79 |
+
arr = np.asarray(image, dtype=np.float32) / 255.0
|
| 80 |
+
arr = (arr - MEAN) / STD
|
| 81 |
+
arr = arr[..., ::-1] # RGB -> BGR
|
| 82 |
+
return np.ascontiguousarray(arr.transpose(2, 0, 1)[None])
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main() -> None:
|
| 86 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 87 |
+
parser.add_argument("--image", required=True, type=Path)
|
| 88 |
+
parser.add_argument("--topk", type=int, default=5)
|
| 89 |
+
args = parser.parse_args()
|
| 90 |
+
|
| 91 |
+
if not ONNX_PATH.exists():
|
| 92 |
+
raise SystemExit(
|
| 93 |
+
f"Expected {ONNX_PATH.name} next to this script. "
|
| 94 |
+
"Did you download the HF repo?"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
engine = load_or_build_engine()
|
| 98 |
+
context = engine.create_execution_context()
|
| 99 |
+
|
| 100 |
+
input_name = engine.get_tensor_name(0)
|
| 101 |
+
output_name = engine.get_tensor_name(1)
|
| 102 |
+
out_shape = tuple(engine.get_tensor_shape(output_name))
|
| 103 |
+
|
| 104 |
+
x = preprocess(args.image)
|
| 105 |
+
h_out = np.empty(out_shape, dtype=np.float32)
|
| 106 |
+
|
| 107 |
+
d_in = cuda.mem_alloc(x.nbytes)
|
| 108 |
+
d_out = cuda.mem_alloc(h_out.nbytes)
|
| 109 |
+
stream = cuda.Stream()
|
| 110 |
+
|
| 111 |
+
cuda.memcpy_htod_async(d_in, x, stream)
|
| 112 |
+
context.set_tensor_address(input_name, int(d_in))
|
| 113 |
+
context.set_tensor_address(output_name, int(d_out))
|
| 114 |
+
|
| 115 |
+
# Warm-up + timed run.
|
| 116 |
+
for _ in range(5):
|
| 117 |
+
context.execute_async_v3(stream.handle)
|
| 118 |
+
stream.synchronize()
|
| 119 |
+
t0 = time.perf_counter()
|
| 120 |
+
context.execute_async_v3(stream.handle)
|
| 121 |
+
stream.synchronize()
|
| 122 |
+
latency_ms = (time.perf_counter() - t0) * 1000.0
|
| 123 |
+
|
| 124 |
+
cuda.memcpy_dtoh_async(h_out, d_out, stream)
|
| 125 |
+
stream.synchronize()
|
| 126 |
+
|
| 127 |
+
logits = h_out.reshape(-1)
|
| 128 |
+
probs = np.exp(logits - logits.max())
|
| 129 |
+
probs /= probs.sum()
|
| 130 |
+
top = probs.argsort()[::-1][: args.topk]
|
| 131 |
+
|
| 132 |
+
print(f"Latency (single-run, GPU compute): {latency_ms:.2f} ms")
|
| 133 |
+
print(f"Top-{args.topk} predictions for {args.image}:")
|
| 134 |
+
for rank, idx in enumerate(top, 1):
|
| 135 |
+
print(f" {rank}. class {idx:>5d} ({probs[idx] * 100:5.2f}%)")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|