Adding BetterTransformer support for new architectures
You want to add a new model for Better Transformer
, the fast path of PyTorch Transformer API? Check this guideline!
Models that should be supported
In theory, any model that has a transformer encoder layer, similar to the classic encoder described in the “Attention Is All You Need” paper should be supported.
More specifically, a model that has an encoder block with a MultiHead-Attention module (with pre or post-attention layer norm) should be convertible to its BetterTransformer
equivalent. The conditions can be summarized as follows:
- Use classic Multi Head attention module (for example, DeBERTa cannot be supported)
- Use either
gelu
orrelu
activation function - Have an even number of attention heads
- Do not use any attention bias (for eg
T5
uses attention bias, therefore cannot be supported) eps
must be equal between the first and second layer norms for each layer
How to convert a model into its BetterTransformer
format?
Step 1: Identifying the source layer to change
First, go to optimum/bettertransformer/__init__.py
and you’ll see the dictionary BetterTransformerManager.MODEL_MAPPING
. This should contain the mapping between a model type, and the Tuple[str, BetterTransformerBaseLayer]
composed of the name of the nn.Module
that can be converted to its BetterTransformer
equivalent, and effectively the equivalent BetterTransformer
layer class.
Let us try to do it step by step for Bert
, first we need to identify the layers that needs to be replaced:
>>> from transformers import AutoModel
>>> model = AutoModel.from_pretrained("bert-base-uncased")
>>> print(model)
...
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
You can clearly see that the layers that need to be replaced are the BertLayer
modules since they contain the whole encoder layer module.
Step 2: Building the xxxLayerBetterTransformer
module
Check that the identified module is not already copied from another module (by inspecting the source code in transformers
and checking that the class definition does not start with # Copied from ...
) - and if not, create a class in bettertransformer/models/encoder_model.py
.
Start with those lines:
import torch
import torch.nn as nn
from ..base import BetterTransformerBaseLayer
class BertLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, bert_layer, config):
...
Now, make sure to fill all the necessary attributes, the list of attributes are:
in_proj_weight
in_proj_bias
out_proj_weight
out_proj_bias
linear1_weight
linear1_bias
linear2_weight
linear2_bias
norm1_eps
norm1_weight
norm1_bias
norm2_weight
norm2_bias
num_heads
embed_dim
Note that these attributes correspond to all the components that are necessary to run a Transformer Encoder module, check the figure 1 on the “Attention Is All You Need” paper.
Once you filled all these attributes (sometimes the query
, key
and value
layers needs to be “contigufied”, check the modeling_encoder.py
file to understand more.)
Make sure also to add the lines:
self.is_last_layer = False
self.validate_bettertransformer()
Step 3: Building the forward pass
First of all, start with the line super().forward_checker()
, this is needed so that the parent class can run all the safety checkers before.
After the first forward pass, the hidden states needs to be nested using the attention mask. Once they are nested, the attention mask is not needed anymore, therefore can be set to None
. This is how the forward pass is built for Bert
, these lines should remain pretty much similar accross models, but sometimes the shapes of the attention masks are different across models.
super().forward_checker()
if hidden_states.is_nested:
attention_mask = None
if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None
Once the hidden_states
are nested, call torch._transformer_encoder_layer_fwd
using the right arguments as follows:
hidden_states = torch._transformer_encoder_layer_fwd( hidden_states, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.out_proj_weight, self.out_proj_bias, self.use_gelu, self.norm_first, self.norm1_eps, self.norm1_weight, self.norm1_bias, self.norm2_weight, self.norm2_bias, self.linear1_weight, self.linear1_bias, self.linear2_weight, self.linear2_bias, attention_mask, )
At the last layer, it is important to “un-nest” the hidden_states so that it can be processed by the next modules, this is done in these lines:
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)
Also make sure to return a tuple
to follow the convention of transformers
.
The best way to reproduce this experiment on your own model is to try it by get some inspiration from the provided modeling scripts. Of course, we will be happy to help you converting your model if you open an issue or a Pull Request on optimum
!
Step 4: Sanity check!
As a last step, make sure to update the BetterTransformerManager.MODEL_MAPPING
dictionary in optimum/bettertransformer/__init__.py
with the correct names, and you should be ready to convert your model. For example, for Bert that would be:
MODEL_MAPPING = {
...
"bert": ("BertLayer", BertLayerBetterTransformer),
...
}
Try it out with the conversion method that is presented in the tutorials sections!