mpt-30b-instruct / hf_checkpointer.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
ef9200f verified
raw
history blame
14.8 kB
import contextlib
import copy
import logging
import math
import os
import re
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
import torch
from mlflow.transformers import _fetch_model_card, _write_license_information
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from .mpt import MPTConfig, MPTForCausalLM
from .utils import init_empty_weights
from .huggingface_hub_utils import edit_files_for_hf_compatibility
log = logging.getLogger(__name__)
_LICENSE_FILE_PATTERN = re.compile('license(\\.[a-z]+|$)', re.IGNORECASE)
def _maybe_get_license_filename(local_dir: str, pretrained_model_name: Optional[str]=None) -> Optional[str]:
"""Returns the name of the license file if it exists in the local_dir.
Note: This is intended to be consistent with the code in MLflow.
https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152
Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
in which case this function will use that to fetch the correct license information.
If the license file does not exist, returns None.
"""
try:
license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
if pretrained_model_name is not None:
log.info(f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub')
os.remove(os.path.join(local_dir, license_filename))
model_card = _fetch_model_card(pretrained_model_name)
local_dir_path = Path(local_dir).absolute()
_write_license_information(pretrained_model_name, model_card, local_dir_path)
license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
return license_filename
except StopIteration:
return None
class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Args:
save_folder (str): Top level folder to save checkpoints to (can be a
URI). It is likely that this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often
checkpoints should be saved. If an integer, it will be assumed to be
in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
:attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
huggingface_folder_name (str): Folder to save each checkpoint under (can
be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``.
Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the
model under in the MLflow model registry. If ``None``, the model
will not be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments
that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is
unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively. A default input example
and signature intended for text generation is also included under the
keys ``input_example`` and ``signature``.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
"""
def __init__(self, save_folder: str, save_interval: Union[str, int, Time], huggingface_folder_name: str='ba{batch}', precision: str='float32', overwrite: bool=True, mlflow_registered_model_name: Optional[str]=None, mlflow_logging_config: Optional[dict]=None, flatten_imports: Sequence[str]=('llmfoundry',)):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}[precision]
self.flatten_imports = flatten_imports
self.mlflow_registered_model_name = mlflow_registered_model_name
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
import numpy as np
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = passed_metadata
mlflow_logging_config.setdefault('task', 'llm/v1/completions')
default_input_example = {'prompt': np.array(['What is Machine Learning?'])}
is_chat = mlflow_logging_config['task'].endswith('chat') or mlflow_logging_config['metadata'].get('task', '').endswith('chat')
if is_chat:
default_input_example = {'messages': np.array([{'role': 'user', 'content': 'What is Machine Learning?'}])}
mlflow_logging_config.setdefault('example_no_conversion', True)
mlflow_logging_config.setdefault('input_example', default_input_example)
self.mlflow_logging_config = mlflow_logging_config
self.huggingface_folder_name_fstr = os.path.join('huggingface', huggingface_folder_name)
self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH)
self.check_interval = create_interval_scheduler(self.save_interval, include_end_of_training=True)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers=[])
if self.remote_ud is not None:
self.remote_ud._num_concurrent_uploads = 4
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []
def run_event(self, event: Event, state: State, logger: Logger) -> None:
if state.get_elapsed_duration() is not None and self.check_interval(state, event) and (self.last_checkpoint_batch != state.timestamp.batch):
self._save_checkpoint(state, logger)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.')
if self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)
if self.mlflow_registered_model_name is not None:
self.mlflow_loggers = [logger_destination for logger_destination in logger.destinations if isinstance(logger_destination, MLFlowLogger)]
if len(self.mlflow_loggers) == 0:
raise ValueError(f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.')
import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set('5GB')
def _is_last_batch(self, state: State):
elapsed_duration = state.get_elapsed_duration()
if elapsed_duration is not None and elapsed_duration >= 1.0:
return True
assert state.max_duration is not None
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and (state.max_duration.unit == TimeUnit.EPOCH):
assert state.dataloader_len is not None
return int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0
return False
def _save_checkpoint(self, state: State, logger: Logger):
del logger
self.last_checkpoint_batch = state.timestamp.batch
log.info('Saving HuggingFace formatted checkpoint')
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
save_dir = format_name_with_dist_and_time(str(Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory() if self.remote_ud is not None else contextlib.nullcontext(enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
assert isinstance(temp_save_dir, str)
log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if not state.is_model_ddp and isinstance(state_dict_model, FSDP) else contextlib.nullcontext()
with state_dict_context:
state_dict = state_dict_model.state_dict()
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')
copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
log.debug(f'Creating new model instance')
if composer_model.using_peft:
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)
new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config[active_adapter])
new_model_instance.to(dtype=self.dtype)
else:
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)
new_model_instance.load_state_dict(state_dict, assign=True)
del state_dict
log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
original_tokenizer.save_pretrained(temp_save_dir)
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir, self.flatten_imports)
if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(remote_file_name)
log.info(f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}')
self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite)
if self.mlflow_registered_model_name and self._is_last_batch(state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}')
local_save_path = str(Path(temp_save_dir) / f'mlflow_save_{i}')
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
model_saving_kwargs: Dict[str, Any] = {'path': local_save_path}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'] = temp_save_dir
model_saving_kwargs['metadata'] = self.mlflow_logging_config['metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs['transformers_model'] = components
model_saving_kwargs.update(self.mlflow_logging_config)
mlflow_logger.save_model(**model_saving_kwargs)
license_filename = _maybe_get_license_filename(local_save_path, self.mlflow_logging_config['metadata'].get('pretrained_model_name', None))
if license_filename is not None:
mlflow_logger._mlflow_client.log_artifact(mlflow_logger._run_id, os.path.join(local_save_path, license_filename))
mlflow_logger.register_model_with_run_id(model_uri=local_save_path, name=self.mlflow_registered_model_name, await_creation_for=3600)