Accelerated inference on AMD GPUs supported by ROCm
By default, ONNX Runtime runs inference on CPU devices. However, it is possible to place supported operations on an AMD Instinct GPU, while leaving any unsupported ones on CPU. In most cases, this allows costly operations to be placed on GPU and significantly accelerate inference.
Our testing involved AMD Instinct GPUs, and for specific GPU compatibility, please refer to the official support list of GPUs available here.
This guide will show you how to run inference on the ROCMExecutionProvider
execution provider that ONNX Runtime supports for AMD GPUs.
Installation
The following setup installs the ONNX Runtime support with ROCM Execution Provider with ROCm 6.0.
1 ROCm Installation
Refer to the ROCm installation guide to install ROCm 6.0.
2 Installing onnxruntime-rocm
Please use the provided Dockerfile example or do a local installation from source since pip wheels are currently unavailable.
Docker Installation:
docker build -f Dockerfile -t ort/rocm .
Local Installation Steps:
2.1 PyTorch with ROCm Support
Optimum ONNX Runtime integration relies on some functionalities of Transformers that require PyTorch. For now, we recommend to use Pytorch compiled against RoCm 6.0, that can be installed following PyTorch installation guide:
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0
# Use 'rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2' as the preferred base image when using Docker for PyTorch installation.
2.2 ONNX Runtime with ROCm Execution Provider
# pre-requisites
pip install -U pip
pip install cmake onnx
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# Install ONNXRuntime from source
git clone --single-branch --branch main --recursive https://github.com/Microsoft/onnxruntime onnxruntime
cd onnxruntime
./build.sh --config Release --build_wheel --allow_running_as_root --update --build --parallel --cmake_extra_defines CMAKE_HIP_ARCHITECTURES=gfx90a,gfx942 ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm
pip install build/Linux/Release/dist/*
Note: The instructions build ORT for MI210/MI250/MI300
gpus. To support other architectures, please update the CMAKE_HIP_ARCHITECTURES
in the build command.
Checking the ROCm installation is successful
Before going further, run the following sample code to check whether the install was successful:
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
>>> from transformers import AutoTokenizer
>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
... "philschmid/tiny-bert-sst2-distilled",
... export=True,
... provider="ROCMExecutionProvider",
... )
>>> tokenizer = AutoTokenizer.from_pretrained("philschmid/tiny-bert-sst2-distilled")
>>> inputs = tokenizer("expectations were low, actual enjoyment was high", return_tensors="pt", padding=True)
>>> outputs = ort_model(**inputs)
>>> assert ort_model.providers == ["ROCMExecutionProvider", "CPUExecutionProvider"]
In case this code runs gracefully, congratulations, the installation is successful! If you encounter the following error or similar,
ValueError: Asked to use ROCMExecutionProvider as an ONNX Runtime execution provider, but the available execution providers are ['CPUExecutionProvider'].
then something is wrong with the ROCM or ONNX Runtime installation.
Use ROCM Execution Provider with ORT models
For ORT models, the use is straightforward. Simply specify the provider
argument in the ORTModel.from_pretrained()
method. Here’s an example:
>>> from optimum.onnxruntime import ORTModelForSequenceClassification
>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
... "distilbert-base-uncased-finetuned-sst-2-english",
... export=True,
... provider="ROCMExecutionProvider",
... )
The model can then be used with the common 🤗 Transformers API for inference and evaluation, such as pipelines.
When using Transformers pipeline, note that the device
argument should be set to perform pre- and post-processing on GPU, following the example below:
>>> from optimum.pipelines import pipeline
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>>> pipe = pipeline(task="text-classification", model=ort_model, tokenizer=tokenizer, device="cuda:0")
>>> result = pipe("Both the music and visual were astounding, not to mention the actors performance.")
>>> print(result)
# printing: [{'label': 'POSITIVE', 'score': 0.9997727274894c714}]
Additionally, you can pass the session option log_severity_level = 0
(verbose), to check whether all nodes are indeed placed on the ROCM execution provider or not:
>>> import onnxruntime
>>> session_options = onnxruntime.SessionOptions()
>>> session_options.log_severity_level = 0
>>> ort_model = ORTModelForSequenceClassification.from_pretrained(
... "distilbert-base-uncased-finetuned-sst-2-english",
... export=True,
... provider="ROCMExecutionProvider",
... session_options=session_options
... )
Observed time gains
Coming soon!
< > Update on GitHub