mpt-7b / hf_fsdp.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
3ff9962 verified
import functools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder
if TYPE_CHECKING:
from peft import PeftModel
def rhasattr(obj: Any, attr: str) -> bool:
"""A chain-able attribute version of hasattr.
For example, to check if
`obj` has the attribute `foo.bar.baz`, you can use:
`rhasattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/67303315
"""
_nested_attrs = attr.split('.')
_curr_obj = obj
for _a in _nested_attrs[:-1]:
if hasattr(_curr_obj, _a):
_curr_obj = getattr(_curr_obj, _a)
else:
return False
return hasattr(_curr_obj, _nested_attrs[-1])
def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
"""A chain-able attribute version of getattr.
For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
`rgetattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/31174427
"""
def _getattr(obj: Any, attr: str):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
for attr in attrs:
if rhasattr(obj, attr):
return rgetattr(obj, attr)
return None
def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
"""Returns the causal decoder backbone of the specified HuggingFace model.
Newer HF models have a `self.get_decoder()` method. Older models do not.
NOTE: Different model configurations have different causal decoder attribute
names.
- transformer: (GPT2LMHeadModel, GPTJConfig)
- model.decoder: (OPTConfig, BloomConfig)
- gpt_neox: (GPTNeoXConfig)
"""
if hasattr(model, 'get_decoder'):
return model.get_decoder()
decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', 'model.transformer')
causal_base_model = findattr(model, decoder_attrs)
if causal_base_model is None:
raise ValueError(f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.')
return causal_base_model
def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
"""Returns the hidden layers of the specified model.
Expects to receive the causal decoder backbone, not he XXForCausalLM wrapper.
NOTE: Different model configurations have different hidden layer attribute names.
- h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
- decoder.layers: (OPTForCausalLM)
- layers: (GPTNeoXForCausalLM, LlaMaForCausalLM)
- blocks: (MPTForCausalLM)
"""
hidden_layers_attrs = ('h', 'decoder.layers', 'layers', 'block', 'blocks')
layers = findattr(model, hidden_layers_attrs)
if layers is None:
raise ValueError(f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}')
return layers
def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
"""Returns the appropriate device to initialize models."""
if init_device == 'mixed':
if dist.get_local_rank() == 0:
return 'cpu'
return 'meta'
return init_device
def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace model.
Call specific functions
"""
if model.config.is_encoder_decoder:
prepare_hf_enc_dec_model_for_fsdp(model, init_device)
else:
prepare_hf_causal_lm_model_for_fsdp(model, init_device)
def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'PeftModel'], init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace decoder.
Wrap any model for FSDP which follows one of the 3 existing conventions from
HuggingFace for decoder-only LLMs.
"""
causal_base_model = hf_get_causal_base_model(model)
if isinstance(causal_base_model, OPTDecoder) or model.config.model_type == 'olmo':
underlying_model = maybe_get_underlying_model(model)
underlying_model.model._fsdp_wrap = False
model_block = hf_get_hidden_layers(causal_base_model)
lm_head = model.get_output_embeddings()
try:
tied_embeddings = causal_base_model.get_input_embeddings()
except:
tied_embeddings = model.get_input_embeddings()
modules = {'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
for mod_name, module in modules.items():
if module is None:
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
block_type = type(model_block[0])
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False
if hasattr(model, 'peft_type') and model.peft_type is not None:
peft_type = model.peft_type.lower()
active_adapters = [adapter.lower() for adapter in model.active_adapters]
for name, module in model.named_modules():
if peft_type in name.lower() and any((adapter in name.lower() for adapter in active_adapters)):
has_parameters = next(module.parameters(), None) is not None
has_buffers = next(module.buffers(), None) is not None
if has_parameters or has_buffers:
module._fsdp_wrap = True
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
model.activation_checkpointing_fn = lambda module: isinstance(module, block_type)
def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
"""Wrap an encoder/decoder HF model.
This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
You have model.shared, model.encoder, model.decoder and model.lm_head, where
model.shared are the embeddings which are tied to model.lm_head, and
model.shared == model.encoder.embed_tokens and model.shared ==
model.decoder.embed_tokens
"""
tied_embeddings = model.get_input_embeddings()
encoder = model.get_encoder()
decoder = model.get_decoder()
lm_head = model.get_output_embeddings()
encoder_block = hf_get_hidden_layers(encoder)
decoder_block = hf_get_hidden_layers(decoder)
modules = {'encoder': encoder, 'decoder': decoder, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
for mod_name, module in modules.items():
if module is None:
raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
decoder_block_type = type(decoder_block[0])
encoder_block_type = type(encoder_block[0])
if model.config.tie_word_embeddings:
tied_embeddings._fsdp_wrap = False
encoder._fsdp_wrap = False
decoder._fsdp_wrap = False
lm_head._fsdp_wrap = False
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
model.activation_checkpointing_fn = lambda module: isinstance(module, decoder_block_type)
if encoder_block_type == decoder_block_type:
return
model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
model.activation_checkpointing_fn = lambda module: isinstance(module, encoder_block_type)