mini-omni-s2s / slam_llm /utils /fsdp_utils.py
xcczach's picture
Upload 73 files
35c1cfd verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
def fsdp_auto_wrap_policy(model, transformer_layer_name):
import functools
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
# FullyShardedDataParallelPlugin.get_module_class_from_name(
# model, transformer_layer_name
# ),
),
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy