| from functools import partial |
|
|
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| CheckpointImpl, |
| apply_activation_checkpointing, |
| checkpoint_wrapper, |
| ) |
|
|
|
|
| non_reentrant_wrapper = partial( |
| checkpoint_wrapper, |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| ) |
|
|
|
|
| def apply_checkpointing(model, block, p): |
| """ |
| Apply selective activation checkpointing. |
| |
| Selectivity is defined as a percentage p, which means we apply ac |
| on p of the total blocks. p is a floating number in the range of |
| [0, 1]. |
| |
| Some examples: |
| p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` |
| p = 1: apply ac on every block. i.e. "full ac". |
| p = 1/2: [ac, no-ac, ac, no-ac, ...] |
| p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] |
| p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] |
| Since blocks are homogeneous, we make ac blocks evenly spaced among |
| all blocks. |
| |
| Implementation: |
| For a given ac ratio p, we should essentially apply ac on every "1/p" |
| blocks. The first ac block can be as early as the 0th block, or as |
| late as the "1/p"th block, and we pick the middle one: (0.5p)th block. |
| Therefore, we are essentially to apply ac on: |
| (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, |
| with these values rounding to integers. |
| Since ac is applied recursively, we can simply use the following math |
| in the code to apply ac on corresponding blocks. |
| """ |
| block_idx = 0 |
| cut_off = 1 / 2 |
| |
| |
| p = eval(p) if isinstance(p, str) else p |
|
|
| def selective_checkpointing(submodule): |
| nonlocal block_idx |
| nonlocal cut_off |
|
|
| if isinstance(submodule, block): |
| block_idx += 1 |
| if block_idx * p >= cut_off: |
| cut_off += 1 |
| return True |
| return False |
|
|
| apply_activation_checkpointing( |
| model, |
| checkpoint_wrapper_fn=non_reentrant_wrapper, |
| check_fn=selective_checkpointing, |
| ) |
|
|