Optimum documentation

Run Inference

You are viewing v1.15.0 version. A newer version v1.23.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Run Inference

This section shows how to run inference-only workloads on Gaudi. For more advanced information about how to speed up inference, check out this guide.

With GaudiTrainer

You can find below a template to perform inference with a GaudiTrainer instance where we want to compute the accuracy over the given dataset:

import evaluate

metric = evaluate.load("accuracy")

# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def my_compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

# Trainer initialization
trainer = GaudiTrainer(
        model=my_model,
        gaudi_config=my_gaudi_config,
        args=my_args,
        train_dataset=None,
        eval_dataset=eval_dataset,
        compute_metrics=my_compute_metrics,
        tokenizer=my_tokenizer,
        data_collator=my_data_collator,
    )

# Run inference
metrics = trainer.evaluate()

The variable my_args should contain some inference-specific arguments, you can take a look here to see the arguments that can be interesting to set for inference.

In our Examples

All our examples contain instructions for running inference with a given model on a given dataset. The reasoning is the same for every example: run the example script with --do_eval and --per_device_eval_batch_size and without --do_train. A simple template is the following:

python path_to_the_example_script \
  --model_name_or_path my_model_name \
  --gaudi_config_name my_gaudi_config_name \
  --dataset_name my_dataset_name \
  --do_eval \
  --per_device_eval_batch_size my_batch_size \
  --output_dir path_to_my_output_dir \
  --use_habana \
  --use_lazy_mode \
  --use_hpu_graphs_for_inference