Transformers documentation

Optimize inference using torch.compile()

You are viewing v4.36.1 version. A newer version v4.46.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Optimize inference using torch.compile()

このガイドは、torch.compile() を使用した推論速度の向上に関するベンチマークを提供することを目的としています。これは、🤗 Transformers のコンピュータビジョンモデル向けのものです。

Benefits of torch.compile

torch.compile()の利点 モデルとGPUによっては、torch.compile()は推論時に最大30%の高速化を実現します。 torch.compile()を使用するには、バージョン2.0以上のtorchをインストールするだけです。

モデルのコンパイルには時間がかかるため、毎回推論するのではなく、モデルを1度だけコンパイルする場合に役立ちます。 任意のコンピュータビジョンモデルをコンパイルするには、以下のようにモデルにtorch.compile()を呼び出します:

from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to("cuda")
+ model = torch.compile(model)

compile() は、コンパイルに関する異なるモードを備えており、基本的にはコンパイル時間と推論のオーバーヘッドが異なります。max-autotunereduce-overhead よりも時間がかかりますが、推論速度が速くなります。デフォルトモードはコンパイルにおいては最速ですが、推論時間においては reduce-overhead に比べて効率が良くありません。このガイドでは、デフォルトモードを使用しました。詳細については、こちら を参照してください。

torch バージョン 2.0.1 で異なるコンピュータビジョンモデル、タスク、ハードウェアの種類、およびバッチサイズを使用して torch.compile をベンチマークしました。

Benchmarking code

以下に、各タスクのベンチマークコードを示します。推論前にGPUをウォームアップし、毎回同じ画像を使用して300回の推論の平均時間を取得します。

Image Classification with ViT

from PIL import Image
import requests
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to("cuda")
model = torch.compile(model)

processed_input = processor(image, return_tensors='pt').to(device="cuda")

with torch.no_grad():
    _ = model(**processed_input)

Object Detection with DETR

from transformers import AutoImageProcessor, AutoModelForObjectDetection

processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
model = torch.compile(model)

texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**inputs)

Image Segmentation with Segformer

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to("cuda")
model = torch.compile(model)
seg_inputs = processor(images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**seg_inputs)

以下は、私たちがベンチマークを行ったモデルのリストです。

Image Classification

Image Segmentation

Object Detection

以下は、torch.compile()を使用した場合と使用しない場合の推論時間の可視化と、異なるハードウェアとバッチサイズの各モデルに対するパフォーマンス向上の割合です。

Duration Comparison on V100 with Batch Size of 1

Percentage Improvement on T4 with Batch Size of 4

下記は、各モデルについてcompile()を使用した場合と使用しなかった場合の推論時間(ミリ秒単位)です。なお、OwlViTは大きなバッチサイズでの使用時にメモリ不足(OOM)が発生することに注意してください。

A100 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 9.325 7.584
Image Segmentation/Segformer 11.759 10.500
Object Detection/OwlViT 24.978 18.420
Image Classification/BeiT 11.282 8.448
Object Detection/DETR 34.619 19.040
Image Classification/ConvNeXT 10.410 10.208
Image Classification/ResNet 6.531 4.124
Image Segmentation/Mask2former 60.188 49.117
Image Segmentation/Maskformer 75.764 59.487
Image Segmentation/MobileNet 8.583 3.974
Object Detection/Resnet-101 36.276 18.197
Object Detection/Conditional-DETR 31.219 17.993

A100 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 14.832 14.499
Image Segmentation/Segformer 18.838 16.476
Image Classification/BeiT 13.205 13.048
Object Detection/DETR 48.657 32.418
Image Classification/ConvNeXT 22.940 21.631
Image Classification/ResNet 6.657 4.268
Image Segmentation/Mask2former 74.277 61.781
Image Segmentation/Maskformer 180.700 159.116
Image Segmentation/MobileNet 14.174 8.515
Object Detection/Resnet-101 68.101 44.998
Object Detection/Conditional-DETR 56.470 35.552

A100 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 40.944 40.010
Image Segmentation/Segformer 37.005 31.144
Image Classification/BeiT 41.854 41.048
Object Detection/DETR 164.382 161.902
Image Classification/ConvNeXT 82.258 75.561
Image Classification/ResNet 7.018 5.024
Image Segmentation/Mask2former 178.945 154.814
Image Segmentation/Maskformer 638.570 579.826
Image Segmentation/MobileNet 51.693 30.310
Object Detection/Resnet-101 232.887 155.021
Object Detection/Conditional-DETR 180.491 124.032

V100 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 10.495 6.00
Image Segmentation/Segformer 13.321 5.862
Object Detection/OwlViT 25.769 22.395
Image Classification/BeiT 11.347 7.234
Object Detection/DETR 33.951 19.388
Image Classification/ConvNeXT 11.623 10.412
Image Classification/ResNet 6.484 3.820
Image Segmentation/Mask2former 64.640 49.873
Image Segmentation/Maskformer 95.532 72.207
Image Segmentation/MobileNet 9.217 4.753
Object Detection/Resnet-101 52.818 28.367
Object Detection/Conditional-DETR 39.512 20.816

V100 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 15.181 14.501
Image Segmentation/Segformer 16.787 16.188
Image Classification/BeiT 15.171 14.753
Object Detection/DETR 88.529 64.195
Image Classification/ConvNeXT 29.574 27.085
Image Classification/ResNet 6.109 4.731
Image Segmentation/Mask2former 90.402 76.926
Image Segmentation/Maskformer 234.261 205.456
Image Segmentation/MobileNet 24.623 14.816
Object Detection/Resnet-101 134.672 101.304
Object Detection/Conditional-DETR 97.464 69.739

V100 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 52.209 51.633
Image Segmentation/Segformer 61.013 55.499
Image Classification/BeiT 53.938 53.581
Object Detection/DETR OOM OOM
Image Classification/ConvNeXT 109.682 100.771
Image Classification/ResNet 14.857 12.089
Image Segmentation/Mask2former 249.605 222.801
Image Segmentation/Maskformer 831.142 743.645
Image Segmentation/MobileNet 93.129 55.365
Object Detection/Resnet-101 482.425 361.843
Object Detection/Conditional-DETR 344.661 255.298

T4 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 16.520 15.786
Image Segmentation/Segformer 16.116 14.205
Object Detection/OwlViT 53.634 51.105
Image Classification/BeiT 16.464 15.710
Object Detection/DETR 73.100 53.99
Image Classification/ConvNeXT 32.932 30.845
Image Classification/ResNet 6.031 4.321
Image Segmentation/Mask2former 79.192 66.815
Image Segmentation/Maskformer 200.026 188.268
Image Segmentation/MobileNet 18.908 11.997
Object Detection/Resnet-101 106.622 82.566
Object Detection/Conditional-DETR 77.594 56.984

T4 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 43.653 43.626
Image Segmentation/Segformer 45.327 42.445
Image Classification/BeiT 52.007 51.354
Object Detection/DETR 277.850 268.003
Image Classification/ConvNeXT 119.259 105.580
Image Classification/ResNet 13.039 11.388
Image Segmentation/Mask2former 201.540 184.670
Image Segmentation/Maskformer 764.052 711.280
Image Segmentation/MobileNet 74.289 48.677
Object Detection/Resnet-101 421.859 357.614
Object Detection/Conditional-DETR 289.002 226.945

T4 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 163.914 160.907
Image Segmentation/Segformer 192.412 163.620
Image Classification/BeiT 188.978 187.976
Object Detection/DETR OOM OOM
Image Classification/ConvNeXT 422.886 388.078
Image Classification/ResNet 44.114 37.604
Image Segmentation/Mask2former 756.337 695.291
Image Segmentation/Maskformer 2842.940 2656.88
Image Segmentation/MobileNet 299.003 201.942
Object Detection/Resnet-101 1619.505 1262.758
Object Detection/Conditional-DETR 1137.513 897.390

PyTorch Nightly

また、PyTorchのナイトリーバージョン(2.1.0dev)でのベンチマークを行い、コンパイルされていないモデルとコンパイル済みモデルの両方でレイテンシーの向上を観察しました。ホイールはこちらから入手できます。

A100

Task/Model Batch Size torch 2.0 - no compile torch 2.0 -
compile
Image Classification/BeiT Unbatched 12.462 6.954
Image Classification/BeiT 4 14.109 12.851
Image Classification/BeiT 16 42.179 42.147
Object Detection/DETR Unbatched 30.484 15.221
Object Detection/DETR 4 46.816 30.942
Object Detection/DETR 16 163.749 163.706

T4

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/BeiT Unbatched 14.408 14.052
Image Classification/BeiT 4 47.381 46.604
Image Classification/BeiT 16 42.179 42.147
Object Detection/DETR Unbatched 68.382 53.481
Object Detection/DETR 4 269.615 204.785
Object Detection/DETR 16 OOM OOM

### V100

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/BeiT Unbatched 13.477 7.926
Image Classification/BeiT 4 15.103 14.378
Image Classification/BeiT 16 52.517 51.691
Object Detection/DETR Unbatched 28.706 19.077
Object Detection/DETR 4 88.402 62.949
Object Detection/DETR 16 OOM OOM

Reduce Overhead

NightlyビルドでA100およびT4向けの reduce-overhead コンパイルモードをベンチマークしました。

A100

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ConvNeXT Unbatched 11.758 7.335
Image Classification/ConvNeXT 4 23.171 21.490
Image Classification/ResNet Unbatched 7.435 3.801
Image Classification/ResNet 4 7.261 2.187
Object Detection/Conditional-DETR Unbatched 32.823 11.627
Object Detection/Conditional-DETR 4 50.622 33.831
Image Segmentation/MobileNet Unbatched 9.869 4.244
Image Segmentation/MobileNet 4 14.385 7.946

T4

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ConvNeXT Unbatched 32.137 31.84
Image Classification/ConvNeXT 4 120.944 110.209
Image Classification/ResNet Unbatched 9.761 7.698
Image Classification/ResNet 4 15.215 13.871
Object Detection/Conditional-DETR Unbatched 72.150 57.660
Object Detection/Conditional-DETR 4 301.494 247.543
Image Segmentation/MobileNet Unbatched 22.266 19.339
Image Segmentation/MobileNet 4 78.311 50.983