Accelerate documentation

Handling big models for inference

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Handling big models for inference

One of the biggest advancements 🤗 Accelerate provides is the concept of large model inference wherein you can perform inference on models that cannot fully fit on your graphics card.

This tutorial will be broken down into two parts showcasing how to use both 🤗 Accelerate and 🤗 Transformers (a higher API-level) to make use of this idea.

Using 🤗 Accelerate

For these tutorials, we’ll assume a typical workflow for loading your model in such that:

import torch

my_model = ModelClass(...)
state_dict = torch.load(checkpoint_file)
my_model.load_state_dict(state_dict)

Note that here we assume that ModelClass is a model that takes up more video-card memory than what can fit on your device (be it mps or cuda).

The first step is to init an empty skeleton of the model which won’t take up any RAM using the init_empty_weights() context manager:

from accelerate import init_empty_weights
with init_empty_weights():
    my_model = ModelClass(...)

With this my_model currently is “parameterless”, hence leaving the smaller footprint than what one would normally get loading this onto the CPU directly.

Next we need to load in the weights to our model so we can perform inference.

For this we will use load_checkpoint_and_dispatch(), which as the name implies will load a checkpoint inside your empty model and dispatch the weights for each layer across all the devices you have available (GPU/MPS and CPU RAM).

To determine how this dispatch can be performed, generally specifying device_map="auto" will be good enough as 🤗 Accelerate will attempt to fill all the space in your GPU(s), then loading them to the CPU, and finally if there is not enough RAM it will be loaded to the disk (the absolute slowest option).

For more details on designing your own device map, see this section of the concept guide

See an example below:

from accelerate import load_checkpoint_and_dispatch

model = load_checkpoint_and_dispatch(
    model, checkpoint=checkpoint_file, device_map="auto"
)

If there are certain “chunks” of layers that shouldn’t be split, you can pass them in as no_split_module_classes. Read more about it here

Also to save on memory (such as if the state_dict will not fit in RAM), a model’s weights can be divided and split into multiple checkpoint files. Read more about it here

Now that the model is dispatched fully, you can perform inference as normal with the model:

input = torch.randn(2,3)
input = input.to("cuda")
output = model(input)

What will happen now is each time the input gets passed through a layer, it will be sent from the CPU to the GPU (or disk to CPU to GPU), the output is calculated, and then the layer is pulled back off the GPU going back down the line. While this adds some overhead to the inference being performed, through this method it is possible to run any size model on your system, as long as the largest layer is capable of fitting on your GPU.

Multiple GPUs can be utilized, however this is considered “model parallelism” and as a result only one GPU will be active at a given moment, waiting for the prior one to send it the output. You should launch your script normally with python and not need torchrun, accelerate launch, etc.

For a visual representation of this, check out the animation below:

Complete Example

Below is the full example showcasing what we performed above:

import torch
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

with init_empty_weights():
    model = MyModel(...)

model = load_checkpoint_and_dispatch(
    model, checkpoint=checkpoint_file, device_map="auto"
)

input = torch.randn(2,3)
input = input.to("cuda")
output = model(input)

Using 🤗 Transformers, 🤗 Diffusers, and other 🤗 Open Source Libraries

Libraries that support 🤗 Accelerate big model inference include all of the earlier logic in their from_pretrained constructors.

These operate by specifying a string representing the model to download from the 🤗 Hub and then denoting device_map="auto" along with a few extra parameters.

As a brief example, we will look at using transformers and loading in Big Science’s T0pp model.

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto")

After loading the model in, the initial steps from before to prepare a model have all been done and the model is fully ready to make use of all the resources in your machine. Through these constructors, you can also save more memory by specifying the precision the model is loaded into as well, through the torch_dtype parameter, such as:

from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("bigscience/T0pp", device_map="auto", torch_dtype=torch.float16)

To learn more about this, check out the 🤗 Transformers documentation available here.

Where to go from here

For a much more detailed look at big model inference, be sure to check out the Conceptual Guide on it