import contextlib import copy import logging import math import os import re import tempfile from pathlib import Path from typing import Any, Dict, Optional, Sequence, Union import torch from mlflow.transformers import _fetch_model_card, _write_license_information from transformers import PreTrainedModel, PreTrainedTokenizerBase from .mpt import MPTConfig, MPTForCausalLM from .utils import init_empty_weights from .huggingface_hub_utils import edit_files_for_hf_compatibility log = logging.getLogger(__name__) _LICENSE_FILE_PATTERN = re.compile('license(\\.[a-z]+|$)', re.IGNORECASE) def _maybe_get_license_filename(local_dir: str, pretrained_model_name: Optional[str]=None) -> Optional[str]: """Returns the name of the license file if it exists in the local_dir. Note: This is intended to be consistent with the code in MLflow. https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152 Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub, MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name, in which case this function will use that to fetch the correct license information. If the license file does not exist, returns None. """ try: license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file))) if pretrained_model_name is not None: log.info(f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub') os.remove(os.path.join(local_dir, license_filename)) model_card = _fetch_model_card(pretrained_model_name) local_dir_path = Path(local_dir).absolute() _write_license_information(pretrained_model_name, model_card, local_dir_path) license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file))) return license_filename except StopIteration: return None class HuggingFaceCheckpointer(Callback): """Save a huggingface formatted checkpoint during training. Args: save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that this would be the same as your save_folder. save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``. precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``. overwrite (bool): Whether to overwrite previous checkpoints. mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not be registered. Default is ``None``. mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call. Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and ``{'task': 'llm/v1/completions'}`` respectively. A default input example and signature intended for text generation is also included under the keys ``input_example`` and ``signature``. flatten_imports (Sequence[str]): A sequence of import prefixes that will be flattened when editing MPT files. """ def __init__(self, save_folder: str, save_interval: Union[str, int, Time], huggingface_folder_name: str='ba{batch}', precision: str='float32', overwrite: bool=True, mlflow_registered_model_name: Optional[str]=None, mlflow_logging_config: Optional[dict]=None, flatten_imports: Sequence[str]=('llmfoundry',)): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite self.precision = precision self.dtype = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}[precision] self.flatten_imports = flatten_imports self.mlflow_registered_model_name = mlflow_registered_model_name if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: import numpy as np passed_metadata = mlflow_logging_config.get('metadata', {}) mlflow_logging_config['metadata'] = passed_metadata mlflow_logging_config.setdefault('task', 'llm/v1/completions') default_input_example = {'prompt': np.array(['What is Machine Learning?'])} is_chat = mlflow_logging_config['task'].endswith('chat') or mlflow_logging_config['metadata'].get('task', '').endswith('chat') if is_chat: default_input_example = {'messages': np.array([{'role': 'user', 'content': 'What is Machine Learning?'}])} mlflow_logging_config.setdefault('example_no_conversion', True) mlflow_logging_config.setdefault('input_example', default_input_example) self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join('huggingface', huggingface_folder_name) self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH) self.check_interval = create_interval_scheduler(self.save_interval, include_end_of_training=True) self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers=[]) if self.remote_ud is not None: self.remote_ud._num_concurrent_uploads = 4 self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] def run_event(self, event: Event, state: State, logger: Logger) -> None: if state.get_elapsed_duration() is not None and self.check_interval(state, event) and (self.last_checkpoint_batch != state.timestamp.batch): self._save_checkpoint(state, logger) elif event == Event.INIT: if not isinstance(state.model, HuggingFaceModel): raise ValueError(f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.') if self.remote_ud is not None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) if self.mlflow_registered_model_name is not None: self.mlflow_loggers = [logger_destination for logger_destination in logger.destinations if isinstance(logger_destination, MLFlowLogger)] if len(self.mlflow_loggers) == 0: raise ValueError(f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.') import mlflow mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set('5GB') def _is_last_batch(self, state: State): elapsed_duration = state.get_elapsed_duration() if elapsed_duration is not None and elapsed_duration >= 1.0: return True assert state.max_duration is not None if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and (state.max_duration.unit == TimeUnit.EPOCH): assert state.dataloader_len is not None return int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0 return False def _save_checkpoint(self, state: State, logger: Logger): del logger self.last_checkpoint_batch = state.timestamp.batch log.info('Saving HuggingFace formatted checkpoint') from transformers.models.auto.configuration_auto import CONFIG_MAPPING CONFIG_MAPPING._extra_content['mpt'] = MPTConfig MPTConfig.register_for_auto_class() MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') save_dir = format_name_with_dist_and_time(str(Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp) dir_context_mgr = tempfile.TemporaryDirectory() if self.remote_ud is not None else contextlib.nullcontext(enter_result=save_dir) with dir_context_mgr as temp_save_dir: assert isinstance(temp_save_dir, str) log.debug('Gathering state dict') from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: composer_model = state.model.module original_model: PreTrainedModel = state.model.module.model state_dict_model = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): composer_model = state.model original_model: PreTrainedModel = state.model.model.module state_dict_model = state.model.model original_tokenizer = state.model.tokenizer else: composer_model = state.model original_model: PreTrainedModel = state.model.model state_dict_model = state.model.model original_tokenizer = state.model.tokenizer state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if not state.is_model_ddp and isinstance(state_dict_model, FSDP) else contextlib.nullcontext() with state_dict_context: state_dict = state_dict_model.state_dict() for k, v in state_dict.items(): if isinstance(v, torch.Tensor): state_dict[k] = v.to(dtype=self.dtype) if dist.get_global_rank() == 0: log.debug('Saving Hugging Face checkpoint in global rank 0') copied_config = copy.deepcopy(original_model.config) if copied_config.model_type == 'mpt': copied_config.attn_config['attn_impl'] = 'torch' copied_config.init_device = 'cpu' log.debug(f'Creating new model instance') if composer_model.using_peft: active_adapter = original_model.active_adapter base_model = original_model.get_base_model() new_base_model_instance = type(base_model)(copied_config) new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config[active_adapter]) new_model_instance.to(dtype=self.dtype) else: with init_empty_weights(): new_model_instance = type(original_model)(copied_config) new_model_instance.load_state_dict(state_dict, assign=True) del state_dict log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: assert isinstance(original_tokenizer, PreTrainedTokenizerBase) original_tokenizer.save_pretrained(temp_save_dir) if original_model.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility(temp_save_dir, self.flatten_imports) if self.remote_ud is not None: for filename in os.listdir(temp_save_dir): remote_file_name = os.path.join(save_dir, filename) remote_file_uri = self.remote_ud.remote_backend.get_uri(remote_file_name) log.info(f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}') self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite) if self.mlflow_registered_model_name and self._is_last_batch(state): components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer log.debug('Logging Hugging Face model to MLFlow') for i, mlflow_logger in enumerate(self.mlflow_loggers): log.debug(f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}') local_save_path = str(Path(temp_save_dir) / f'mlflow_save_{i}') import mlflow mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' model_saving_kwargs: Dict[str, Any] = {'path': local_save_path} if composer_model.using_peft: model_saving_kwargs['flavor'] = 'peft' model_saving_kwargs['save_pretrained_dir'] = temp_save_dir model_saving_kwargs['metadata'] = self.mlflow_logging_config['metadata'] else: model_saving_kwargs['flavor'] = 'transformers' model_saving_kwargs['transformers_model'] = components model_saving_kwargs.update(self.mlflow_logging_config) mlflow_logger.save_model(**model_saving_kwargs) license_filename = _maybe_get_license_filename(local_save_path, self.mlflow_logging_config['metadata'].get('pretrained_model_name', None)) if license_filename is not None: mlflow_logger._mlflow_client.log_artifact(mlflow_logger._run_id, os.path.join(local_save_path, license_filename)) mlflow_logger.register_model_with_run_id(model_uri=local_save_path, name=self.mlflow_registered_model_name, await_creation_for=3600)