levit-256-onnx / README.md
Felix Marty
better readme
b93cbc9
|
raw
history blame
2.24 kB
metadata
license: apache-2.0
tags:
  - vision
  - image-classification
datasets:
  - imagenet-1k

This model is a fork of facebook/levit-256, where:

  • nn.BatchNorm2d and nn.Conv2d are fused
  • nn.BatchNorm1d and nn.Linear are fused

and the optimized model is converted to the onnx format.

How to use

from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoModelForImageClassification

pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
pt_model.eval()

ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")

inp = {"pixel_values": torch.rand(1, 3, 224, 224)}

with torch.no_grad():
    res = pt_model(**inp)
res_ort = ort_model(**inp)

assert torch.allclose(res.logits, res_ort.logits, atol=1e-4)

Benchmarking

More than x2 throughput with batch normalization folding and onnxruntime 🔥

PyTorch runtime:

{'latency_50': 22.3024695,
 'latency_90': 23.1230725,
 'latency_95': 23.2653985,
 'latency_99': 23.60095705,
 'latency_999': 23.865580469999998,
 'latency_mean': 22.442956878923766,
 'latency_std': 0.46544295612971265,
 'nb_forwards': 446,
 'throughput': 44.6}

Optimum-onnxruntime runtime:

{'latency_50': 9.302445,
 'latency_90': 9.782875,
 'latency_95': 9.9071944,
 'latency_99': 11.084606999999997,
 'latency_999': 12.035858692000001,
 'latency_mean': 9.357703552853133,
 'latency_std': 0.4018553286992142,
 'nb_forwards': 1069,
 'throughput': 106.9}
from optimum.runs_base import TimeBenchmark

from pprint import pprint

time_benchmark_ort = TimeBenchmark(
    model=ort_model,
    batch_size=1,
    input_length=224,
    model_input_names={"pixel_values"},
    warmup_runs=10,
    duration=10
)

results_ort = time_benchmark_ort.execute()

with torch.no_grad():
    time_benchmark_pt = TimeBenchmark(
        model=pt_model,
        batch_size=1,
        input_length=224,
        model_input_names={"pixel_values"},
        warmup_runs=10,
        duration=10
    )

    results_pt = time_benchmark_pt.execute()

print("PyTorch runtime:\n")
pprint(results_pt)

print("\nOptimum-onnxruntime runtime:\n")
pprint(results_ort)