File size: 14,772 Bytes
3ff9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import contextlib
import copy
import logging
import math
import os
import re
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union
import torch
from mlflow.transformers import _fetch_model_card, _write_license_information
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from .mpt import MPTConfig, MPTForCausalLM
from .utils import init_empty_weights
from .huggingface_hub_utils import edit_files_for_hf_compatibility
log = logging.getLogger(__name__)
_LICENSE_FILE_PATTERN = re.compile('license(\\.[a-z]+|$)', re.IGNORECASE)

def _maybe_get_license_filename(local_dir: str, pretrained_model_name: Optional[str]=None) -> Optional[str]:
    """Returns the name of the license file if it exists in the local_dir.

    Note: This is intended to be consistent with the code in MLflow.
    https://github.com/mlflow/mlflow/blob/5d13d6ec620a02de9a5e31201bf1becdb9722ea5/mlflow/transformers/__init__.py#L1152

    Since LLM Foundry supports local model files being used rather than fetching the files from the Hugging Face Hub,
    MLflow's logic to fetch and write the license information on model save is not applicable; it will try to search for
    a Hugging Face repo named after the local path. However, the user can provide the original pretrained model name,
    in which case this function will use that to fetch the correct license information.

    If the license file does not exist, returns None.
    """
    try:
        license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
        if pretrained_model_name is not None:
            log.info(f'Overwriting license file {license_filename} with license info for model {pretrained_model_name} from Hugging Face Hub')
            os.remove(os.path.join(local_dir, license_filename))
            model_card = _fetch_model_card(pretrained_model_name)
            local_dir_path = Path(local_dir).absolute()
            _write_license_information(pretrained_model_name, model_card, local_dir_path)
            license_filename = next((file for file in os.listdir(local_dir) if _LICENSE_FILE_PATTERN.search(file)))
        return license_filename
    except StopIteration:
        return None

class HuggingFaceCheckpointer(Callback):
    """Save a huggingface formatted checkpoint during training.

    Args:
        save_folder (str): Top level folder to save checkpoints to (can be a
            URI). It is likely that this would be the same as your save_folder.
        save_interval: Union[str, int, Time]: The interval describing how often
            checkpoints should be saved. If an integer, it will be assumed to be
            in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
            :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
            :attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
        huggingface_folder_name (str): Folder to save each checkpoint under (can
            be a format string). Default is ``ba{batch}``.
        precision: The precision to save the model in. Default is ``float32``.
            Options are ``bfloat16``, ``float16``, or ``float32``.
        overwrite (bool): Whether to overwrite previous checkpoints.
        mlflow_registered_model_name (Optional[str]): The name to register the
            model under in the MLflow model registry. If ``None``, the model
            will not be registered. Default is ``None``.
        mlflow_logging_config (Optional[dict]): A dictionary of config arguments
            that will get passed along to the MLflow ``save_model`` call.
            Expected to contain ``metadata`` and ``task`` keys. If either is
            unspecified, the defaults are ``'text-generation'`` and
            ``{'task': 'llm/v1/completions'}`` respectively. A default input example
            and signature intended for text generation is also included under the
            keys ``input_example`` and ``signature``.
        flatten_imports (Sequence[str]): A sequence of import prefixes that will
            be flattened when editing MPT files.
    """

    def __init__(self, save_folder: str, save_interval: Union[str, int, Time], huggingface_folder_name: str='ba{batch}', precision: str='float32', overwrite: bool=True, mlflow_registered_model_name: Optional[str]=None, mlflow_logging_config: Optional[dict]=None, flatten_imports: Sequence[str]=('llmfoundry',)):
        _, _, self.save_dir_format_str = parse_uri(save_folder)
        self.overwrite = overwrite
        self.precision = precision
        self.dtype = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}[precision]
        self.flatten_imports = flatten_imports
        self.mlflow_registered_model_name = mlflow_registered_model_name
        if mlflow_logging_config is None:
            mlflow_logging_config = {}
        if self.mlflow_registered_model_name is not None:
            import numpy as np
            passed_metadata = mlflow_logging_config.get('metadata', {})
            mlflow_logging_config['metadata'] = passed_metadata
            mlflow_logging_config.setdefault('task', 'llm/v1/completions')
            default_input_example = {'prompt': np.array(['What is Machine Learning?'])}
            is_chat = mlflow_logging_config['task'].endswith('chat') or mlflow_logging_config['metadata'].get('task', '').endswith('chat')
            if is_chat:
                default_input_example = {'messages': np.array([{'role': 'user', 'content': 'What is Machine Learning?'}])}
                mlflow_logging_config.setdefault('example_no_conversion', True)
            mlflow_logging_config.setdefault('input_example', default_input_example)
        self.mlflow_logging_config = mlflow_logging_config
        self.huggingface_folder_name_fstr = os.path.join('huggingface', huggingface_folder_name)
        self.save_interval: Time = Time.from_input(save_interval, TimeUnit.EPOCH)
        self.check_interval = create_interval_scheduler(self.save_interval, include_end_of_training=True)
        self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers=[])
        if self.remote_ud is not None:
            self.remote_ud._num_concurrent_uploads = 4
        self.last_checkpoint_batch: Optional[Time] = None
        self.mlflow_loggers = []

    def run_event(self, event: Event, state: State, logger: Logger) -> None:
        if state.get_elapsed_duration() is not None and self.check_interval(state, event) and (self.last_checkpoint_batch != state.timestamp.batch):
            self._save_checkpoint(state, logger)
        elif event == Event.INIT:
            if not isinstance(state.model, HuggingFaceModel):
                raise ValueError(f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.')
            if self.remote_ud is not None:
                self.remote_ud.init(state, logger)
                state.callbacks.append(self.remote_ud)
            if self.mlflow_registered_model_name is not None:
                self.mlflow_loggers = [logger_destination for logger_destination in logger.destinations if isinstance(logger_destination, MLFlowLogger)]
                if len(self.mlflow_loggers) == 0:
                    raise ValueError(f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. ' + 'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.')
                import mlflow
                mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set('5GB')

    def _is_last_batch(self, state: State):
        elapsed_duration = state.get_elapsed_duration()
        if elapsed_duration is not None and elapsed_duration >= 1.0:
            return True
        assert state.max_duration is not None
        if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and (state.max_duration.unit == TimeUnit.EPOCH):
            assert state.dataloader_len is not None
            return int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0
        return False

    def _save_checkpoint(self, state: State, logger: Logger):
        del logger
        self.last_checkpoint_batch = state.timestamp.batch
        log.info('Saving HuggingFace formatted checkpoint')
        from transformers.models.auto.configuration_auto import CONFIG_MAPPING
        CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
        MPTConfig.register_for_auto_class()
        MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')
        save_dir = format_name_with_dist_and_time(str(Path(self.save_dir_format_str) / self.huggingface_folder_name_fstr), state.run_name, state.timestamp)
        dir_context_mgr = tempfile.TemporaryDirectory() if self.remote_ud is not None else contextlib.nullcontext(enter_result=save_dir)
        with dir_context_mgr as temp_save_dir:
            assert isinstance(temp_save_dir, str)
            log.debug('Gathering state dict')
            from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
            if state.is_model_ddp:
                composer_model = state.model.module
                original_model: PreTrainedModel = state.model.module.model
                state_dict_model = state.model.module.model
                original_tokenizer = state.model.module.tokenizer
            elif isinstance(state.model.model, FSDP):
                composer_model = state.model
                original_model: PreTrainedModel = state.model.model.module
                state_dict_model = state.model.model
                original_tokenizer = state.model.tokenizer
            else:
                composer_model = state.model
                original_model: PreTrainedModel = state.model.model
                state_dict_model = state.model.model
                original_tokenizer = state.model.tokenizer
            state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if not state.is_model_ddp and isinstance(state_dict_model, FSDP) else contextlib.nullcontext()
            with state_dict_context:
                state_dict = state_dict_model.state_dict()
                for k, v in state_dict.items():
                    if isinstance(v, torch.Tensor):
                        state_dict[k] = v.to(dtype=self.dtype)
            if dist.get_global_rank() == 0:
                log.debug('Saving Hugging Face checkpoint in global rank 0')
                copied_config = copy.deepcopy(original_model.config)
                if copied_config.model_type == 'mpt':
                    copied_config.attn_config['attn_impl'] = 'torch'
                    copied_config.init_device = 'cpu'
                log.debug(f'Creating new model instance')
                if composer_model.using_peft:
                    active_adapter = original_model.active_adapter
                    base_model = original_model.get_base_model()
                    new_base_model_instance = type(base_model)(copied_config)
                    new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config[active_adapter])
                    new_model_instance.to(dtype=self.dtype)
                else:
                    with init_empty_weights():
                        new_model_instance = type(original_model)(copied_config)
                new_model_instance.load_state_dict(state_dict, assign=True)
                del state_dict
                log.debug('Saving Hugging Face checkpoint to disk')
                new_model_instance.save_pretrained(temp_save_dir)
                if original_tokenizer is not None:
                    assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
                    original_tokenizer.save_pretrained(temp_save_dir)
                if original_model.config.model_type == 'mpt':
                    log.debug('Editing MPT files for HuggingFace compatibility')
                    edit_files_for_hf_compatibility(temp_save_dir, self.flatten_imports)
                if self.remote_ud is not None:
                    for filename in os.listdir(temp_save_dir):
                        remote_file_name = os.path.join(save_dir, filename)
                        remote_file_uri = self.remote_ud.remote_backend.get_uri(remote_file_name)
                        log.info(f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}')
                        self.remote_ud.upload_file(state=state, remote_file_name=remote_file_name, file_path=Path(os.path.join(temp_save_dir, filename)), overwrite=self.overwrite)
                if self.mlflow_registered_model_name and self._is_last_batch(state):
                    components = {'model': new_model_instance}
                    if original_tokenizer is not None:
                        components['tokenizer'] = original_tokenizer
                    log.debug('Logging Hugging Face model to MLFlow')
                    for i, mlflow_logger in enumerate(self.mlflow_loggers):
                        log.debug(f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}')
                        local_save_path = str(Path(temp_save_dir) / f'mlflow_save_{i}')
                        import mlflow
                        mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
                        model_saving_kwargs: Dict[str, Any] = {'path': local_save_path}
                        if composer_model.using_peft:
                            model_saving_kwargs['flavor'] = 'peft'
                            model_saving_kwargs['save_pretrained_dir'] = temp_save_dir
                            model_saving_kwargs['metadata'] = self.mlflow_logging_config['metadata']
                        else:
                            model_saving_kwargs['flavor'] = 'transformers'
                            model_saving_kwargs['transformers_model'] = components
                            model_saving_kwargs.update(self.mlflow_logging_config)
                        mlflow_logger.save_model(**model_saving_kwargs)
                        license_filename = _maybe_get_license_filename(local_save_path, self.mlflow_logging_config['metadata'].get('pretrained_model_name', None))
                        if license_filename is not None:
                            mlflow_logger._mlflow_client.log_artifact(mlflow_logger._run_id, os.path.join(local_save_path, license_filename))
                        mlflow_logger.register_model_with_run_id(model_uri=local_save_path, name=self.mlflow_registered_model_name, await_creation_for=3600)