# NEW import os # from functools import partial from pickle import UnpicklingError from typing import Dict, Set, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey from transformers.configuration_utils import PretrainedConfig from transformers.file_utils import ( FLAX_WEIGHTS_NAME, WEIGHTS_NAME, PushToHubMixin, cached_path, hf_bucket_url, is_offline_mode, is_remote_url, ) from transformers.modeling_flax_pytorch_utils import ( load_pytorch_checkpoint_in_flax_state_dict, ) from transformers.utils import logging from .generation_clip_vision_utils import FlaxCLIPVisionMBartGenerationMixin logger = logging.get_logger(__name__) class FlaxCLIPVisionMBartPreTrainedModel( PushToHubMixin, FlaxCLIPVisionMBartGenerationMixin ): r""" Base class for all models. :class:`~transformers.FlaxPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading, downloading and saving models. Class attributes (overridden by derived classes): - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. """ config_class = None base_model_prefix = "" def __init__( self, config: PretrainedConfig, module: nn.Module, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, ): if config is None: raise ValueError("config cannot be None") if module is None: raise ValueError("module cannot be None") # Those are private to be exposed as typed property on derived classes. self._config = config self._module = module # Those are public as their type is generic to every derived classes. self.key = PRNGKey(seed) self.dtype = dtype # randomly initialized parameters random_params = self.init_weights(self.key, input_shape) # save required_params as set self._required_params = set(flatten_dict(unfreeze(random_params)).keys()) self.params = random_params def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict: raise NotImplementedError(f"init method has to be implemented for {self}") @classmethod def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. """ return cls(config, **kwargs) @property def config(self) -> PretrainedConfig: return self._config @property def module(self) -> nn.Module: return self._module @property def params(self) -> Union[Dict, FrozenDict]: return self._params @property def required_params(self) -> Set: return self._required_params @params.setter def params(self, params: Union[Dict, FrozenDict]): if isinstance(params, FrozenDict): params = unfreeze(params) param_keys = set(flatten_dict(params).keys()) if len(self.required_params - param_keys) > 0: raise ValueError( "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}" ) self._params = params @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], dtype: jnp.dtype = jnp.float32, *model_args, **kwargs, ): r""" Instantiate a pretrained flax model from a pre-trained model configuration. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): Can be either: - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. - A path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this case, ``from_pt`` should be set to :obj:`True`. model_args (sequence of positional arguments, `optional`): All remaning positional arguments will be passed to the underlying model's ``__init__`` method. config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): Can be either: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - The model is a model provided by the library (loaded with the `model id` string of a pretrained model). - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by supplying the save directory. - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a PyTorch checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. proxies (:obj:`Dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to only look at local files (i.e., do not try to download the model). revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git. kwargs (remaining dictionary of keyword arguments, `optional`): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: >>> from transformers import BertConfig, FlaxBertModel >>> # Download model and configuration from huggingface.co and cache. >>> model = FlaxBertModel.from_pretrained('bert-base-cased') >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). >>> model = FlaxBertModel.from_pretrained('./test/saved_model/') >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./pt_model/config.json') >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = { "file_type": "model", "framework": "flax", "from_auto_class": from_auto_class, } if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline if is_offline_mode() and not local_files_only: logger.info("Offline mode: forcing local_files_only=True") local_files_only = True # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = ( config if config is not None else pretrained_model_name_or_path ) config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, _from_auto=from_auto_class, _from_pipeline=from_pipeline, **kwargs, ) else: model_kwargs = kwargs # Add the dtype to model_kwargs model_kwargs["dtype"] = dtype # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if from_pt and os.path.isfile( os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) ): # Load from a PyTorch checkpoint archive_file = os.path.join( pretrained_model_name_or_path, WEIGHTS_NAME ) elif os.path.isfile( os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) ): # Load from a Flax checkpoint archive_file = os.path.join( pretrained_model_name_or_path, FLAX_WEIGHTS_NAME ) else: raise EnvironmentError( f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory " f"{pretrained_model_name_or_path} or `from_pt` set to False" ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( pretrained_model_name_or_path ): archive_file = pretrained_model_name_or_path else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME, revision=revision, ) # redirect to the cache, if necessary try: resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, use_auth_token=use_auth_token, user_agent=user_agent, ) except EnvironmentError as err: logger.error(err) msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info(f"loading weights file {archive_file}") else: logger.info( f"loading weights file {archive_file} from cache at {resolved_archive_file}" ) else: resolved_archive_file = None # init random models model = cls(config, *model_args, **model_kwargs) if from_pt: state = load_pytorch_checkpoint_in_flax_state_dict( model, resolved_archive_file ) else: with open(resolved_archive_file, "rb") as state_f: try: state = from_bytes(cls, state_f.read()) except UnpicklingError: raise EnvironmentError( f"Unable to convert {archive_file} to Flax deserializable object. " ) # make sure all arrays are stored as jnp.arrays # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(jnp.array, state) # if model is base model only use model_prefix key if ( cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state ): state = state[cls.base_model_prefix] # if model is head model and we are loading weights from base model # we initialize new params dict with base_model_prefix if ( cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state ): state = {cls.base_model_prefix: state} # flatten dicts state = flatten_dict(state) random_state = flatten_dict(unfreeze(model.params)) missing_keys = model.required_params - set(state.keys()) unexpected_keys = set(state.keys()) - model.required_params # add missing keys as random parameters for missing_key in missing_keys: state[missing_key] = random_state[missing_key] # remove unexpected keys to not be saved again for unexpected_key in unexpected_keys: del state[unexpected_key] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.info( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.info( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the checkpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) # set correct parameters model.params = unflatten_dict(state) return model def save_pretrained( self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method Arguments: save_directory (:obj:`str` or :obj:`os.PathLike`): Directory to which to save. Will be created if it doesn't exist. push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to push your model to the Hugging Face model hub after saving it. .. warning:: Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory instead. kwargs: Additional key word arguments passed along to the :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. """ if os.path.isfile(save_directory): logger.error( f"Provided path ({save_directory}) should be a directory, not a file" ) return if push_to_hub: commit_message = kwargs.pop("commit_message", None) repo = self._create_or_get_repo(save_directory, **kwargs) os.makedirs(save_directory, exist_ok=True) # get abs dir save_directory = os.path.abspath(save_directory) # save config as well self.config.architectures = [self.__class__.__name__[4:]] self.config.save_pretrained(save_directory) # save model output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) with open(output_model_file, "wb") as f: params = params if params is not None else self.params model_bytes = to_bytes(params) f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") if push_to_hub: url = self._push_to_hub(repo, commit_message=commit_message) logger.info(f"Model pushed to the hub in this commit: {url}")