Optimum documentation

Optimization

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v1.19.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Optimization

🤗 Optimum Intel provides an openvino package that enables you to apply a variety of model compression methods such as quantization, pruning, on many models hosted on the 🤗 hub using the NNCF framework.

Post-training

Quantization is a technique to reduce the computational and memory costs of running inference by representing the weights and / or the activations with lower precision data types like 8-bit or 4-bit.

Weight-only quantization

Quantization can be applied on the model’s Linear, Convolutional and Embedding layers, enabling the loading of large models on memory-limited devices. For example, when applying 8-bit quantization, the resulting model will be x4 smaller than its fp32 counterpart. For 4-bit quantization, the reduction in memory could theoretically reach x8, but is closer to x6 in practice.

8-bit

For the 8-bit weight quantization you can set load_in_8bit=True to load your model’s weights in 8-bit:

from optimum.intel import OVModelForCausalLM

model_id = "helenai/gpt2-ov"
model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)

# Saves the int8 model that will be x4 smaller than its fp32 counterpart
model.save_pretrained(saving_directory)

load_in_8bit is enabled by default for the models larger than 1 billion parameters. You can disable it with load_in_8bit=False.

You can also provide a quantization_config instead to specify additional optimization parameters.

4-bit

For the 4-bit weight quantization, you need a quantization_config to define the optimization parameters, for example:

from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig

quantization_config = OVWeightQuantizationConfig(bits=4)
model = OVModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)

You can tune quantization parameters to achieve a better performance accuracy trade-off as follows:

quantization_config = OVWeightQuantizationConfig(bits=4, sym=False, ratio=0.8, dataset="ptb")

By default the quantization scheme will be asymmetric, to make it symmetric you can add sym=True.

For 4-bit quantization you can also specify the following arguments in the quantization configuration :

  • The group_size parameter will define the group size to use for quantization, -1 it will results in per-column quantization.
  • The ratio parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to int4 while 10% will be quantized to int8.

Smaller group_size and ratio values usually improve accuracy at the sacrifice of the model size and inference latency.

Static quantization

When applying post-training static quantization, both the weights and the activations are quantized. To apply quantization on the activations, an additional calibration step is needed which consists in feeding a calibration_dataset to the network in order to estimate the quantization activations parameters.

Here is how to apply static quantization on a fine-tuned DistilBERT given your own calibration_dataset:

from transformers import AutoTokenizer
from optimum.intel import OVQuantizer, OVModelForSequenceClassification,

model_id = "distilbert-base-uncased-finetuned-sst-2-english"
model = OVModelForSequenceClassification.from_pretrained(model_id, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# The directory where the quantized model will be saved
save_dir = "ptq_model"

quantizer = OVQuantizer.from_pretrained(model)

# Apply static quantization and export the resulting quantized model to OpenVINO IR format
quantizer.quantize(calibration_dataset=calibration_dataset, save_directory=save_dir)
# Save the tokenizer
tokenizer.save_pretrained(save_dir)

The calibration dataset can also be created easily using your OVQuantizer:

from functools import partial

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["sentence"], padding="max_length", max_length=128, truncation=True)

# Create the calibration dataset used to perform static quantization
calibration_dataset = quantizer.get_calibration_dataset(
    "glue",
    dataset_config_name="sst2",
    preprocess_function=partial(preprocess_function, tokenizer=tokenizer),
    num_samples=300,
    dataset_split="train",
)

The quantize() method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device.

Hybrid quantization

Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion (SD) models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights. The U-Net component takes up most of the overall execution time of the pipeline. Thus, optimizing just this one component can bring substantial benefits in terms of inference speed while keeping acceptable accuracy without fine-tuning. Quantizing the rest of the diffusion pipeline does not significantly improve inference performance but could potentially lead to substantial accuracy degradation. Therefore, the proposal is to apply quantization in hybrid mode for the U-Net model and weight-only quantization for the rest of the pipeline components :

  • U-Net : quantization applied on both the weights and activations
  • The text encoder, VAE encoder / decoder : quantization applied on the weights

The hybrid mode involves the quantization of weights in MatMul and Embedding layers, and activations of other layers, facilitating accuracy preservation post-optimization while reducing the model size.

The quantization_config is utilized to define optimization parameters for optimizing the SD pipeline. To enable hybrid quantization, specify the quantization dataset in the quantization_config. If the dataset is not defined, weight-only quantization will be applied on all components.

from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig

model = OVStableDiffusionPipeline.from_pretrained(
    model_id,
    export=True,
    quantization_config=OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions"),
)

For more details, please refer to the corresponding NNCF documentation.

Training-time

Apart from optimizing a model after training like post-training quantization above, optimum.openvino also provides optimization methods during training, namely Quantization-Aware Training (QAT) and Joint Pruning, Quantization and Distillation (JPQD).

Quantization-Aware Training (QAT)

QAT simulates the effects of quantization during training, in order to alleviate its effects on the model’s accuracy. It is recommended in the case where post-training quantization results in high accuracy degradation. Here is an example on how to fine-tune a DistilBERT on the sst-2 task while applying quantization aware training (QAT).

  import evaluate
  import numpy as np
  from transformers import (
      AutoModelForSequenceClassification,
      AutoTokenizer,
      TrainingArguments,
      default_data_collator,
  )
  from datasets import load_dataset
- from transformers import Trainer
+ from optimum.intel import OVConfig, OVTrainer, OVModelForSequenceClassification

  model_id = "distilbert-base-uncased-finetuned-sst-2-english"
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
  tokenizer = AutoTokenizer.from_pretrained(model_id)
  # The directory where the quantized model will be saved
  save_dir = "qat_model"
  dataset = load_dataset("glue", "sst2")
  dataset = dataset.map(
      lambda examples: tokenizer(examples["sentence"], padding=True), batched=True
  )
  metric = evaluate.load("glue", "sst2")

  def compute_metrics(eval_preds):
      preds = np.argmax(eval_preds.predictions, axis=1)
      return metric.compute(predictions=preds, references=eval_preds.label_ids)

  # Load the default quantization configuration detailing the quantization we wish to apply
+ ov_config = OVConfig()

- trainer = Trainer(
+ trainer = OVTrainer(
      model=model,
      args=TrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=True),
      train_dataset=dataset["train"].select(range(300)),
      eval_dataset=dataset["validation"],
      compute_metrics=compute_metrics,
      tokenizer=tokenizer,
      data_collator=default_data_collator,
+     ov_config=ov_config,
+     task="text-classification",
)

  # Train the model while applying quantization
  train_result = trainer.train()
  metrics = trainer.evaluate()
  # Export the quantized model to OpenVINO IR format and save it
  trainer.save_model()

  # Load the resulting quantized model
- model = AutoModelForSequenceClassification.from_pretrained(save_dir)
+ model = OVModelForSequenceClassification.from_pretrained(save_dir)

Joint Pruning, Quantization and Distillation (JPQD)

Other than quantization, compression methods like pruning and distillation are common in further improving the task performance and efficiency. Structured pruning slims a model for lower computational demands while distillation leverages knowledge of a teacher, usually, larger model to improve model prediction. Combining these methods with quantization can result in optimized model with significant efficiency improvement while enjoying good task accuracy retention. In optimum.openvino, OVTrainer provides the capability to jointly prune, quantize and distill a model during training. Following is an example on how to perform the optimization on BERT-base for the sst-2 task.

First, we create a config dictionary to specify the target algorithms. As optimum.openvino relies on NNCF as backend, the config format follows NNCF specifications (see here). In the example config below, we specify pruning and quantization in a list of compression with thier hyperparameters. The pruning method closely resembles the work of Lagunas et al., 2021, Block Pruning For Faster Transformers whereas the quantization refers to QAT. With this configuration, the model under optimization will be initialized with pruning and quantization operators at the beginning of the training.

compression_config = [
    {
        "compression":
        {
        "algorithm":  "movement_sparsity",
        "params": {
            "warmup_start_epoch":  1,
            "warmup_end_epoch":    4,
            "importance_regularization_factor":  0.01,
            "enable_structured_masking":  True
        },
        "sparse_structure_by_scopes": [
            {"mode":  "block",   "sparse_factors": [32, 32], "target_scopes": "{re}.*BertAttention.*"},
            {"mode":  "per_dim", "axis":  0,                 "target_scopes": "{re}.*BertIntermediate.*"},
            {"mode":  "per_dim", "axis":  1,                 "target_scopes": "{re}.*BertOutput.*"},
        ],
        "ignored_scopes": ["{re}.*NNCFEmbedding", "{re}.*pooler.*", "{re}.*LayerNorm.*"]
        }
    },
    {
        "algorithm": "quantization",
        "weights": {"mode": "symmetric"}
        "activations": { "mode": "symmetric"},
    }
]

Known limitation: Current structured pruning with movement sparsity only supports BERT, Wav2vec2 and Swin family of models. See here for more information.

Once we have the config ready, we can start develop the training pipeline like the snippet below. Since we are customizing joint compression with config above, notice that OVConfig is initialized with config dictionary (JSON parsing to python dictionary is skipped for brevity). As for distillation, users are required to load the teacher model, it is just like a normal model loading with transformers API. OVTrainingArguments extends transformers’ TrainingArguments with distillation hyperparameters, i.e. distillation weightage and temperature for ease of use. The snippet below shows how we load a teacher model and create training arguments with OVTrainingArguments. Subsequently, the teacher model, with the instantiated OVConfig and OVTrainingArguments are fed to OVTrainer. Voila! that is all we need, the rest of the pipeline is identical to native transformers training.

- from transformers import Trainer, TrainingArguments
+ from optimum.intel import OVConfig, OVTrainer, OVTrainingArguments

  # Load teacher model
+ teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_or_path)

- ov_config = OVConfig()
+ ov_config = OVConfig(compression=compression_config)

  trainer = OVTrainer(
      model=model,
+     teacher_model=teacher_model,
-     args=TrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=True),
+     args=OVTrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=True, distillation_temperature=3, distillation_weight=0.9),
      train_dataset=dataset["train"].select(range(300)),
      eval_dataset=dataset["validation"],
      compute_metrics=compute_metrics,
      tokenizer=tokenizer,
      data_collator=default_data_collator,
+     ov_config=ov_config,
      task="text-classification",
  )

  # Train the model like usual, internally the training is applied with pruning, quantization and distillation
  train_result = trainer.train()
  metrics = trainer.evaluate()
  # Export the quantized model to OpenVINO IR format and save it
  trainer.save_model()

More on the description and how to configure movement sparsity, see NNCF documentation here.

More on available algorithms in NNCF, see documentation here.

For complete JPQD scripts, please refer to examples provided here.

Quantization-Aware Training (QAT) and knowledge distillation can also be combined in order to optimize Stable Diffusion models while maintaining accuracy. For more details, take a look at this blog post.

Inference with Transformers pipeline

After applying quantization on our model, we can then easily load it with our OVModelFor<Task> classes and perform inference with OpenVINO Runtime using the Transformers pipelines.

from transformers import pipeline
from optimum.intel import OVModelForSequenceClassification

model_id = "helenai/distilbert-base-uncased-finetuned-sst-2-english-ov-int8"
ov_model = OVModelForSequenceClassification.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
cls_pipe = pipeline("text-classification", model=ov_model, tokenizer=tokenizer)
text = "He's a dreadful magician."
outputs = cls_pipe(text)

[{'label': 'NEGATIVE', 'score': 0.9840195178985596}]