feat: selective activation checkpointing

#16

This PR hasn't been tested yet

This PR adds selective activation checkpointing to the BERT model.
By passing activation_checkpoint_lvl in the config, you can set how many of the BERT layers will be checkpointed if gradient_checkpointing_enable() is called. Reducing this number will save computation at the cost of increased VRAM usage. Checkpointing will not go into effect until gradient_checkpointing_enable() is called.

By default, the value is 100, which means that for any reasonable architecture, all layers will be checkpointed. For the base model, it might make sense to set this to something like 6 to checkpoint half of the layers.
We enforce that MLP checkpointing cannot occur within a checkpointed layer.

For pretraining, I think it would make sense to set this parameter to 0, even though nothing should happen before gradient_checkpointing_enable() is called. But better safe than sorry.

Publish this branch
This branch is in draft mode, publish it to be able to merge.

Sign up or log in to comment