File size: 1,815 Bytes
b7d8df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Re-usable :class:`.ComposerModel` for LLM HF Models."""
from __future__ import annotations
from collections import UserDict
from typing import TYPE_CHECKING, List, Mapping, Optional
import transformers
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase
from transformers.utils.generic import ModelOutput
from .hf_fsdp import prepare_hf_model_for_fsdp
if TYPE_CHECKING:
    from peft import PeftConfig
_HF_IGNORE_INDEX = -100

class HuggingFaceModelWithFSDP(HuggingFaceModel):
    """Wrapper around HuggingFaceModel.

    Handles preparation for FSDP wrapping.
    """

    def __init__(self, model: transformers.PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase]=None, metrics: Optional[List[Metric]]=None, eval_metrics: Optional[List[Metric]]=None, shift_labels: bool=False, init_device: Optional[str]=None, peft_config: Optional['PeftConfig']=None):
        super().__init__(model, tokenizer, use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True)
        prepare_hf_model_for_fsdp(self.model, init_device)
        self.model.param_init_fn = lambda module: self.model._init_weights(module)

    def forward(self, batch: Mapping):
        if isinstance(batch, dict) or isinstance(batch, UserDict):
            batch = {k: v for k, v in batch.items() if k in self.model_forward_args}
            output = self.model(**batch)
        else:
            raise ValueError('Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model')
        return output

    def loss(self, outputs: ModelOutput, batch: Mapping):
        if self.config.use_return_dict:
            return outputs['loss']
        return outputs[:2]