File size: 7,560 Bytes
3ff9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import functools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder
if TYPE_CHECKING:
    from peft import PeftModel

def rhasattr(obj: Any, attr: str) -> bool:
    """A chain-able attribute version of hasattr.

    For example, to check if
    `obj` has the attribute `foo.bar.baz`, you can use:
        `rhasattr(obj, "foo.bar.baz")`
    Reference: https://stackoverflow.com/a/67303315
    """
    _nested_attrs = attr.split('.')
    _curr_obj = obj
    for _a in _nested_attrs[:-1]:
        if hasattr(_curr_obj, _a):
            _curr_obj = getattr(_curr_obj, _a)
        else:
            return False
    return hasattr(_curr_obj, _nested_attrs[-1])

def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
    """A chain-able attribute version of getattr.

    For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
        `rgetattr(obj, "foo.bar.baz")`
    Reference: https://stackoverflow.com/a/31174427
    """

    def _getattr(obj: Any, attr: str):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
    for attr in attrs:
        if rhasattr(obj, attr):
            return rgetattr(obj, attr)
    return None

def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
    """Returns the causal decoder backbone of the specified HuggingFace model.

    Newer HF models have a `self.get_decoder()` method. Older models do not.

    NOTE: Different model configurations have different causal decoder attribute
    names.
        - transformer: (GPT2LMHeadModel, GPTJConfig)
        - model.decoder: (OPTConfig, BloomConfig)
        - gpt_neox: (GPTNeoXConfig)
    """
    if hasattr(model, 'get_decoder'):
        return model.get_decoder()
    decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox', 'model.transformer')
    causal_base_model = findattr(model, decoder_attrs)
    if causal_base_model is None:
        raise ValueError(f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.')
    return causal_base_model

def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
    """Returns the hidden layers of the specified model.

    Expects to receive the causal decoder backbone, not he XXForCausalLM wrapper.

    NOTE: Different model configurations have different hidden layer attribute names.
        - h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
        - decoder.layers: (OPTForCausalLM)
        - layers: (GPTNeoXForCausalLM, LlaMaForCausalLM)
        - blocks: (MPTForCausalLM)
    """
    hidden_layers_attrs = ('h', 'decoder.layers', 'layers', 'block', 'blocks')
    layers = findattr(model, hidden_layers_attrs)
    if layers is None:
        raise ValueError(f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}')
    return layers

def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
    """Returns the appropriate device to initialize models."""
    if init_device == 'mixed':
        if dist.get_local_rank() == 0:
            return 'cpu'
        return 'meta'
    return init_device

def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
    """FSDP wrap a HuggingFace model.

    Call specific functions
    """
    if model.config.is_encoder_decoder:
        prepare_hf_enc_dec_model_for_fsdp(model, init_device)
    else:
        prepare_hf_causal_lm_model_for_fsdp(model, init_device)

def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, 'PeftModel'], init_device: Optional[str]) -> None:
    """FSDP wrap a HuggingFace decoder.

    Wrap any model for FSDP which follows one of the 3 existing conventions from
    HuggingFace for decoder-only LLMs.
    """
    causal_base_model = hf_get_causal_base_model(model)
    if isinstance(causal_base_model, OPTDecoder) or model.config.model_type == 'olmo':
        underlying_model = maybe_get_underlying_model(model)
        underlying_model.model._fsdp_wrap = False
    model_block = hf_get_hidden_layers(causal_base_model)
    lm_head = model.get_output_embeddings()
    try:
        tied_embeddings = causal_base_model.get_input_embeddings()
    except:
        tied_embeddings = model.get_input_embeddings()
    modules = {'base_model': causal_base_model, 'model_block': model_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
    for mod_name, module in modules.items():
        if module is None:
            raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
    block_type = type(model_block[0])
    if model.config.tie_word_embeddings:
        causal_base_model._fsdp_wrap = False
        tied_embeddings._fsdp_wrap = False
        lm_head._fsdp_wrap = False
    if hasattr(model, 'peft_type') and model.peft_type is not None:
        peft_type = model.peft_type.lower()
        active_adapters = [adapter.lower() for adapter in model.active_adapters]
        for name, module in model.named_modules():
            if peft_type in name.lower() and any((adapter in name.lower() for adapter in active_adapters)):
                has_parameters = next(module.parameters(), None) is not None
                has_buffers = next(module.buffers(), None) is not None
                if has_parameters or has_buffers:
                    module._fsdp_wrap = True
    model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
    model.activation_checkpointing_fn = lambda module: isinstance(module, block_type)

def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, init_device: Optional[str]) -> None:
    """Wrap an encoder/decoder HF model.

    This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
    You have model.shared, model.encoder, model.decoder and model.lm_head, where
    model.shared are the embeddings which are tied to model.lm_head, and
    model.shared == model.encoder.embed_tokens and model.shared ==
    model.decoder.embed_tokens
    """
    tied_embeddings = model.get_input_embeddings()
    encoder = model.get_encoder()
    decoder = model.get_decoder()
    lm_head = model.get_output_embeddings()
    encoder_block = hf_get_hidden_layers(encoder)
    decoder_block = hf_get_hidden_layers(decoder)
    modules = {'encoder': encoder, 'decoder': decoder, 'encoder_block': encoder_block, 'decoder_block': decoder_block, 'lm_head': lm_head, 'tied_embeddings': tied_embeddings}
    for mod_name, module in modules.items():
        if module is None:
            raise ValueError(f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.')
    decoder_block_type = type(decoder_block[0])
    encoder_block_type = type(encoder_block[0])
    if model.config.tie_word_embeddings:
        tied_embeddings._fsdp_wrap = False
        encoder._fsdp_wrap = False
        decoder._fsdp_wrap = False
        lm_head._fsdp_wrap = False
    model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
    model.activation_checkpointing_fn = lambda module: isinstance(module, decoder_block_type)
    if encoder_block_type == decoder_block_type:
        return
    model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
    model.activation_checkpointing_fn = lambda module: isinstance(module, encoder_block_type)