|
|
import collections |
|
|
from enum import Enum |
|
|
|
|
|
import torch |
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper |
|
|
|
|
|
from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES |
|
|
|
|
|
|
|
|
class CheckpointType(str, Enum): |
|
|
FULL = "full" |
|
|
OPS = "ops" |
|
|
BLOCK_SKIP = "block_skip" |
|
|
|
|
|
|
|
|
_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { |
|
|
torch.ops.aten.mm.default, |
|
|
torch.ops.aten._scaled_dot_product_efficient_attention.default, |
|
|
torch.ops.aten._scaled_dot_product_flash_attention.default, |
|
|
torch.ops._c10d_functional.reduce_scatter_tensor.default, |
|
|
} |
|
|
|
|
|
|
|
|
def apply_activation_checkpointing( |
|
|
module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1 |
|
|
) -> torch.nn.Module: |
|
|
if checkpointing_type == CheckpointType.FULL: |
|
|
module = _apply_activation_checkpointing_blocks(module) |
|
|
elif checkpointing_type == CheckpointType.OPS: |
|
|
module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) |
|
|
elif checkpointing_type == CheckpointType.BLOCK_SKIP: |
|
|
module = _apply_activation_checkpointing_blocks(module, n_layer) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" |
|
|
) |
|
|
return module |
|
|
|
|
|
|
|
|
def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module: |
|
|
for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: |
|
|
blocks: torch.nn.Module = getattr(module, transformer_block_name, None) |
|
|
if blocks is None: |
|
|
continue |
|
|
for index, (layer_id, block) in enumerate(blocks.named_children()): |
|
|
if n_layer is None or index % n_layer == 0: |
|
|
block = checkpoint_wrapper(block, preserve_rng_state=False) |
|
|
blocks.register_module(layer_id, block) |
|
|
return module |
|
|
|
|
|
|
|
|
def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module: |
|
|
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts |
|
|
|
|
|
def _get_custom_policy(meta): |
|
|
def _custom_policy(ctx, func, *args, **kwargs): |
|
|
mode = "recompute" if ctx.is_recompute else "forward" |
|
|
mm_count_key = f"{mode}_mm_count" |
|
|
if func == torch.ops.aten.mm.default: |
|
|
meta[mm_count_key] += 1 |
|
|
|
|
|
to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) |
|
|
return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE |
|
|
|
|
|
return _custom_policy |
|
|
|
|
|
def selective_checkpointing_context_fn(): |
|
|
meta = collections.defaultdict(int) |
|
|
return create_selective_checkpoint_contexts(_get_custom_policy(meta)) |
|
|
|
|
|
return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False) |
|
|
|