"""Re-usable :class:`.ComposerModel` for LLM HF Models.""" from __future__ import annotations from collections import UserDict from typing import TYPE_CHECKING, List, Mapping, Optional import transformers from torchmetrics import Metric from transformers import PreTrainedTokenizerBase from transformers.utils.generic import ModelOutput from .hf_fsdp import prepare_hf_model_for_fsdp if TYPE_CHECKING: from peft import PeftConfig _HF_IGNORE_INDEX = -100 class HuggingFaceModelWithFSDP(HuggingFaceModel): """Wrapper around HuggingFaceModel. Handles preparation for FSDP wrapping. """ def __init__(self, model: transformers.PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase]=None, metrics: Optional[List[Metric]]=None, eval_metrics: Optional[List[Metric]]=None, shift_labels: bool=False, init_device: Optional[str]=None, peft_config: Optional['PeftConfig']=None): super().__init__(model, tokenizer, use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True) prepare_hf_model_for_fsdp(self.model, init_device) self.model.param_init_fn = lambda module: self.model._init_weights(module) def forward(self, batch: Mapping): if isinstance(batch, dict) or isinstance(batch, UserDict): batch = {k: v for k, v in batch.items() if k in self.model_forward_args} output = self.model(**batch) else: raise ValueError('Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model') return output def loss(self, outputs: ModelOutput, batch: Mapping): if self.config.use_return_dict: return outputs['loss'] return outputs[:2]