# Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers.generation.utils import GenerateNonBeamOutput from transformers.utils import logging, is_accelerate_available from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.logits_process import ( LogitsProcessorList, ) from transformers.generation.streamers import BaseStreamer from transformers.generation.stopping_criteria import ( StoppingCriteriaList, ) from transformers.utils import ModelOutput, logging import os logger = logging.get_logger(__name__) import collections import gc import itertools import os import re import shutil import tempfile from transformers import PreTrainedModel from transformers.integrations import is_deepspeed_zero3_enabled from transformers.pytorch_utils import id_tensor_storage from transformers.modeling_utils import ( is_fsdp_enabled, is_local_dist_rank_0, load_state_dict, set_initialized_submodules, _load_state_dict_into_model, _load_state_dict_into_meta_model, expand_device_map, get_disk_only_shard_files, get_disk_only_shard_files, ) if is_accelerate_available(): from accelerate.utils import ( find_tied_parameters, load_offloaded_weights, save_offload_index, set_module_tensor_to_device, ) from transformers.utils import logging from dataclasses import dataclass PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." @dataclass class GenerateDecoderOnlyOutput(ModelOutput): """ Outputs of decoder-only generation models, when using non-beam methods. Args: sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early due to the `eos_token_id`. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. """ sequences: torch.LongTensor = None scores: Optional[Tuple[torch.FloatTensor]] = None logits: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) new_key = key.replace("gamma", "weight") if "beta" in key and "vision_tower.vision_tower" not in key: logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata["assign_to_params_buffers"] = assign_to_params_buffers args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: if is_deepspeed_zero3_enabled(): import deepspeed # In sharded models, each shard has only part of the full state_dict, so only gather # parameters that are in the current state_dict. named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] if len(params_to_gather) > 0: # because zero3 puts placeholders in model params, this context # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) else: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".", assign_to_params_buffers) load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so # it's safe to delete it. del state_dict return error_msgs def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): """ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's parameters. Note: We fully disable this if we are using `deepspeed` """ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: return False if is_deepspeed_zero3_enabled(): return False # Some models explicitly do not support param buffer assignment if not getattr(model_to_load, "_supports_param_buffer_assignment", True): logger.debug( f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" ) return False # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype first_key = list(model_to_load.state_dict().keys())[0] if start_prefix + first_key in state_dict: return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype # For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) return False class BaseCausalLM(PreTrainedModel): def __init__(self, config): super().__init__(config) def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], logits_warper: Optional[LogitsProcessorList] = None, **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in `generation_config`) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ # init values pad_token_id = generation_config.pad_token_id output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): raise ValueError( "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " f"{logits_warper})." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if do_sample: next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) probs = nn.functional.softmax(next_token_scores, dim=-1) # token selection if do_sample: next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 if streamer is not None: streamer.end() if return_dict_in_generate: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids @classmethod def _load_pretrained_model( cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes=False, sharded_metadata=None, _fast_init=True, low_cpu_mem_usage=False, device_map=None, offload_folder=None, offload_state_dict=None, dtype=None, hf_quantizer=None, keep_in_fp32_modules=None, gguf_path=None, ): is_safetensors = False is_quantized = hf_quantizer is not None state_dict_folder = None state_dict_index = None if device_map is not None and "disk" in device_map.values(): archive_file = ( resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file ) is_safetensors = archive_file.endswith(".safetensors") if offload_folder is None and not is_safetensors: raise ValueError( "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" " offers the weights in this format." ) if offload_folder is not None: os.makedirs(offload_folder, exist_ok=True) if offload_state_dict is None: offload_state_dict = True is_sharded_safetensors = is_safetensors and sharded_metadata is not None for key, param in model.state_dict().items(): if param.device == torch.device("meta"): try: set_module_tensor_to_device( model, key, "cuda", torch.empty(*param.size(), dtype=dtype) ) except: pass # tie the model weights before retrieving the state_dict model.tie_weights() # Retrieve missing & unexpected_keys model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) prefix = model.base_model_prefix def _fix_key(key): if "beta" in key and "vision_tower.vision_tower" not in key: return key.replace("beta", "bias") if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): return key.replace("gamma", "weight") return key original_loaded_keys = loaded_keys loaded_keys = [_fix_key(key) for key in loaded_keys] if len(prefix) > 0: has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) else: has_prefix_module = False expects_prefix_module = False # key re-naming operations are never done on the keys # that are loaded, but always on the keys of the newly initialized model remove_prefix_from_model = not has_prefix_module and expects_prefix_module add_prefix_to_model = has_prefix_module and not expects_prefix_module if remove_prefix_from_model: _prefix = f"{prefix}." expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] elif add_prefix_to_model: expected_keys = [".".join([prefix, s]) for s in expected_keys] missing_keys = sorted(set(expected_keys) - set(loaded_keys)) unexpected_keys = set(loaded_keys) - set(expected_keys) # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model # buffers model_buffers = {n for n, _ in model.named_buffers()} if remove_prefix_from_model: model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} elif add_prefix_to_model: model_buffers = {".".join([prefix, key]) for key in model_buffers} unexpected_keys = sorted(unexpected_keys - model_buffers) model.tie_weights() if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): ptrs = collections.defaultdict(list) for name, tensor in model.state_dict().items(): id_tensor = id_tensor_storage(tensor) ptrs[id_tensor].append(name) # These are all the pointers of shared tensors. tied_params = [names for _, names in ptrs.items() if len(names) > 1] else: # id function doesn't work for meta tensor so we need this function tied_params = find_tied_parameters(model) for group in tied_params: if remove_prefix_from_model: group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] elif add_prefix_to_model: group = [".".join([prefix, key]) for key in group] missing_in_group = [k for k in missing_keys if k in group] if len(missing_in_group) > 0 and len(missing_in_group) < len(group): missing_keys = [k for k in missing_keys if k not in missing_in_group] # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. if cls._keys_to_ignore_on_load_missing is not None: for pat in cls._keys_to_ignore_on_load_missing: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] if cls._keys_to_ignore_on_load_unexpected is not None: for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if hf_quantizer is not None: missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) # retrieve weights on meta device and put them back on CPU. # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step if low_cpu_mem_usage: for key in missing_keys: if key in list(model_state_dict.keys()): key = key elif f"{prefix}.{key}" in list(model_state_dict.keys()): key = f"{prefix}.{key}" elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): key = ".".join(key.split(".")[1:]) param = model_state_dict[key] # upcast in fp32 if any target_dtype = dtype if ( keep_in_fp32_modules is not None and dtype == torch.float16 and any( module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules ) ): target_dtype = torch.float32 if param.device == torch.device("meta"): value = torch.empty(*param.size(), dtype=target_dtype) if ( not is_quantized or getattr(hf_quantizer, "requires_parameters_quantization", False) or not hf_quantizer.check_quantized_param( model, param_value=value, param_name=key, state_dict={} ) ): set_module_tensor_to_device(model, key, "cpu", value) else: hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: if not ignore_mismatched_sizes: if remove_prefix_from_model: _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] elif add_prefix_to_model: _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] else: _loaded_keys = loaded_keys not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) # If we're about to tie the output embeds to the input embeds we don't need to init them if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: output_embeddings = model.get_output_embeddings() if output_embeddings is not None: # Still need to initialize if there is a bias term since biases are not tied. if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: output_embeddings._is_hf_initialized = True else: not_initialized_submodules = dict(model.named_modules()) # This will only initialize submodules that are not marked as initialized by the line above. if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed not_initialized_parameters = list( set( itertools.chain.from_iterable( submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() ) ) ) with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): model.apply(model._initialize_weights) else: model.apply(model._initialize_weights) # Set some modules to fp32 if any if keep_in_fp32_modules is not None: for name, param in model.named_parameters(): if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): # param = param.to(torch.float32) does not work here as only in the local scope. param.data = param.data.to(torch.float32) # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: start_prefix = cls.base_model_prefix + "." if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: model_to_load = getattr(model, cls.base_model_prefix) base_model_expected_keys = list(model_to_load.state_dict().keys()) if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): raise ValueError( "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " "properly saved?" ) if device_map is not None: device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} def _find_mismatched_keys( state_dict, model_state_dict, loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: # If the checkpoint is sharded, we may not have the key here. if checkpoint_key not in state_dict: continue model_key = checkpoint_key if remove_prefix_from_model: # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. model_key = f"{prefix}.{checkpoint_key}" elif add_prefix_to_model: # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. model_key = ".".join(checkpoint_key.split(".")[1:]) if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): if ( state_dict[checkpoint_key].shape[-1] == 1 and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() ): # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. # Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. pass else: mismatched_keys.append( (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) ) del state_dict[checkpoint_key] return mismatched_keys if resolved_archive_file is not None: folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) else: folder = None if device_map is not None and is_safetensors: param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" if sharded_metadata is None: archive_file = ( resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file ) weight_map = {p: archive_file for p in original_loaded_keys} else: weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} offload_index = { p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} for p, f in weight_map.items() if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" } else: offload_index = None if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes, ) # For GGUF models `state_dict` is never set to None as the state dict is always small if gguf_path: error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, loaded_keys, start_prefix, expected_keys, device_map=device_map, offload_folder=offload_folder, offload_index=offload_index, state_dict_folder=state_dict_folder, state_dict_index=state_dict_index, dtype=dtype, hf_quantizer=hf_quantizer, is_safetensors=is_safetensors, keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) else: # Sharded checkpoint or whole but low_cpu_mem_usage==True assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix ) error_msgs = _load_state_dict_into_model( model_to_load, state_dict, start_prefix, assign_to_params_buffers ) else: # This should always be a list but, just to be sure. if not isinstance(resolved_archive_file, list): resolved_archive_file = [resolved_archive_file] error_msgs = [] mismatched_keys = [] if not is_safetensors: offload_index = {} if device_map is not None and "disk" in device_map.values() else None if offload_state_dict: state_dict_folder = tempfile.mkdtemp() state_dict_index = {} else: state_dict_folder = None state_dict_index = None if is_sharded_safetensors: disk_only_shard_files = get_disk_only_shard_files( device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix ) disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] else: disk_only_shard_files = [] if len(resolved_archive_file) > 1: resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") assign_to_params_buffers = None for shard_file in resolved_archive_file: # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. if shard_file in disk_only_shard_files: continue state_dict = load_state_dict(shard_file, is_quantized=is_quantized) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys += _find_mismatched_keys( state_dict, model_state_dict, original_loaded_keys, add_prefix_to_model, remove_prefix_from_model, ignore_mismatched_sizes, ) if low_cpu_mem_usage: if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: for key, param in model_to_load.state_dict().items(): if param.device == torch.device("meta"): set_module_tensor_to_device( model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) ) else: new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, state_dict, loaded_keys, start_prefix, expected_keys, device_map=device_map, offload_folder=offload_folder, offload_index=offload_index, state_dict_folder=state_dict_folder, state_dict_index=state_dict_index, dtype=dtype, hf_quantizer=hf_quantizer, is_safetensors=is_safetensors, keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, ) error_msgs += new_error_msgs else: # Sharded checkpoint or whole but low_cpu_mem_usage==True if assign_to_params_buffers is None: assign_to_params_buffers = check_support_param_buffer_assignment( model_to_load, state_dict, start_prefix ) error_msgs += _load_state_dict_into_model( model_to_load, state_dict, start_prefix, assign_to_params_buffers ) # force memory release del state_dict gc.collect() if offload_index is not None and len(offload_index) > 0: if model != model_to_load: # We need to add the prefix of the base model prefix = cls.base_model_prefix if not is_safetensors: for weight_name in offload_index: shutil.move( os.path.join(offload_folder, f"{weight_name}.dat"), os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), ) offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} if not is_safetensors: save_offload_index(offload_index, offload_folder) offload_index = None if offload_state_dict: # Load back temporarily offloaded state dict load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) shutil.rmtree(state_dict_folder) if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += ( "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." ) raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") if len(unexpected_keys) > 0: archs = [] if model.config.architectures is None else model.config.architectures warner = logger.warning if model.__class__.__name__ in archs else logger.info warner( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" " with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {model.__class__.__name__} for predictions without further" " training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" " to use it for predictions and inference." ) return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs