Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gc | |
import json | |
import os | |
import re | |
import warnings | |
from functools import partial | |
from pickle import UnpicklingError | |
from typing import Any, Dict, Optional, Set, Tuple, Union | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import msgpack.exceptions | |
from flax.core.frozen_dict import FrozenDict, unfreeze | |
from flax.serialization import from_bytes, to_bytes | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from jax.random import PRNGKey | |
from .configuration_utils import PretrainedConfig | |
from .dynamic_module_utils import custom_object_save | |
from .generation import FlaxGenerationMixin, GenerationConfig | |
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict | |
from .utils import ( | |
FLAX_WEIGHTS_INDEX_NAME, | |
FLAX_WEIGHTS_NAME, | |
SAFE_WEIGHTS_INDEX_NAME, | |
SAFE_WEIGHTS_NAME, | |
WEIGHTS_INDEX_NAME, | |
WEIGHTS_NAME, | |
PushToHubMixin, | |
add_code_sample_docstrings, | |
add_start_docstrings_to_model_forward, | |
cached_file, | |
copy_func, | |
download_url, | |
has_file, | |
is_offline_mode, | |
is_remote_url, | |
logging, | |
replace_return_docstrings, | |
) | |
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files | |
from .utils.import_utils import is_safetensors_available | |
if is_safetensors_available(): | |
from safetensors import safe_open | |
from safetensors.flax import load_file as safe_load_file | |
from safetensors.flax import save_file as safe_save_file | |
logger = logging.get_logger(__name__) | |
def quick_gelu(x): | |
return x * jax.nn.sigmoid(1.702 * x) | |
ACT2FN = { | |
"gelu": partial(nn.gelu, approximate=False), | |
"relu": nn.relu, | |
"silu": nn.swish, | |
"swish": nn.swish, | |
"gelu_new": partial(nn.gelu, approximate=True), | |
"quick_gelu": quick_gelu, | |
} | |
def dtype_byte_size(dtype): | |
""" | |
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example: | |
```py | |
>>> dtype_byte_size(np.float32) | |
4 | |
``` | |
""" | |
if dtype == bool: | |
return 1 / 8 | |
bit_search = re.search(r"[^\d](\d+)$", dtype.name) | |
if bit_search is None: | |
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") | |
bit_size = int(bit_search.groups()[0]) | |
return bit_size // 8 | |
def flax_shard_checkpoint(params, max_shard_size="10GB"): | |
""" | |
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a | |
given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so | |
there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For | |
example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as | |
[6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. | |
<Tip warning={true}> | |
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will | |
have a size greater than `max_shard_size`. | |
</Tip> | |
Args: | |
params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. | |
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): | |
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit | |
(like `"5MB"`). | |
""" | |
max_shard_size = convert_file_size_to_int(max_shard_size) | |
sharded_state_dicts = [] | |
current_block = {} | |
current_block_size = 0 | |
total_size = 0 | |
# flatten the weights to chunk | |
weights = flatten_dict(params, sep="/") | |
for item in weights: | |
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype) | |
# If this weight is going to tip up over the maximal size, we split. | |
if current_block_size + weight_size > max_shard_size: | |
sharded_state_dicts.append(current_block) | |
current_block = {} | |
current_block_size = 0 | |
current_block[item] = weights[item] | |
current_block_size += weight_size | |
total_size += weight_size | |
# Add the last block | |
sharded_state_dicts.append(current_block) | |
# If we only have one shard, we return it | |
if len(sharded_state_dicts) == 1: | |
return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None | |
# Otherwise, let's build the index | |
weight_map = {} | |
shards = {} | |
for idx, shard in enumerate(sharded_state_dicts): | |
shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack") | |
shards[shard_file] = shard | |
for weight_name in shard.keys(): | |
weight_map[weight_name] = shard_file | |
# Add the metadata | |
metadata = {"total_size": total_size} | |
index = {"metadata": metadata, "weight_map": weight_map} | |
return shards, index | |
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): | |
r""" | |
Base class for all models. | |
[`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, | |
downloading and saving models. | |
Class attributes (overridden by derived classes): | |
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class | |
for this model architecture. | |
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived | |
classes of the same architecture adding modules on top of the base model. | |
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP | |
models, `pixel_values` for vision models and `input_values` for speech models). | |
""" | |
config_class = None | |
base_model_prefix = "" | |
main_input_name = "input_ids" | |
_auto_class = None | |
_missing_keys = set() | |
def __init__( | |
self, | |
config: PretrainedConfig, | |
module: nn.Module, | |
input_shape: Tuple = (1, 1), | |
seed: int = 0, | |
dtype: jnp.dtype = jnp.float32, | |
_do_init: bool = True, | |
): | |
if config is None: | |
raise ValueError("config cannot be None") | |
if module is None: | |
raise ValueError("module cannot be None") | |
# Those are private to be exposed as typed property on derived classes. | |
self._config = config | |
self._module = module | |
# Those are public as their type is generic to every derived classes. | |
self.key = PRNGKey(seed) | |
self.dtype = dtype | |
self.input_shape = input_shape | |
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None | |
# To check if the model was intialized automatically. | |
self._is_initialized = _do_init | |
if _do_init: | |
# randomly initialized parameters | |
random_params = self.init_weights(self.key, input_shape) | |
params_shape_tree = jax.eval_shape(lambda params: params, random_params) | |
else: | |
init_fn = partial(self.init_weights, input_shape=input_shape) | |
params_shape_tree = jax.eval_shape(init_fn, self.key) | |
logger.info( | |
"Model weights are not initialized as `_do_init` is set to `False`. " | |
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights." | |
) | |
# get the shape of the parameters | |
self._params_shape_tree = params_shape_tree | |
# save required_params as set | |
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) | |
# initialize the parameters | |
if _do_init: | |
self.params = random_params | |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: | |
raise NotImplementedError(f"init method has to be implemented for {self}") | |
def enable_gradient_checkpointing(self): | |
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}") | |
def _from_config(cls, config, **kwargs): | |
""" | |
All context managers that the model should be initialized under go here. | |
""" | |
return cls(config, **kwargs) | |
def framework(self) -> str: | |
""" | |
:str: Identifies that this is a Flax model. | |
""" | |
return "flax" | |
def config(self) -> PretrainedConfig: | |
return self._config | |
def module(self) -> nn.Module: | |
return self._module | |
def params(self) -> Union[Dict, FrozenDict]: | |
if not self._is_initialized: | |
raise ValueError( | |
"`params` cannot be accessed from model when the model is created with `_do_init=False`. " | |
"You must call `init_weights` manually and store the params outside of the model and " | |
"pass it explicitly where needed." | |
) | |
return self._params | |
def required_params(self) -> Set: | |
return self._required_params | |
def params_shape_tree(self) -> Dict: | |
return self._params_shape_tree | |
def params(self, params: Union[Dict, FrozenDict]): | |
# don't set params if the model is not initialized | |
if not self._is_initialized: | |
raise ValueError( | |
"`params` cannot be set from model when the model is created with `_do_init=False`. " | |
"You store the params outside of the model." | |
) | |
if isinstance(params, FrozenDict): | |
params = unfreeze(params) | |
param_keys = set(flatten_dict(params).keys()) | |
if len(self.required_params - param_keys) > 0: | |
raise ValueError( | |
"Some parameters are missing. Make sure that `params` include the following " | |
f"parameters {self.required_params - param_keys}" | |
) | |
self._params = params | |
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: | |
""" | |
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. | |
""" | |
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 | |
def conditional_cast(param): | |
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): | |
param = param.astype(dtype) | |
return param | |
if mask is None: | |
return jax.tree_util.tree_map(conditional_cast, params) | |
flat_params = flatten_dict(params) | |
flat_mask, _ = jax.tree_util.tree_flatten(mask) | |
for masked, key in zip(flat_mask, flat_params.keys()): | |
if masked: | |
param = flat_params[key] | |
flat_params[key] = conditional_cast(param) | |
return unflatten_dict(flat_params) | |
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
r""" | |
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast | |
the `params` in place. | |
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full | |
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. | |
Arguments: | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
mask (`Union[Dict, FrozenDict]`): | |
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params | |
you want to cast, and should be `False` for those you want to skip. | |
Examples: | |
```python | |
>>> from transformers import FlaxBertModel | |
>>> # load model | |
>>> model = FlaxBertModel.from_pretrained("bert-base-cased") | |
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision | |
>>> model.params = model.to_bf16(model.params) | |
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) | |
>>> # then pass the mask as follows | |
>>> from flax import traverse_util | |
>>> model = FlaxBertModel.from_pretrained("bert-base-cased") | |
>>> flat_params = traverse_util.flatten_dict(model.params) | |
>>> mask = { | |
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) | |
... for path in flat_params | |
... } | |
>>> mask = traverse_util.unflatten_dict(mask) | |
>>> model.params = model.to_bf16(model.params, mask) | |
```""" | |
return self._cast_floating_to(params, jnp.bfloat16, mask) | |
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
r""" | |
Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the | |
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. | |
Arguments: | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
mask (`Union[Dict, FrozenDict]`): | |
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params | |
you want to cast, and should be `False` for those you want to skip | |
Examples: | |
```python | |
>>> from transformers import FlaxBertModel | |
>>> # Download model and configuration from huggingface.co | |
>>> model = FlaxBertModel.from_pretrained("bert-base-cased") | |
>>> # By default, the model params will be in fp32, to illustrate the use of this method, | |
>>> # we'll first cast to fp16 and back to fp32 | |
>>> model.params = model.to_f16(model.params) | |
>>> # now cast back to fp32 | |
>>> model.params = model.to_fp32(model.params) | |
```""" | |
return self._cast_floating_to(params, jnp.float32, mask) | |
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): | |
r""" | |
Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the | |
`params` in place. | |
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full | |
half-precision training or to save weights in float16 for inference in order to save memory and improve speed. | |
Arguments: | |
params (`Union[Dict, FrozenDict]`): | |
A `PyTree` of model parameters. | |
mask (`Union[Dict, FrozenDict]`): | |
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params | |
you want to cast, and should be `False` for those you want to skip | |
Examples: | |
```python | |
>>> from transformers import FlaxBertModel | |
>>> # load model | |
>>> model = FlaxBertModel.from_pretrained("bert-base-cased") | |
>>> # By default, the model params will be in fp32, to cast these to float16 | |
>>> model.params = model.to_fp16(model.params) | |
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) | |
>>> # then pass the mask as follows | |
>>> from flax import traverse_util | |
>>> model = FlaxBertModel.from_pretrained("bert-base-cased") | |
>>> flat_params = traverse_util.flatten_dict(model.params) | |
>>> mask = { | |
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) | |
... for path in flat_params | |
... } | |
>>> mask = traverse_util.unflatten_dict(mask) | |
>>> model.params = model.to_fp16(model.params, mask) | |
```""" | |
return self._cast_floating_to(params, jnp.float16, mask) | |
def load_flax_weights(cls, resolved_archive_file): | |
try: | |
if resolved_archive_file.endswith(".safetensors"): | |
state = safe_load_file(resolved_archive_file) | |
state = unflatten_dict(state, sep=".") | |
else: | |
with open(resolved_archive_file, "rb") as state_f: | |
state = from_bytes(cls, state_f.read()) | |
except (UnpicklingError, msgpack.exceptions.ExtraData) as e: | |
try: | |
with open(resolved_archive_file) as f: | |
if f.read().startswith("version"): | |
raise OSError( | |
"You seem to have cloned a repository without having git-lfs installed. Please" | |
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the" | |
" folder you cloned." | |
) | |
else: | |
raise ValueError from e | |
except (UnicodeDecodeError, ValueError): | |
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ") | |
return state | |
def load_flax_sharded_weights(cls, shard_files): | |
""" | |
This is the same as [`flax.serialization.from_bytes`] | |
(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint. | |
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being | |
loaded in the model. | |
Args: | |
shard_files (`List[str]`: | |
The list of shard files to load. | |
Returns: | |
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model': | |
{'params': {'...'}}}`. | |
""" | |
# Load the index | |
state_sharded_dict = {} | |
for shard_file in shard_files: | |
# load using msgpack utils | |
try: | |
with open(shard_file, "rb") as state_f: | |
state = from_bytes(cls, state_f.read()) | |
except (UnpicklingError, msgpack.exceptions.ExtraData) as e: | |
with open(shard_file) as f: | |
if f.read().startswith("version"): | |
raise OSError( | |
"You seem to have cloned a repository without having git-lfs installed. Please" | |
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the" | |
" folder you cloned." | |
) | |
else: | |
raise ValueError from e | |
except (UnicodeDecodeError, ValueError): | |
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ") | |
state = flatten_dict(state, sep="/") | |
state_sharded_dict.update(state) | |
del state | |
gc.collect() | |
# the state dict is unflattened to the match the format of model.params | |
return unflatten_dict(state_sharded_dict, sep="/") | |
def can_generate(cls) -> bool: | |
""" | |
Returns whether this model can generate sequences with `.generate()`. Returns: | |
`bool`: Whether this model can generate sequences with `.generate()`. | |
""" | |
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. | |
# Alternativelly, the model can also have a custom `generate` function. | |
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): | |
return False | |
return True | |
def from_pretrained( | |
cls, | |
pretrained_model_name_or_path: Union[str, os.PathLike], | |
dtype: jnp.dtype = jnp.float32, | |
*model_args, | |
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, | |
cache_dir: Optional[Union[str, os.PathLike]] = None, | |
ignore_mismatched_sizes: bool = False, | |
force_download: bool = False, | |
local_files_only: bool = False, | |
token: Optional[Union[str, bool]] = None, | |
revision: str = "main", | |
**kwargs, | |
): | |
r""" | |
Instantiate a pretrained flax model from a pre-trained model configuration. | |
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come | |
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning | |
task. | |
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those | |
weights are discarded. | |
Parameters: | |
pretrained_model_name_or_path (`str` or `os.PathLike`): | |
Can be either: | |
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. | |
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a | |
user or organization name, like `dbmdz/bert-base-german-cased`. | |
- A path to a *directory* containing model weights saved using | |
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. | |
- A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, | |
`from_pt` should be set to `True`. | |
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): | |
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and | |
`jax.numpy.bfloat16` (on TPUs). | |
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If | |
specified all the computation will be performed with the given `dtype`. | |
**Note that this only specifies the dtype of the computation and does not influence the dtype of model | |
parameters.** | |
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and | |
[`~FlaxPreTrainedModel.to_bf16`]. | |
model_args (sequence of positional arguments, *optional*): | |
All remaining positional arguments will be passed to the underlying model's `__init__` method. | |
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): | |
Can be either: | |
- an instance of a class derived from [`PretrainedConfig`], | |
- a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. | |
Configuration for the model to use instead of an automatically loaded configuration. Configuration can | |
be automatically loaded when: | |
- The model is a model provided by the library (loaded with the *model id* string of a pretrained | |
model). | |
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the | |
save directory. | |
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a | |
configuration JSON file named *config.json* is found in the directory. | |
cache_dir (`Union[str, os.PathLike]`, *optional*): | |
Path to a directory in which a downloaded pretrained model configuration should be cached if the | |
standard cache should not be used. | |
from_pt (`bool`, *optional*, defaults to `False`): | |
Load the model weights from a PyTorch checkpoint save file (see docstring of | |
`pretrained_model_name_or_path` argument). | |
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): | |
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size | |
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a | |
checkpoint with 3 labels). | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
cached versions if they exist. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to delete incompletely received files. Will attempt to resume the download if such a | |
file exists. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
local_files_only(`bool`, *optional*, defaults to `False`): | |
Whether or not to only look at local files (i.e., do not try to download the model). | |
token (`str` or `bool`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use | |
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). | |
revision (`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any | |
identifier allowed by git. | |
<Tip> | |
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>". | |
</Tip> | |
subfolder (`str`, *optional*, defaults to `""`): | |
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can | |
specify the folder name here. | |
kwargs (remaining dictionary of keyword arguments, *optional*): | |
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., | |
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or | |
automatically loaded: | |
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the | |
underlying model's `__init__` method (we assume all relevant updates to the configuration have | |
already been done) | |
- If a configuration is not provided, `kwargs` will be first passed to the configuration class | |
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that | |
corresponds to a configuration attribute will be used to override said attribute with the | |
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute | |
will be passed to the underlying model's `__init__` function. | |
Examples: | |
```python | |
>>> from transformers import BertConfig, FlaxBertModel | |
>>> # Download model and configuration from huggingface.co and cache. | |
>>> model = FlaxBertModel.from_pretrained("bert-base-cased") | |
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). | |
>>> model = FlaxBertModel.from_pretrained("./test/saved_model/") | |
>>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). | |
>>> config = BertConfig.from_json_file("./pt_model/config.json") | |
>>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) | |
```""" | |
from_pt = kwargs.pop("from_pt", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
use_auth_token = kwargs.pop("use_auth_token", None) | |
trust_remote_code = kwargs.pop("trust_remote_code", None) | |
from_pipeline = kwargs.pop("_from_pipeline", None) | |
from_auto_class = kwargs.pop("_from_auto", False) | |
_do_init = kwargs.pop("_do_init", True) | |
subfolder = kwargs.pop("subfolder", "") | |
commit_hash = kwargs.pop("_commit_hash", None) | |
# Not relevant for Flax Models | |
_ = kwargs.pop("adapter_kwargs", None) | |
if use_auth_token is not None: | |
warnings.warn( | |
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", | |
FutureWarning, | |
) | |
if token is not None: | |
raise ValueError( | |
"`token` and `use_auth_token` are both specified. Please set only the argument `token`." | |
) | |
token = use_auth_token | |
if trust_remote_code is True: | |
logger.warning( | |
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" | |
" ignored." | |
) | |
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} | |
if from_pipeline is not None: | |
user_agent["using_pipeline"] = from_pipeline | |
if is_offline_mode() and not local_files_only: | |
logger.info("Offline mode: forcing local_files_only=True") | |
local_files_only = True | |
# Load config if we don't provide a configuration | |
if not isinstance(config, PretrainedConfig): | |
config_path = config if config is not None else pretrained_model_name_or_path | |
config, model_kwargs = cls.config_class.from_pretrained( | |
config_path, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
_from_auto=from_auto_class, | |
_from_pipeline=from_pipeline, | |
_commit_hash=commit_hash, | |
**kwargs, | |
) | |
else: | |
model_kwargs = kwargs.copy() | |
if commit_hash is None: | |
commit_hash = getattr(config, "_commit_hash", None) | |
# Add the dtype to model_kwargs | |
model_kwargs["dtype"] = dtype | |
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the | |
# index of the files. | |
is_sharded = False | |
# Load model | |
if pretrained_model_name_or_path is not None: | |
pretrained_model_name_or_path = str(pretrained_model_name_or_path) | |
is_local = os.path.isdir(pretrained_model_name_or_path) | |
if os.path.isdir(pretrained_model_name_or_path): | |
if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): | |
# Load from a Flax checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) | |
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)): | |
# Load from a sharded Flax checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME) | |
is_sharded = True | |
elif is_safetensors_available() and os.path.isfile( | |
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) | |
): | |
# Load from a safetensors checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) | |
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): | |
# Load from a PyTorch checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) | |
elif from_pt and os.path.isfile( | |
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) | |
): | |
# Load from a sharded pytorch checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME) | |
is_sharded = True | |
# At this stage we don't have a weight file so we will raise an error. | |
elif is_safetensors_available() and os.path.isfile( | |
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) | |
): | |
# Load from a sharded safetensors checkpoint | |
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) | |
is_sharded = True | |
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") | |
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): | |
raise EnvironmentError( | |
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " | |
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " | |
"weights." | |
) | |
else: | |
raise EnvironmentError( | |
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " | |
f"{pretrained_model_name_or_path}." | |
) | |
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): | |
archive_file = pretrained_model_name_or_path | |
is_local = True | |
elif is_remote_url(pretrained_model_name_or_path): | |
filename = pretrained_model_name_or_path | |
resolved_archive_file = download_url(pretrained_model_name_or_path) | |
else: | |
if from_pt: | |
filename = WEIGHTS_NAME | |
else: | |
filename = FLAX_WEIGHTS_NAME | |
try: | |
# Load from URL or cache if already cached | |
cached_file_kwargs = { | |
"cache_dir": cache_dir, | |
"force_download": force_download, | |
"proxies": proxies, | |
"resume_download": resume_download, | |
"local_files_only": local_files_only, | |
"token": token, | |
"user_agent": user_agent, | |
"revision": revision, | |
"subfolder": subfolder, | |
"_raise_exceptions_for_missing_entries": False, | |
"_commit_hash": commit_hash, | |
} | |
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) | |
# Maybe the checkpoint is sharded, we try to grab the index name in this case. | |
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME: | |
resolved_archive_file = cached_file( | |
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs | |
) | |
if resolved_archive_file is not None: | |
is_sharded = True | |
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case. | |
if resolved_archive_file is None and from_pt: | |
resolved_archive_file = cached_file( | |
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs | |
) | |
if resolved_archive_file is not None: | |
is_sharded = True | |
# If we still haven't found anything, look for `safetensors`. | |
if resolved_archive_file is None: | |
# No support for sharded safetensors yet, so we'll raise an error if that's all we find. | |
filename = SAFE_WEIGHTS_NAME | |
resolved_archive_file = cached_file( | |
pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs | |
) | |
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None | |
# result when internet is up, the repo and revision exist, but the file does not. | |
if resolved_archive_file is None: | |
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error | |
# message. | |
has_file_kwargs = { | |
"revision": revision, | |
"proxies": proxies, | |
"token": token, | |
} | |
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): | |
is_sharded = True | |
raise NotImplementedError( | |
"Support for sharded checkpoints using safetensors is coming soon!" | |
) | |
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): | |
raise EnvironmentError( | |
f"{pretrained_model_name_or_path} does not appear to have a file named" | |
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to" | |
" load this model from those weights." | |
) | |
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs): | |
raise EnvironmentError( | |
f"{pretrained_model_name_or_path} does not appear to have a file named" | |
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use" | |
" `from_pt=True` to load this model from those weights." | |
) | |
else: | |
raise EnvironmentError( | |
f"{pretrained_model_name_or_path} does not appear to have a file named" | |
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." | |
) | |
except EnvironmentError: | |
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted | |
# to the original exception. | |
raise | |
except Exception: | |
# For any other exception, we throw a generic error. | |
raise EnvironmentError( | |
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" | |
" from 'https://huggingface.co/models', make sure you don't have a local directory with the" | |
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" | |
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." | |
) | |
if is_local: | |
logger.info(f"loading weights file {archive_file}") | |
resolved_archive_file = archive_file | |
filename = resolved_archive_file.split(os.path.sep)[-1] | |
else: | |
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") | |
else: | |
resolved_archive_file = None | |
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded. | |
if is_sharded: | |
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. | |
resolved_archive_file, _ = get_checkpoint_shard_files( | |
pretrained_model_name_or_path, | |
resolved_archive_file, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
local_files_only=local_files_only, | |
token=token, | |
user_agent=user_agent, | |
revision=revision, | |
subfolder=subfolder, | |
_commit_hash=commit_hash, | |
) | |
safetensors_from_pt = False | |
if filename == SAFE_WEIGHTS_NAME: | |
with safe_open(resolved_archive_file, framework="flax") as f: | |
safetensors_metadata = f.metadata() | |
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: | |
raise OSError( | |
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." | |
" Make sure you save your model with the `save_pretrained` method." | |
) | |
safetensors_from_pt = safetensors_metadata.get("format") == "pt" | |
# init random models | |
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) | |
if from_pt or safetensors_from_pt: | |
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) | |
else: | |
if is_sharded: | |
state = cls.load_flax_sharded_weights(resolved_archive_file) | |
else: | |
state = cls.load_flax_weights(resolved_archive_file) | |
# make sure all arrays are stored as jnp.arrays | |
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: | |
# https://github.com/google/flax/issues/1261 | |
if _do_init: | |
state = jax.tree_util.tree_map(jnp.array, state) | |
else: | |
# keep the params on CPU if we don't want to initialize | |
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state) | |
if "batch_stats" in state: # if flax model contains batch norm layers | |
# if model is base model only use model_prefix key | |
if ( | |
cls.base_model_prefix not in dict(model.params_shape_tree["params"]) | |
and cls.base_model_prefix in state["params"] | |
): | |
state["params"] = state["params"][cls.base_model_prefix] | |
state["batch_stats"] = state["batch_stats"][cls.base_model_prefix] | |
# if model is head model and we are loading weights from base model | |
# we initialize new params dict with base_model_prefix | |
if ( | |
cls.base_model_prefix in dict(model.params_shape_tree["params"]) | |
and cls.base_model_prefix not in state["params"] | |
): | |
state = { | |
"params": {cls.base_model_prefix: state["params"]}, | |
"batch_stats": {cls.base_model_prefix: state["batch_stats"]}, | |
} | |
else: | |
# if model is base model only use model_prefix key | |
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state: | |
state = state[cls.base_model_prefix] | |
# if model is head model and we are loading weights from base model | |
# we initialize new params dict with base_model_prefix | |
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state: | |
state = {cls.base_model_prefix: state} | |
# flatten dicts | |
state = flatten_dict(state) | |
random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree)) | |
missing_keys = model.required_params - set(state.keys()) | |
unexpected_keys = set(state.keys()) - model.required_params | |
# Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked | |
for unexpected_key in unexpected_keys.copy(): | |
if "num_batches_tracked" in unexpected_key[-1]: | |
unexpected_keys.remove(unexpected_key) | |
if missing_keys and not _do_init: | |
logger.warning( | |
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " | |
"Make sure to call model.init_weights to initialize the missing weights." | |
) | |
cls._missing_keys = missing_keys | |
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | |
# matching the weights in the model. | |
mismatched_keys = [] | |
for key in state.keys(): | |
if key in random_state and state[key].shape != random_state[key].shape: | |
if ignore_mismatched_sizes: | |
mismatched_keys.append((key, state[key].shape, random_state[key].shape)) | |
state[key] = random_state[key] | |
else: | |
raise ValueError( | |
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " | |
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " | |
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " | |
"model." | |
) | |
# add missing keys as random parameters if we are initializing | |
if missing_keys and _do_init: | |
for missing_key in missing_keys: | |
state[missing_key] = random_state[missing_key] | |
# remove unexpected keys to not be saved again | |
for unexpected_key in unexpected_keys: | |
del state[unexpected_key] | |
if len(unexpected_keys) > 0: | |
logger.warning( | |
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | |
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" | |
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" | |
" with another architecture (e.g. initializing a BertForSequenceClassification model from a" | |
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" | |
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" | |
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | |
) | |
else: | |
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
if len(missing_keys) > 0: | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" | |
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
) | |
elif len(mismatched_keys) == 0: | |
logger.info( | |
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" | |
f" was trained on, you can already use {model.__class__.__name__} for predictions without further" | |
" training." | |
) | |
if len(mismatched_keys) > 0: | |
mismatched_warning = "\n".join( | |
[ | |
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | |
for key, shape1, shape2 in mismatched_keys | |
] | |
) | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" | |
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" | |
" to use it for predictions and inference." | |
) | |
# dictionary of key: dtypes for the model params | |
param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state) | |
# extract keys of parameters not in jnp.float32 | |
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] | |
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] | |
# raise a warning if any of the parameters are not in jnp.float32 | |
if len(fp16_params) > 0: | |
logger.warning( | |
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " | |
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" | |
"You should probably UPCAST the model weights to float32 if this was not intended. " | |
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." | |
) | |
if len(bf16_params) > 0: | |
logger.warning( | |
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " | |
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" | |
"You should probably UPCAST the model weights to float32 if this was not intended. " | |
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." | |
) | |
# If it is a model with generation capabilities, attempt to load the generation config | |
if model.can_generate(): | |
try: | |
model.generation_config = GenerationConfig.from_pretrained( | |
pretrained_model_name_or_path, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
token=token, | |
revision=revision, | |
subfolder=subfolder, | |
_from_auto=from_auto_class, | |
_from_pipeline=from_pipeline, | |
**kwargs, | |
) | |
except OSError: | |
logger.info( | |
"Generation config file not found, using a generation config created from the model config." | |
) | |
pass | |
if _do_init: | |
# set correct parameters | |
model.params = unflatten_dict(state) | |
return model | |
else: | |
return model, unflatten_dict(state) | |
def save_pretrained( | |
self, | |
save_directory: Union[str, os.PathLike], | |
params=None, | |
push_to_hub=False, | |
max_shard_size="10GB", | |
token: Optional[Union[str, bool]] = None, | |
safe_serialization: bool = False, | |
**kwargs, | |
): | |
""" | |
Save a model and its configuration file to a directory, so that it can be re-loaded using the | |
`[`~FlaxPreTrainedModel.from_pretrained`]` class method | |
Arguments: | |
save_directory (`str` or `os.PathLike`): | |
Directory to which to save. Will be created if it doesn't exist. | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the | |
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your | |
namespace). | |
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): | |
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size | |
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). | |
<Tip warning={true}> | |
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard | |
which will be bigger than `max_shard_size`. | |
</Tip> | |
token (`str` or `bool`, *optional*): | |
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use | |
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | |
safe_serialization (`bool`, *optional*, defaults to `False`): | |
Whether to save the model using `safetensors` or through msgpack. | |
""" | |
use_auth_token = kwargs.pop("use_auth_token", None) | |
if use_auth_token is not None: | |
warnings.warn( | |
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", | |
FutureWarning, | |
) | |
if token is not None: | |
raise ValueError( | |
"`token` and `use_auth_token` are both specified. Please set only the argument `token`." | |
) | |
token = use_auth_token | |
if token is not None: | |
kwargs["token"] = token | |
if os.path.isfile(save_directory): | |
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") | |
return | |
os.makedirs(save_directory, exist_ok=True) | |
if push_to_hub: | |
commit_message = kwargs.pop("commit_message", None) | |
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) | |
repo_id = self._create_repo(repo_id, **kwargs) | |
files_timestamps = self._get_files_timestamps(save_directory) | |
# get abs dir | |
save_directory = os.path.abspath(save_directory) | |
# save config as well | |
self.config.architectures = [self.__class__.__name__[4:]] | |
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be | |
# loaded from the Hub. | |
if self._auto_class is not None: | |
custom_object_save(self, save_directory, config=self.config) | |
self.config.save_pretrained(save_directory) | |
if self.can_generate(): | |
self.generation_config.save_pretrained(save_directory) | |
# save model | |
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME | |
output_model_file = os.path.join(save_directory, weights_name) | |
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size) | |
# Clean the folder from a previous save | |
for filename in os.listdir(save_directory): | |
full_filename = os.path.join(save_directory, filename) | |
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") | |
if ( | |
filename.startswith(weights_no_suffix) | |
and os.path.isfile(full_filename) | |
and filename not in shards.keys() | |
): | |
os.remove(full_filename) | |
if index is None: | |
if safe_serialization: | |
params = params if params is not None else self.params | |
flat_dict = flatten_dict(params, sep=".") | |
safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"}) | |
else: | |
with open(output_model_file, "wb") as f: | |
params = params if params is not None else self.params | |
model_bytes = to_bytes(params) | |
f.write(model_bytes) | |
else: | |
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME) | |
# Save the index as well | |
with open(save_index_file, "w", encoding="utf-8") as f: | |
content = json.dumps(index, indent=2, sort_keys=True) + "\n" | |
f.write(content) | |
logger.info( | |
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " | |
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " | |
f"index located at {save_index_file}." | |
) | |
for shard_file, shard in shards.items(): | |
# the shard item are unflattened, to save them we need to flatten them again | |
with open(os.path.join(save_directory, shard_file), mode="wb") as f: | |
params = unflatten_dict(shard, sep="/") | |
shard_bytes = to_bytes(params) | |
f.write(shard_bytes) | |
logger.info(f"Model weights saved in {output_model_file}") | |
if push_to_hub: | |
self._upload_modified_files( | |
save_directory, | |
repo_id, | |
files_timestamps, | |
commit_message=commit_message, | |
token=token, | |
) | |
def register_for_auto_class(cls, auto_class="FlaxAutoModel"): | |
""" | |
Register this class with a given auto class. This should only be used for custom models as the ones in the | |
library are already mapped with an auto class. | |
<Tip warning={true}> | |
This API is experimental and may have some slight breaking changes in the next releases. | |
</Tip> | |
Args: | |
auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`): | |
The auto class to register this new model with. | |
""" | |
if not isinstance(auto_class, str): | |
auto_class = auto_class.__name__ | |
import transformers.models.auto as auto_module | |
if not hasattr(auto_module, auto_class): | |
raise ValueError(f"{auto_class} is not a valid auto class.") | |
cls._auto_class = auto_class | |
# To update the docstring, we need to copy the method, otherwise we change the original docstring. | |
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub) | |
if FlaxPreTrainedModel.push_to_hub.__doc__ is not None: | |
FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format( | |
object="model", object_class="FlaxAutoModel", object_files="model checkpoint" | |
) | |
def overwrite_call_docstring(model_class, docstring): | |
# copy __call__ function to be sure docstring is changed only for this function | |
model_class.__call__ = copy_func(model_class.__call__) | |
# delete existing docstring | |
model_class.__call__.__doc__ = None | |
# set correct docstring | |
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) | |
def append_call_sample_docstring( | |
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None | |
): | |
model_class.__call__ = copy_func(model_class.__call__) | |
model_class.__call__ = add_code_sample_docstrings( | |
checkpoint=checkpoint, | |
output_type=output_type, | |
config_class=config_class, | |
model_cls=model_class.__name__, | |
revision=revision, | |
real_checkpoint=real_checkpoint, | |
)(model_class.__call__) | |
def append_replace_return_docstrings(model_class, output_type, config_class): | |
model_class.__call__ = copy_func(model_class.__call__) | |
model_class.__call__ = replace_return_docstrings( | |
output_type=output_type, | |
config_class=config_class, | |
)(model_class.__call__) | |