Transformers documentation

使用 torch.compile() 优化推理

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.46.3).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

使用 torch.compile() 优化推理

本指南旨在为使用torch.compile()🤗 Transformers中的计算机视觉模型中引入的推理速度提升提供一个基准。

torch.compile 的优势

根据模型和GPU的不同,torch.compile()在推理过程中可以提高多达30%的速度。要使用torch.compile(),只需安装2.0及以上版本的torch即可。

编译模型需要时间,因此如果您只需要编译一次模型而不是每次推理都编译,那么它非常有用。 要编译您选择的任何计算机视觉模型,请按照以下方式调用torch.compile()

from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to("cuda")
+ model = torch.compile(model)

compile() 提供了多种编译模式,它们在编译时间和推理开销上有所不同。max-autotunereduce-overhead 需要更长的时间,但会得到更快的推理速度。默认模式在编译时最快,但在推理时间上与 reduce-overhead 相比效率较低。在本指南中,我们使用了默认模式。您可以在这里了解更多信息。

我们在 PyTorch 2.0.1 版本上使用不同的计算机视觉模型、任务、硬件类型和数据批量大小对 torch.compile 进行了基准测试。

基准测试代码

以下是每个任务的基准测试代码。我们在推理之前”预热“GPU,并取300次推理的平均值,每次使用相同的图像。

使用 ViT 进行图像分类

import torch
from PIL import Image
import requests
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to("cuda")
model = torch.compile(model)

processed_input = processor(image, return_tensors='pt').to(device="cuda")

with torch.no_grad():
    _ = model(**processed_input)

使用 DETR 进行目标检测

from transformers import AutoImageProcessor, AutoModelForObjectDetection

processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to("cuda")
model = torch.compile(model)

texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**inputs)

使用 Segformer 进行图像分割

from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to("cuda")
model = torch.compile(model)
seg_inputs = processor(images=image, return_tensors="pt").to("cuda")

with torch.no_grad():
    _ = model(**seg_inputs)

以下是我们进行基准测试的模型列表。

图像分类

图像分割

目标检测

Duration Comparison on V100 with Batch Size of 1

Percentage Improvement on T4 with Batch Size of 4

下面可以找到每个模型使用和不使用compile()的推理时间(毫秒)。请注意,OwlViT在大批量大小下会导致内存溢出。

A100 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 9.325 7.584
Image Segmentation/Segformer 11.759 10.500
Object Detection/OwlViT 24.978 18.420
Image Classification/BeiT 11.282 8.448
Object Detection/DETR 34.619 19.040
Image Classification/ConvNeXT 10.410 10.208
Image Classification/ResNet 6.531 4.124
Image Segmentation/Mask2former 60.188 49.117
Image Segmentation/Maskformer 75.764 59.487
Image Segmentation/MobileNet 8.583 3.974
Object Detection/Resnet-101 36.276 18.197
Object Detection/Conditional-DETR 31.219 17.993

A100 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 14.832 14.499
Image Segmentation/Segformer 18.838 16.476
Image Classification/BeiT 13.205 13.048
Object Detection/DETR 48.657 32.418
Image Classification/ConvNeXT 22.940 21.631
Image Classification/ResNet 6.657 4.268
Image Segmentation/Mask2former 74.277 61.781
Image Segmentation/Maskformer 180.700 159.116
Image Segmentation/MobileNet 14.174 8.515
Object Detection/Resnet-101 68.101 44.998
Object Detection/Conditional-DETR 56.470 35.552

A100 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 40.944 40.010
Image Segmentation/Segformer 37.005 31.144
Image Classification/BeiT 41.854 41.048
Object Detection/DETR 164.382 161.902
Image Classification/ConvNeXT 82.258 75.561
Image Classification/ResNet 7.018 5.024
Image Segmentation/Mask2former 178.945 154.814
Image Segmentation/Maskformer 638.570 579.826
Image Segmentation/MobileNet 51.693 30.310
Object Detection/Resnet-101 232.887 155.021
Object Detection/Conditional-DETR 180.491 124.032

V100 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 10.495 6.00
Image Segmentation/Segformer 13.321 5.862
Object Detection/OwlViT 25.769 22.395
Image Classification/BeiT 11.347 7.234
Object Detection/DETR 33.951 19.388
Image Classification/ConvNeXT 11.623 10.412
Image Classification/ResNet 6.484 3.820
Image Segmentation/Mask2former 64.640 49.873
Image Segmentation/Maskformer 95.532 72.207
Image Segmentation/MobileNet 9.217 4.753
Object Detection/Resnet-101 52.818 28.367
Object Detection/Conditional-DETR 39.512 20.816

V100 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 15.181 14.501
Image Segmentation/Segformer 16.787 16.188
Image Classification/BeiT 15.171 14.753
Object Detection/DETR 88.529 64.195
Image Classification/ConvNeXT 29.574 27.085
Image Classification/ResNet 6.109 4.731
Image Segmentation/Mask2former 90.402 76.926
Image Segmentation/Maskformer 234.261 205.456
Image Segmentation/MobileNet 24.623 14.816
Object Detection/Resnet-101 134.672 101.304
Object Detection/Conditional-DETR 97.464 69.739

V100 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 52.209 51.633
Image Segmentation/Segformer 61.013 55.499
Image Classification/BeiT 53.938 53.581
Object Detection/DETR OOM OOM
Image Classification/ConvNeXT 109.682 100.771
Image Classification/ResNet 14.857 12.089
Image Segmentation/Mask2former 249.605 222.801
Image Segmentation/Maskformer 831.142 743.645
Image Segmentation/MobileNet 93.129 55.365
Object Detection/Resnet-101 482.425 361.843
Object Detection/Conditional-DETR 344.661 255.298

T4 (batch size: 1)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 16.520 15.786
Image Segmentation/Segformer 16.116 14.205
Object Detection/OwlViT 53.634 51.105
Image Classification/BeiT 16.464 15.710
Object Detection/DETR 73.100 53.99
Image Classification/ConvNeXT 32.932 30.845
Image Classification/ResNet 6.031 4.321
Image Segmentation/Mask2former 79.192 66.815
Image Segmentation/Maskformer 200.026 188.268
Image Segmentation/MobileNet 18.908 11.997
Object Detection/Resnet-101 106.622 82.566
Object Detection/Conditional-DETR 77.594 56.984

T4 (batch size: 4)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 43.653 43.626
Image Segmentation/Segformer 45.327 42.445
Image Classification/BeiT 52.007 51.354
Object Detection/DETR 277.850 268.003
Image Classification/ConvNeXT 119.259 105.580
Image Classification/ResNet 13.039 11.388
Image Segmentation/Mask2former 201.540 184.670
Image Segmentation/Maskformer 764.052 711.280
Image Segmentation/MobileNet 74.289 48.677
Object Detection/Resnet-101 421.859 357.614
Object Detection/Conditional-DETR 289.002 226.945

T4 (batch size: 16)

Task/Model torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ViT 163.914 160.907
Image Segmentation/Segformer 192.412 163.620
Image Classification/BeiT 188.978 187.976
Object Detection/DETR OOM OOM
Image Classification/ConvNeXT 422.886 388.078
Image Classification/ResNet 44.114 37.604
Image Segmentation/Mask2former 756.337 695.291
Image Segmentation/Maskformer 2842.940 2656.88
Image Segmentation/MobileNet 299.003 201.942
Object Detection/Resnet-101 1619.505 1262.758
Object Detection/Conditional-DETR 1137.513 897.390

PyTorch Nightly

我们还在 PyTorch Nightly 版本(2.1.0dev)上进行了基准测试,可以在这里找到 Nightly 版本的安装包,并观察到了未编译和编译模型的延迟性能改善。

A100

Task/Model Batch Size torch 2.0 - no compile torch 2.0 -
compile
Image Classification/BeiT Unbatched 12.462 6.954
Image Classification/BeiT 4 14.109 12.851
Image Classification/BeiT 16 42.179 42.147
Object Detection/DETR Unbatched 30.484 15.221
Object Detection/DETR 4 46.816 30.942
Object Detection/DETR 16 163.749 163.706

T4

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/BeiT Unbatched 14.408 14.052
Image Classification/BeiT 4 47.381 46.604
Image Classification/BeiT 16 42.179 42.147
Object Detection/DETR Unbatched 68.382 53.481
Object Detection/DETR 4 269.615 204.785
Object Detection/DETR 16 OOM OOM

V100

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/BeiT Unbatched 13.477 7.926
Image Classification/BeiT 4 15.103 14.378
Image Classification/BeiT 16 52.517 51.691
Object Detection/DETR Unbatched 28.706 19.077
Object Detection/DETR 4 88.402 62.949
Object Detection/DETR 16 OOM OOM

降低开销

我们在 PyTorch Nightly 版本中为 A100 和 T4 进行了 reduce-overhead 编译模式的性能基准测试。

A100

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ConvNeXT Unbatched 11.758 7.335
Image Classification/ConvNeXT 4 23.171 21.490
Image Classification/ResNet Unbatched 7.435 3.801
Image Classification/ResNet 4 7.261 2.187
Object Detection/Conditional-DETR Unbatched 32.823 11.627
Object Detection/Conditional-DETR 4 50.622 33.831
Image Segmentation/MobileNet Unbatched 9.869 4.244
Image Segmentation/MobileNet 4 14.385 7.946

T4

Task/Model Batch Size torch 2.0 -
no compile
torch 2.0 -
compile
Image Classification/ConvNeXT Unbatched 32.137 31.84
Image Classification/ConvNeXT 4 120.944 110.209
Image Classification/ResNet Unbatched 9.761 7.698
Image Classification/ResNet 4 15.215 13.871
Object Detection/Conditional-DETR Unbatched 72.150 57.660
Object Detection/Conditional-DETR 4 301.494 247.543
Image Segmentation/MobileNet Unbatched 22.266 19.339
Image Segmentation/MobileNet 4 78.311 50.983
< > Update on GitHub