Optimum documentation

Adding BetterTransformer support for new architectures

You are viewing v1.10.1 version. A newer version v1.23.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Adding BetterTransformer support for new architectures

You want to add a new model for Better Transformer, the fast path of PyTorch Transformer API? Check this guideline!

Models that should be supported

In theory, any model that has a transformer encoder layer, similar to the classic encoder described in the “Attention Is All You Need” paper should be supported. More specifically, a model that has an encoder block with a MultiHead-Attention module (with pre or post-attention layer norm) should be convertible to its BetterTransformer equivalent. The conditions can be summarized as follows:

  • Use classic Multi Head attention module (for example, DeBERTa cannot be supported)
  • Use either gelu or relu activation function
  • Have an even number of attention heads
  • Do not use any attention bias (for eg T5 uses attention bias, therefore cannot be supported)
  • eps must be equal between the first and second layer norms for each layer

How to convert a model into its BetterTransformer format?

Step 1: Identifying the source layer to change

First, go to optimum/bettertransformer/__init__.py and you’ll see the dictionary BetterTransformerManager.MODEL_MAPPING. This should contain the mapping between a model type, and the Tuple[str, BetterTransformerBaseLayer] composed of the name of the nn.Module that can be converted to its BetterTransformer equivalent, and effectively the equivalent BetterTransformer layer class.

Let us try to do it step by step for Bert, first we need to identify the layers that needs to be replaced:

>>> from transformers import AutoModel

>>> model = AutoModel.from_pretrained("bert-base-uncased")
>>> print(model)
...
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

You can clearly see that the layers that need to be replaced are the BertLayer modules since they contain the whole encoder layer module.

Step 2: Building the xxxLayerBetterTransformer module

Check that the identified module is not already copied from another module (by inspecting the source code in transformers and checking that the class definition does not start with # Copied from ...) - and if not, create a class in bettertransformer/models/encoder_model.py. Start with those lines:

import torch
import torch.nn as nn

from ..base import BetterTransformerBaseLayer


class BertLayerBetterTransformer(BetterTransformerBaseLayer):
    def __init__(self, bert_layer, config):
...

Now, make sure to fill all the necessary attributes, the list of attributes are:

  • in_proj_weight
  • in_proj_bias
  • out_proj_weight
  • out_proj_bias
  • linear1_weight
  • linear1_bias
  • linear2_weight
  • linear2_bias
  • norm1_eps
  • norm1_weight
  • norm1_bias
  • norm2_weight
  • norm2_bias
  • num_heads
  • embed_dim

Note that these attributes correspond to all the components that are necessary to run a Transformer Encoder module, check the figure 1 on the “Attention Is All You Need” paper.

Once you filled all these attributes (sometimes the query, key and value layers needs to be “contigufied”, check the modeling_encoder.py file to understand more.)

Make sure also to add the lines:

self.is_last_layer = False
self.validate_bettertransformer()

Step 3: Building the forward pass

First of all, start with the line super().forward_checker(), this is needed so that the parent class can run all the safety checkers before.

After the first forward pass, the hidden states needs to be nested using the attention mask. Once they are nested, the attention mask is not needed anymore, therefore can be set to None. This is how the forward pass is built for Bert, these lines should remain pretty much similar accross models, but sometimes the shapes of the attention masks are different across models.

super().forward_checker()

if hidden_states.is_nested:
    attention_mask = None

if attention_mask is not None:
    # attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
    # 0->false->keep this token -inf->true->mask this token
    attention_mask = attention_mask.bool()
    attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
    seqlen = attention_mask.shape[1]
    lengths = torch.sum(~attention_mask, 1)
    if not all([l == seqlen for l in lengths]):
        hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
    attention_mask = None

Once the hidden_states are nested, call torch._transformer_encoder_layer_fwd using the right arguments as follows:

hidden_states = torch._transformer_encoder_layer_fwd(
    hidden_states,
    self.embed_dim,
    self.num_heads,
    self.in_proj_weight,
    self.in_proj_bias,
    self.out_proj_weight,
    self.out_proj_bias,
    self.use_gelu,
    self.norm_first,
    self.norm1_eps,
    self.norm1_weight,
    self.norm1_bias,
    self.norm2_weight,
    self.norm2_bias,
    self.linear1_weight,
    self.linear1_bias,
    self.linear2_weight,
    self.linear2_bias,
    attention_mask,
)

At the last layer, it is important to “un-nest” the hidden_states so that it can be processed by the next modules, this is done in these lines:

if hidden_states.is_nested and self.is_last_layer:
    hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)

Also make sure to return a tuple to follow the convention of transformers.

The best way to reproduce this experiment on your own model is to try it by get some inspiration from the provided modeling scripts. Of course, we will be happy to help you converting your model if you open an issue or a Pull Request on optimum!

Step 4: Sanity check!

As a last step, make sure to update the BetterTransformerManager.MODEL_MAPPING dictionary in optimum/bettertransformer/__init__.py with the correct names, and you should be ready to convert your model. For example, for Bert that would be:

MODEL_MAPPING = {
  ...
  "bert": ("BertLayer", BertLayerBetterTransformer),
  ...
}

Try it out with the conversion method that is presented in the tutorials sections!