# coding=utf-8 # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from functools import partial from pickle import UnpicklingError from typing import Dict, Set, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from transformers.configuration_utils import PretrainedConfig from transformers.file_utils import ( FLAX_WEIGHTS_NAME, WEIGHTS_NAME, PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, is_remote_url, ) from .generation_clip_vision_marian_utils import FlaxGenerationMixin from transformers.modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from transformers.utils import logging logger = logging.get_logger(__name__) def quick_gelu(x): return x * jax.nn.sigmoid(1.702 * x) ACT2FN = { "gelu": partial(nn.gelu, approximate=False), "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, "gelu_new": partial(nn.gelu, approximate=True), "quick_gelu": quick_gelu, } class FlaxCLIPVisionMarianPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): config_class = None base_model_prefix = "" def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype # randomly initialized parameters random_params = self.init_weights(self.key, input_shape) # save required_params as set self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) self.params = random_params def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: raise NotImplementedError(f"init method has to be implemented for {self}") @classmethod def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. """ return cls(config, **kwargs) @property def config(self) -> PretrainedConfig: return self._config @property def module(self) -> nn.Module: return self._module @property def params(self) -> Union[Dict, FrozenDict]: return self._params @property def required_params(self) -> Set: return self._required_params @params.setter def params(self, params: Union[Dict, FrozenDict]): if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) if len(self.required_params - param_keys) > 0: raise ValueError( "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}" ) self._params = params @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs ): config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) else: raise EnvironmentError( f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " f"{pretrained_model_name_or_path} or `from_pt` set to False" ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except EnvironmentError as err: logger.error(err) msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") else: resolved_archive_file = None # init random models model = cls(config, *model_args, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except UnpicklingError: raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") # make sure all arrays are stored as jnp.arrays # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(jnp.array, state) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] # if model is head model and we are loading weights from base model # we initialize new params dict with base_model_prefix if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state: state = {cls.base_model_prefix: state} # flatten dicts state = flatten_dict(state) random_state = flatten_dict(unfreeze(model.params)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] for key in state.keys(): if key in random_state and state[key].shape != random_state[key].shape: if ignore_mismatched_sizes: mismatched_keys.append((key, state[key].shape, random_state[key].shape)) state[key] = random_state[key] else: raise ValueError( f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. " "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this " "model." ) # add missing keys as random parameters for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) # set correct parameters model.params = unflatten_dict(state) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method Arguments: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory to which to save. Will be created if it doesn't exist. push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to push your model to the Hugging Face model hub after saving it. .. warning:: Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory instead. kwargs: Additional key word arguments passed along to the :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo = self._create_or_get_repo(save_directory, **kwargs) os.makedirs(save_directory, exist_ok=True) # get abs dir save_directory = os.path.abspath(save_directory) # save config as well self.config.architectures = [self.__class__.__name__[4:]] self.config.save_pretrained(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: params = params if params is not None else self.params model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") if push_to_hub: url = self._push_to_hub(repo, commit_message=commit_message) logger.info(f"Model pushed to the hub in this commit: {url}")