mpt-7b-storywriter / model_wrapper.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
b7d8df8 verified
raw
history blame
1.82 kB
"""Re-usable :class:`.ComposerModel` for LLM HF Models."""
from __future__ import annotations
from collections import UserDict
from typing import TYPE_CHECKING, List, Mapping, Optional
import transformers
from torchmetrics import Metric
from transformers import PreTrainedTokenizerBase
from transformers.utils.generic import ModelOutput
from .hf_fsdp import prepare_hf_model_for_fsdp
if TYPE_CHECKING:
from peft import PeftConfig
_HF_IGNORE_INDEX = -100
class HuggingFaceModelWithFSDP(HuggingFaceModel):
"""Wrapper around HuggingFaceModel.
Handles preparation for FSDP wrapping.
"""
def __init__(self, model: transformers.PreTrainedModel, tokenizer: Optional[PreTrainedTokenizerBase]=None, metrics: Optional[List[Metric]]=None, eval_metrics: Optional[List[Metric]]=None, shift_labels: bool=False, init_device: Optional[str]=None, peft_config: Optional['PeftConfig']=None):
super().__init__(model, tokenizer, use_logits=True, metrics=metrics, eval_metrics=eval_metrics, shift_labels=shift_labels, peft_config=peft_config, should_save_peft_only=True)
prepare_hf_model_for_fsdp(self.model, init_device)
self.model.param_init_fn = lambda module: self.model._init_weights(module)
def forward(self, batch: Mapping):
if isinstance(batch, dict) or isinstance(batch, UserDict):
batch = {k: v for k, v in batch.items() if k in self.model_forward_args}
output = self.model(**batch)
else:
raise ValueError('Unexpected batch type. Expected a dictionary with keys corresponding to the inputs to the forward function of the Huggingface model')
return output
def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
return outputs['loss']
return outputs[:2]