|
|
|
|
|
import functools |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
from peft.utils.other import fsdp_auto_wrap_policy |
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
|
CheckpointImpl, |
|
|
apply_activation_checkpointing, |
|
|
checkpoint_wrapper, |
|
|
) |
|
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy |
|
|
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
|
|
|
|
|
from .load import get_no_split_modules |
|
|
from torch.distributed.fsdp import BackwardPrefetch |
|
|
non_reentrant_wrapper = partial( |
|
|
checkpoint_wrapper, |
|
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
|
|
) |
|
|
|
|
|
|
|
|
def apply_fsdp_checkpointing(model, no_split_modules, p=1): |
|
|
|
|
|
"""apply activation checkpointing to model |
|
|
returns None as model is updated directly |
|
|
""" |
|
|
print("--> applying fdsp activation checkpointing...") |
|
|
block_idx = 0 |
|
|
cut_off = 1 / 2 |
|
|
|
|
|
|
|
|
p = eval(p) if isinstance(p, str) else p |
|
|
|
|
|
def selective_checkpointing(submodule): |
|
|
nonlocal block_idx |
|
|
nonlocal cut_off |
|
|
|
|
|
if isinstance(submodule, no_split_modules): |
|
|
block_idx += 1 |
|
|
if block_idx * p >= cut_off: |
|
|
cut_off += 1 |
|
|
return True |
|
|
return False |
|
|
|
|
|
apply_activation_checkpointing( |
|
|
model, |
|
|
checkpoint_wrapper_fn=non_reentrant_wrapper, |
|
|
check_fn=selective_checkpointing, |
|
|
) |
|
|
|
|
|
|
|
|
def get_mixed_precision(master_weight_type="fp32"): |
|
|
weight_type = torch.float32 if master_weight_type == "fp32" else torch.bfloat16 |
|
|
mixed_precision = MixedPrecision( |
|
|
param_dtype=weight_type, |
|
|
|
|
|
reduce_dtype=weight_type, |
|
|
|
|
|
buffer_dtype=weight_type, |
|
|
cast_forward_inputs=False, |
|
|
) |
|
|
return mixed_precision |
|
|
|
|
|
|
|
|
def get_dit_fsdp_kwargs( |
|
|
transformer, |
|
|
sharding_strategy, |
|
|
use_lora=False, |
|
|
cpu_offload=False, |
|
|
master_weight_type="fp32", |
|
|
): |
|
|
no_split_modules = get_no_split_modules(transformer) |
|
|
if use_lora: |
|
|
auto_wrap_policy = fsdp_auto_wrap_policy |
|
|
else: |
|
|
auto_wrap_policy = functools.partial( |
|
|
transformer_auto_wrap_policy, |
|
|
transformer_layer_cls=no_split_modules, |
|
|
) |
|
|
|
|
|
|
|
|
mixed_precision = get_mixed_precision(master_weight_type) |
|
|
|
|
|
|
|
|
if sharding_strategy == "full": |
|
|
sharding_strategy = ShardingStrategy.FULL_SHARD |
|
|
elif sharding_strategy == "hybrid_full": |
|
|
sharding_strategy = ShardingStrategy.HYBRID_SHARD |
|
|
elif sharding_strategy == "none": |
|
|
sharding_strategy = ShardingStrategy.NO_SHARD |
|
|
auto_wrap_policy = None |
|
|
elif sharding_strategy == "hybrid_zero2": |
|
|
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 |
|
|
elif sharding_strategy == 'shard_grad_op': |
|
|
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP |
|
|
|
|
|
device_id = torch.cuda.current_device() |
|
|
cpu_offload = ( |
|
|
torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None |
|
|
) |
|
|
fsdp_kwargs = { |
|
|
"auto_wrap_policy": auto_wrap_policy, |
|
|
"mixed_precision": mixed_precision, |
|
|
"sharding_strategy": sharding_strategy, |
|
|
"device_id": device_id, |
|
|
"limit_all_gathers": True, |
|
|
"cpu_offload": cpu_offload, |
|
|
} |
|
|
|
|
|
|
|
|
if len(no_split_modules) != 0 and use_lora: |
|
|
fsdp_kwargs.update( |
|
|
{ |
|
|
"use_orig_params": False, |
|
|
"sync_module_states": True, |
|
|
} |
|
|
) |
|
|
elif len(no_split_modules) == 0 and use_lora: |
|
|
fsdp_kwargs.update({"use_orig_params": True}) |
|
|
|
|
|
return fsdp_kwargs, no_split_modules |
|
|
|
|
|
|
|
|
def get_discriminator_fsdp_kwargs(master_weight_type="fp32"): |
|
|
auto_wrap_policy = None |
|
|
|
|
|
|
|
|
mixed_precision = get_mixed_precision(master_weight_type) |
|
|
sharding_strategy = ShardingStrategy.NO_SHARD |
|
|
device_id = torch.cuda.current_device() |
|
|
fsdp_kwargs = { |
|
|
"auto_wrap_policy": auto_wrap_policy, |
|
|
"mixed_precision": mixed_precision, |
|
|
"sharding_strategy": sharding_strategy, |
|
|
"device_id": device_id, |
|
|
"limit_all_gathers": True, |
|
|
} |
|
|
|
|
|
return fsdp_kwargs |
|
|
def get_vae_fsdp_kwargs(master_weight_type="fp32", cpu_offload=False): |
|
|
auto_wrap_policy = None |
|
|
|
|
|
|
|
|
mixed_precision = get_mixed_precision(master_weight_type) |
|
|
|
|
|
sharding_strategy = ShardingStrategy.FULL_SHARD |
|
|
|
|
|
|
|
|
|
|
|
device_id = torch.cuda.current_device() |
|
|
cpu_offload = ( |
|
|
torch.distributed.fsdp.CPUOffload(offload_params=True) if cpu_offload else None |
|
|
) |
|
|
|
|
|
fsdp_kwargs = { |
|
|
"auto_wrap_policy": auto_wrap_policy, |
|
|
"mixed_precision": mixed_precision, |
|
|
"sharding_strategy": sharding_strategy, |
|
|
"device_id": device_id, |
|
|
"limit_all_gathers": True, |
|
|
"cpu_offload": cpu_offload, |
|
|
"limit_all_gathers": True, |
|
|
"use_orig_params": True, |
|
|
|
|
|
} |
|
|
|
|
|
return fsdp_kwargs |