feat: selective activation checkpointing
This PR hasn't been tested yet
This PR adds selective activation checkpointing to the BERT model.
By passing activation_checkpoint_lvl
in the config, you can set how many of the BERT layers will be checkpointed if gradient_checkpointing_enable()
is called. Reducing this number will save computation at the cost of increased VRAM usage. Checkpointing will not go into effect until gradient_checkpointing_enable()
is called.
By default, the value is 100
, which means that for any reasonable architecture, all layers will be checkpointed. For the base model, it might make sense to set this to something like 6
to checkpoint half of the layers.
We enforce that MLP checkpointing cannot occur within a checkpointed layer.
For pretraining, I think it would make sense to set this parameter to 0
, even though nothing should happen before gradient_checkpointing_enable()
is called. But better safe than sorry.