# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
import functools
import os
import re
import warnings
from typing import Dict, List, Optional, Union
import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import hdf5_format
from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .generation_tf_utils import TFGenerationMixin
from .utils import logging
logger = logging.get_logger(__name__)
[docs]class TFModelUtilsMixin:
"""
A few utilities for :obj:`tf.keras.Model`, to be used as a mixin.
"""
[docs] def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
if only_trainable:
return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
else:
return self.count_params()
[docs]def keras_serializable(cls):
"""
Decorate a Keras Layer class to support Keras serialization.
This is done by:
1. Adding a :obj:`transformers_config` dict to the Keras config dictionary in :obj:`get_config` (called by Keras at
serialization time.
2. Wrapping :obj:`__init__` to accept that :obj:`transformers_config` dict (passed by Keras at deserialization
time) and convert it to a config object for the actual layer initializer.
3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
need to be supplied in :obj:`custom_objects` in the call to :obj:`tf.keras.models.load_model`.
Args:
cls (a :obj:`tf.keras.layers.Layers subclass`):
Typically a :obj:`TF.MainLayer` class in this project, in general must accept a :obj:`config` argument to
its initializer.
Returns:
The same class object, with modifications for Keras deserialization.
"""
initializer = cls.__init__
config_class = getattr(cls, "config_class", None)
if config_class is None:
raise AttributeError("Must set `config_class` to use @keras_serializable")
@functools.wraps(initializer)
def wrapped_init(self, *args, **kwargs):
config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)
if isinstance(config, dict):
config = config_class.from_dict(config)
initializer(self, config, *args, **kwargs)
elif isinstance(config, PretrainedConfig):
if len(args) > 0:
initializer(self, *args, **kwargs)
else:
initializer(self, config, *args, **kwargs)
else:
raise ValueError("Must pass either `config` (PretrainedConfig) or `config` (dict)")
self._config = config
self._kwargs = kwargs
cls.__init__ = wrapped_init
if not hasattr(cls, "get_config"):
raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
if hasattr(cls.get_config, "_is_default"):
def get_config(self):
cfg = super(cls, self).get_config()
cfg["config"] = self._config.to_dict()
cfg.update(self._kwargs)
return cfg
cls.get_config = get_config
cls._keras_serializable = True
if hasattr(tf.keras.utils, "register_keras_serializable"):
cls = tf.keras.utils.register_keras_serializable()(cls)
return cls
[docs]class TFCausalLanguageModelingLoss:
"""
Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.
.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100 do not affect loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
[docs]class TFQuestionAnsweringLoss:
"""
Loss function suitable for question answering.
"""
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
start_loss = loss_fn(labels["start_position"], logits[0])
end_loss = loss_fn(labels["end_position"], logits[1])
return (start_loss + end_loss) / 2.0
[docs]class TFTokenClassificationLoss:
"""
Loss function suitable for token classification.
.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""
def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
if tf.math.reduce_any(labels == -1):
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
active_loss = tf.reshape(labels, (-1,)) != -1
else:
active_loss = tf.reshape(labels, (-1,)) != -100
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
[docs]class TFSequenceClassificationLoss:
"""
Loss function suitable for sequence classification.
"""
def compute_loss(self, labels, logits):
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
else:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
return loss_fn(labels, logits)
[docs]class TFMultipleChoiceLoss(TFSequenceClassificationLoss):
"""Loss function suitable for multiple choice tasks."""
[docs]class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
"""
Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""
def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
"""
missing_layers = []
unexpected_layers = []
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
return missing_layers, unexpected_layers
def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
weight_value_tuples = []
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
for weight_name in saved_weight_names:
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])
for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
raise e
else:
array = saved_weight_value
weight_value_tuples.append((symbolic_weight, array))
K.batch_set_value(weight_value_tuples)
[docs]class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""
Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
for loading, downloading and saving models as well as a few methods common to all models to:
* resize the input embeddings,
* prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
from the model when loading the model weights (and avoid unnecessary warnings).
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
"""
config_class = None
base_model_prefix = ""
authorized_missing_keys = None
authorized_unexpected_keys = None
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
return {"input_ids": tf.constant(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
)
)
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
[docs] def get_output_embeddings(self) -> tf.keras.layers.Layer:
"""
Returns the model's output embeddings.
Returns:
:obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
"""
return None # Overwrite for models with output embeddings
[docs] def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
"""
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
Arguments:
new_num_tokens (:obj:`int`, `optional`):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
just returns a pointer to the input tokens :obj:`tf.Variable` module of the model wihtout doing
anything.
Return:
:obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
# get_input_embeddings and set_input_embeddings need to be implemented in base layer.
base_model = getattr(self, self.base_model_prefix, self)
old_embeddings = base_model.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
base_model.set_input_embeddings(new_embeddings)
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
return base_model.get_input_embeddings()
def _get_word_embeddings(self, embeddings):
if hasattr(embeddings, "word_embeddings"):
# TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings
return embeddings.word_embeddings
elif hasattr(embeddings, "weight"):
# TFSharedEmbeddings
return embeddings.weight
else:
raise ValueError("word embedding is not defined.")
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
initialized vectors at the end. Reducing the size will remove vectors from the end
Args:
old_embeddings (:obj:`tf.Variable`):
Old embeddings to be resized.
new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
:obj:`tf.Variable`` module of the model wihtout doing anything.
Return:
:obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
:obj:`new_num_tokens` is :obj:`None`
"""
word_embeddings = self._get_word_embeddings(old_embeddings)
if new_num_tokens is None:
return word_embeddings
old_num_tokens, old_embedding_dim = word_embeddings.shape
if old_num_tokens == new_num_tokens:
return word_embeddings
# initialize new embeddings
# todo: initializer range is not always passed in config.
init_range = getattr(self.config, "initializer_range", 0.02)
new_embeddings = self.add_weight(
"weight",
shape=[new_num_tokens, old_embedding_dim],
initializer=get_initializer(init_range),
dtype=tf.float32,
)
init_weights = new_embeddings.numpy()
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :]
new_embeddings.assign(init_weights)
return new_embeddings
[docs] def prune_heads(self, heads_to_prune):
"""
Prunes heads of the base model.
Arguments:
heads_to_prune (:obj:`Dict[int, List[int]]`):
Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list of
heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads
0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
raise NotImplementedError
[docs] def save_pretrained(self, save_directory):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
:func:`~transformers.TFPreTrainedModel.from_pretrained` class method.
Arguments:
save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
return
os.makedirs(save_directory, exist_ok=True)
# Save configuration file
self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
self.save_weights(output_model_file)
logger.info("Model weights saved in {}".format(output_model_file))
[docs] @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (:obj:`str`, `optional`):
Can be either:
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
afterwards.
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
arguments ``config`` and ``state_dict``).
model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
config (:obj:`Union[PretrainedConfig, str]`, `optional`):
Can be either:
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
- The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch state_dict save file (see docstring of
``pretrained_model_name_or_path`` argument).
cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies: (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples::
>>> from transformers import BertConfig, TFBertModel
>>> # Download model and configuration from S3 and cache.
>>> model = TFBertModel.from_pretrained('bert-base-uncased')
>>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
>>> model = TFBertModel.from_pretrained('./test/saved_model/')
>>> # Update configuration during loading.
>>> model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
>>> assert model.config.output_attentions == True
>>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
>>> model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
else:
raise EnvironmentError(
"Error no file named {} found in directory {} or `from_pt` set to False".format(
[WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
)
)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
revision=revision,
mirror=mirror,
)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
else:
resolved_archive_file = None
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
model = cls(config, *model_args, **model_kwargs)
if from_pt:
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
# Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
model(model.dummy_inputs, training=False) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
load_tf_weights(model, resolved_archive_file)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
)
model(model.dummy_inputs, training=False) # Make sure restore ops are run
missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if cls.authorized_unexpected_keys is not None:
for pat in cls.authorized_unexpected_keys:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return model, loading_info
return model
[docs]class TFConv1D(tf.keras.layers.Layer):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
Args:
nf (:obj:`int`):
The number of output features.
nx (:obj:`int`):
The number of input features.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation to use to initialize the weights.
kwargs:
Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
"""
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
super().__init__(**kwargs)
self.nf = nf
self.nx = nx
self.initializer_range = initializer_range
def build(self, input_shape):
self.weight = self.add_weight(
"weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
)
self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
def call(self, x):
bz, sl = shape_list(x)[:2]
x = tf.reshape(x, [-1, self.nx])
x = tf.matmul(x, self.weight) + self.bias
x = tf.reshape(x, [bz, sl, self.nf])
return x
[docs]class TFSharedEmbeddings(tf.keras.layers.Layer):
r"""
Construct shared token embeddings.
The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
modeling.
Args:
vocab_size (:obj:`int`):
The size of the vocabulary, e.g., the number of unique tokens.
hidden_size (:obj:`int`):
The size of the embedding vectors.
initializer_range (:obj:`float`, `optional`):
The standard deviation to use when initializing the weights. If no value is provided, it will default to
:math:`1/\sqrt{hidden\_size}`.
kwargs:
Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
"""
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
def build(self, input_shape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
self.weight = self.add_weight(
"weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
)
super().build(input_shape)
def get_config(self):
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
"""
Get token embeddings of inputs or decode final hidden state.
Args:
inputs (:obj:`tf.Tensor`):
In embedding mode, should be an int64 tensor with shape :obj:`[batch_size, length]`.
In linear mode, should be a float tensor with shape :obj:`[batch_size, length, hidden_size]`.
mode (:obj:`str`, defaults to :obj:`"embedding"`):
A valid value is either :obj:`"embedding"` or :obj:`"linear"`, the first one indicates that the layer
should be used as an embedding layer, the second one that the layer should be used as a linear decoder.
Returns:
:obj:`tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape
:obj:`[batch_size, length, embedding_size]`.
In linear mode, the output is a float32 with shape :obj:`[batch_size, length, vocab_size]`.
Raises:
ValueError: if :obj:`mode` is not valid.
Shared weights logic is adapted from `here
<https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24>`__.
"""
if mode == "embedding":
return self._embedding(inputs)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError("mode {} is not valid.".format(mode))
def _embedding(self, input_ids):
"""Applies embedding based on inputs tensor."""
return tf.gather(self.weight, input_ids)
def _linear(self, inputs):
"""
Computes logits by running inputs through a linear layer.
Args:
inputs: A float32 tensor with shape [..., hidden_size]
Returns:
float32 tensor with shape [..., vocab_size].
"""
first_dims = shape_list(inputs)[:-1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, first_dims + [self.vocab_size])
[docs]class TFSequenceSummary(tf.keras.layers.Layer):
"""
Compute a single vector summary of a sequence hidden states.
Args:
config (:class:`~transformers.PretrainedConfig`):
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
config class of your model for the default values it uses):
- **summary_type** (:obj:`str`) -- The method to use to make this summary. Accepted values are:
- :obj:`"last"` -- Take the last token hidden state (like XLNet)
- :obj:`"first"` -- Take the first token hidden state (like Bert)
- :obj:`"mean"` -- Take the mean of all tokens hidden states
- :obj:`"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- :obj:`"attn"` -- Not implemented now, use multi-head attention
- **summary_use_proj** (:obj:`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (:obj:`bool`) -- If :obj:`True`, the projection outputs to
:obj:`config.num_labels` classes (otherwise to :obj:`config.hidden_size`).
- **summary_activation** (:obj:`Optional[str]`) -- Set to :obj:`"tanh"` to add a tanh activation to the
output, another string or :obj:`None` will add no activation.
- **summary_first_dropout** (:obj:`float`) -- Optional dropout probability before the projection and
activation.
- **summary_last_dropout** (:obj:`float`)-- Optional dropout probability after the projection and
activation.
initializer_range (:obj:`float`, defaults to 0.02): The standard deviation to use to initialize the weights.
kwargs:
Additional keyword arguments passed along to the :obj:`__init__` of :obj:`tf.keras.layers.Layer`.
"""
def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
super().__init__(**kwargs)
self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
if self.has_summary:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = tf.keras.layers.Dense(
num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
)
self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
if self.has_activation:
self.activation = tf.keras.activations.tanh
self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
if self.has_first_dropout:
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
[docs] def call(self, inputs, cls_index=None, training=False):
if not isinstance(inputs, (dict, tuple, list)):
hidden_states = inputs
elif isinstance(inputs, (tuple, list)):
hidden_states = inputs[0]
cls_index = inputs[1] if len(inputs) > 1 else None
assert len(inputs) <= 2, "Too many inputs."
else:
hidden_states = inputs.get("hidden_states")
cls_index = inputs.get("cls_index", None)
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = tf.reduce_mean(hidden_states, axis=1)
elif self.summary_type == "cls_index":
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
if cls_index is None:
cls_index = tf.fill(
hidden_shape[:-2], hidden_shape[-2] - 1
) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
cls_shape = shape_list(cls_index)
if len(cls_shape) <= len(hidden_shape) - 2:
cls_index = cls_index[..., tf.newaxis]
# else:
# cls_index = cls_index[..., tf.newaxis]
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
output = tf.squeeze(
output, axis=len(hidden_shape) - 2
) # shape of output: (batch, num choices, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
if self.has_first_dropout:
output = self.first_dropout(output, training=training)
if self.has_summary:
output = self.summary(output)
if self.has_activation:
output = self.activation(output)
if self.has_last_dropout:
output = self.last_dropout(output, training=training)
return output
[docs]def shape_list(x: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
x (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = x.shape.as_list()
dynamic = tf.shape(x)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
[docs]def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
"""
Creates a :obj:`tf.initializers.TruncatedNormal` with the given range.
Args:
initializer_range (`float`, defaults to 0.02): Standard deviation of the initializer range.
Returns:
:obj:`tf.initializers.TruncatedNormal`: The truncated normal initializer.
"""
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
[docs]def cast_bool_to_primitive(bool_variable: Union[tf.Tensor, bool], default_tensor_to_true=False) -> bool:
"""
Function arguments can be inserted as boolean tensor and bool variables to cope with Keras serialization we need to
cast the bool arguments (like :obj:`output_attentions` for instance) to correct boolean if it is a tensor.
Args:
bool_variable (:obj:`Union[tf.Tensor, bool]`):
The variable to convert to a boolean.
default_tensor_to_true (:obj:`bool`, `optional`, defaults to `False`):
The default value to use in case the tensor has no numpy attribute.
Returns:
:obj:`bool`: The converted value.
"""
# if bool variable is tensor and has numpy value
if tf.is_tensor(bool_variable):
if hasattr(bool_variable, "numpy"):
return bool(bool_variable.numpy())
elif default_tensor_to_true:
return True
# else variable is bool
return bool_variable
class TFWrappedEmbeddings:
"""
this class wraps a the TFSharedEmbeddingTokens layer into a python 'no-keras-layer' class to avoid problem with
weight restoring. Also it makes sure that the layer is called from the correct scope to avoid problem with
saving/storing the correct weights
"""
def __init__(self, layer, abs_scope_name=None):
self._layer = layer
self._abs_scope_name = abs_scope_name
def call(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer.call(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer.call(inputs, mode)
def __call__(self, inputs, mode="embedding"):
if self._abs_scope_name is None:
return self._layer(inputs, mode)
# if an abs scope name is given to the embedding variable, call variable from absolute scope
with tf.compat.v1.variable_scope(self._abs_scope_name, auxiliary_name_scope=False) as abs_scope_name:
with tf.name_scope(abs_scope_name.original_name_scope):
return self._layer(inputs, mode)